import os import transformers import torch from transformers import AutoModel, AutoTokenizer, AutoConfig, DataCollatorForSeq2Seq from peft import PeftModel class ModelLoader: def __init__(self, model_name_or_path, pre_seq_len=0, prefix_projection=False): self.config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) if pre_seq_len is not None and pre_seq_len > 0: self.config.pre_seq_len = pre_seq_len self.config.prefix_projection = prefix_projection self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) self.model = AutoModel.from_pretrained(model_name_or_path, config=self.config, trust_remote_code=True).half() # self.model = self.model.cuda() self.base_model = self.model def quantize(self, quantization_bit): if quantization_bit is not None: print(f"Quantized to {quantization_bit} bit") self.model = self.model.quantize(quantization_bit) return self.model def models(self): return self.model, self.tokenizer def collator(self): return DataCollatorForSeq2Seq(tokenizer=self.tokenizer, model=self.model) def load_lora(self, ckpt_path, name="default"): # 训练时节约GPU占用 _peft_loaded = PeftModel.from_pretrained(self.base_model, ckpt_path, adapter_name=name) self.model = _peft_loaded.merge_and_unload() print(f"Load LoRA model successfully!") def load_loras(self, ckpt_paths, name="default"): global peft_loaded if len(ckpt_paths) == 0: return first = True for name, path in ckpt_paths.items(): print(f"Load {name} from {path}") if first: peft_loaded = PeftModel.from_pretrained(self.base_model, path, adapter_name=name) first = False else: peft_loaded.load_adapter(path, adapter_name=name) peft_loaded.set_adapter(name) self.model = peft_loaded def load_prefix(self, ckpt_path): prefix_state_dict = torch.load(os.path.join(ckpt_path, "pytorch_model.bin")) new_prefix_state_dict = {} for k, v in prefix_state_dict.items(): if k.startswith("transformer.prefix_encoder."): new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) self.model.transformer.prefix_encoder.float() print(f"Load prefix model successfully!")