baichuan.py 1.89 KB
Newer Older
1
import os
陈正乐 committed
2 3
from typing import Dict, Optional, List
from langchain.llms.base import BaseLLM, LLM
4 5 6
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks
import torch

陈正乐 committed
7
from transformers import AutoTokenizer, AutoModel, AutoConfig, AutoModelForCausalLM
8 9 10 11 12 13 14 15
from transformers.generation.utils import GenerationConfig

from pydantic import root_validator


class BaichuanLLM(LLM):
    model_name: str = "baichuan-inc/Baichuan-13B-Chat"
    quantization_bit: Optional[int] = None
陈正乐 committed
16

17 18 19 20 21
    tokenizer: AutoTokenizer = None
    model: AutoModel = None

    def _llm_type(self) -> str:
        return "chatglm_local"
陈正乐 committed
22

23
    @root_validator()
陈正乐 committed
24
    def validate_environment(self, values: Dict) -> Dict:
25 26 27 28
        if not values["model_name"]:
            raise ValueError("No model name provided.")

        model_name = values["model_name"]
陈正乐 committed
29
        tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, trust_remote_code=True)
30 31 32 33 34 35 36 37 38 39 40 41 42
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            # device_map="auto",
            trust_remote_code=True
        )
        model.generation_config = GenerationConfig.from_pretrained(
            model_name
        )
        if values["quantization_bit"]:
            print(f"Quantized to {values['quantization_bit']} bit")
            model = model.quantize(values["quantization_bit"]).cuda()
        else:
陈正乐 committed
43
            model = model.half().cuda()
44 45 46 47 48 49 50

        model = model.eval()

        values["tokenizer"] = tokenizer
        values["model"] = model
        return values

陈正乐 committed
51 52 53 54
    def _call(self, prompt: str, stop: Optional[List[str]] = None,
              run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str:
        message = [{"role": "user", "content": prompt}]
        resp = self.model.chat(self.tokenizer, message)
55
        # print(f"prompt:{prompt}\nresponse:{resp}\n")
陈正乐 committed
56
        return resp