loader.py 2.53 KB
Newer Older
1 2 3 4 5 6
import os
import transformers
import torch
from transformers import AutoModel, AutoTokenizer, AutoConfig, DataCollatorForSeq2Seq
from peft import PeftModel

陈正乐 committed
7

8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
class ModelLoader:
    def __init__(self, model_name_or_path, pre_seq_len=0, prefix_projection=False):
        self.config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
        if pre_seq_len is not None and pre_seq_len > 0:
            self.config.pre_seq_len = pre_seq_len
            self.config.prefix_projection = prefix_projection
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
        self.model = AutoModel.from_pretrained(model_name_or_path, config=self.config, trust_remote_code=True).half()
        # self.model = self.model.cuda()
        self.base_model = self.model

    def quantize(self, quantization_bit):
        if quantization_bit is not None:
            print(f"Quantized to {quantization_bit} bit")
            self.model = self.model.quantize(quantization_bit)
        return self.model

    def models(self):
        return self.model, self.tokenizer
陈正乐 committed
27

28 29
    def collator(self):
        return DataCollatorForSeq2Seq(tokenizer=self.tokenizer, model=self.model)
陈正乐 committed
30 31 32 33 34

    def load_lora(self, ckpt_path, name="default"):
        # 训练时节约GPU占用
        _peft_loaded = PeftModel.from_pretrained(self.base_model, ckpt_path, adapter_name=name)
        self.model = _peft_loaded.merge_and_unload()
35
        print(f"Load LoRA model successfully!")
陈正乐 committed
36 37 38 39

    def load_loras(self, ckpt_paths, name="default"):
        global peft_loaded
        if len(ckpt_paths) == 0:
40 41 42 43 44 45 46
            return
        first = True
        for name, path in ckpt_paths.items():
            print(f"Load {name} from {path}")
            if first:
                peft_loaded = PeftModel.from_pretrained(self.base_model, path, adapter_name=name)
                first = False
陈正乐 committed
47 48
            else:
                peft_loaded.load_adapter(path, adapter_name=name)
49 50 51
        peft_loaded.set_adapter(name)
        self.model = peft_loaded

陈正乐 committed
52
    def load_prefix(self, ckpt_path):
53 54 55 56 57 58 59 60
        prefix_state_dict = torch.load(os.path.join(ckpt_path, "pytorch_model.bin"))
        new_prefix_state_dict = {}
        for k, v in prefix_state_dict.items():
            if k.startswith("transformer.prefix_encoder."):
                new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
        self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
        self.model.transformer.prefix_encoder.float()
        print(f"Load prefix model successfully!")