import os from typing import Dict, Optional, List from langchain.llms.base import BaseLLM, LLM from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks import torch from transformers import AutoTokenizer, AutoModel, AutoConfig, AutoModelForCausalLM from transformers.generation.utils import GenerationConfig from pydantic import root_validator class BaichuanLLM(LLM): model_name: str = "baichuan-inc/Baichuan-13B-Chat" quantization_bit: Optional[int] = None tokenizer: AutoTokenizer = None model: AutoModel = None def _llm_type(self) -> str: return "chatglm_local" @root_validator() def validate_environment(self, values: Dict) -> Dict: if not values["model_name"]: raise ValueError("No model name provided.") model_name = values["model_name"] tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, # device_map="auto", trust_remote_code=True ) model.generation_config = GenerationConfig.from_pretrained( model_name ) if values["quantization_bit"]: print(f"Quantized to {values['quantization_bit']} bit") model = model.quantize(values["quantization_bit"]).cuda() else: model = model.half().cuda() model = model.eval() values["tokenizer"] = tokenizer values["model"] = model return values def _call(self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str: message = [{"role": "user", "content": prompt}] resp = self.model.chat(self.tokenizer, message) # print(f"prompt:{prompt}\nresponse:{resp}\n") return resp