import os import requests from typing import Dict, Optional, List, Any, Mapping, Iterator from pydantic import root_validator import torch from transformers import AutoTokenizer, AutoModel, AutoConfig import langchain from langchain.llms.base import BaseLLM, LLM from langchain.cache import InMemoryCache from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks, AsyncCallbackManagerForLLMRun import aiohttp import asyncio # 启动llm的缓存 # langchain.llm_cache = InMemoryCache() class ChatGLMLocLLM(LLM): model_name: str = "THUDM/chatglm-6b" ptuning_checkpoint: str = None quantization_bit: Optional[int] = None pre_seq_len: Optional[int] = None prefix_projection: bool = False tokenizer: AutoTokenizer = None model: AutoModel = None def _llm_type(self) -> str: return "chatglm_local" # @root_validator() @staticmethod def validate_environment(values: Dict) -> Dict: if not values["model_name"]: raise ValueError("No model name provided.") model_name = values["model_name"] tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) # model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True) if values["pre_seq_len"]: config.pre_seq_len = values["pre_seq_len"] if values["prefix_projection"]: config.prefix_projection = values["prefix_projection"] if values["ptuning_checkpoint"]: ptuning_checkpoint = values["ptuning_checkpoint"] print(f"Loading prefix_encoder weight from {ptuning_checkpoint}") model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True) prefix_state_dict = torch.load(os.path.join(ptuning_checkpoint, "pytorch_model.bin")) new_prefix_state_dict = {} for k, v in prefix_state_dict.items(): if k.startswith("transformer.prefix_encoder."): new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) else: model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True).half().cuda() if values["pre_seq_len"]: # P-tuning v2 model = model.half().cuda() model.transformer.prefix_encoder.float().cuda() if values["quantization_bit"]: print(f"Quantized to {values['quantization_bit']} bit") model = model.quantize(values["quantization_bit"]) model = model.eval() values["tokenizer"] = tokenizer values["model"] = model return values def _call(self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str: resp, his = self.model.chat(self.tokenizer, prompt) # print(f"prompt:{prompt}\nresponse:{resp}\n") return resp class ChatGLMSerLLM(LLM): # 模型服务url url: str = "http://127.0.0.1:8000" chat_history: dict = [] out_stream: bool = False cache: bool = False @property def _llm_type(self) -> str: return "chatglm3-6b" def get_num_tokens(self, text: str) -> int: resp = self._post(url=self.url + "/tokens", query=self._construct_query(text)) if resp.status_code == 200: resp_json = resp.json() predictions = resp_json['response'] # display(self.convert_data(resp_json['history'])) return predictions else: return len(text) @staticmethod def convert_data(data): result = [] for item in data: result.append({'q': item[0], 'a': item[1]}) return result def _construct_query(self, prompt: str, temperature=0.95) -> Dict: """构造请求体 """ # self.chat_history.append({"role": "user", "content": prompt}) query = { "prompt": prompt, "history": self.chat_history, "max_length": 4096, "top_p": 0.7, "temperature": temperature } return query @classmethod def _post(cls, url: str, query: Dict) -> Any: """POST请求 """ _headers = {"Content_Type": "application/json"} with requests.session() as sess: resp = sess.post(url, json=query, headers=_headers, timeout=300) return resp @staticmethod async def _post_stream(url: str, query: Dict, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, stream=False) -> Any: """POST请求 """ _headers = {"Content_Type": "application/json"} async with aiohttp.ClientSession() as sess: async with sess.post(url, json=query, headers=_headers, timeout=300) as response: if response.status == 200: if stream and not run_manager: print('not callable') if run_manager: for _callable in run_manager.get_sync().handlers: await _callable.on_llm_start(None, None) async for chunk in response.content.iter_any(): # 处理每个块的数据 if chunk and run_manager: for _callable in run_manager.get_sync().handlers: # print(chunk.decode("utf-8"),end="") await _callable.on_llm_new_token(chunk.decode("utf-8")) if run_manager: for _callable in run_manager.get_sync().handlers: await _callable.on_llm_end(None) else: raise ValueError(f'glm 请求异常,http code:{response.status}') def _call(self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, stream=False, **kwargs: Any) -> str: query = self._construct_query(prompt=prompt, temperature=kwargs["temperature"] if "temperature" in kwargs else 0.95) # display("==============================") # display(query) # post if stream or self.out_stream: async def _post_stream(): await self._post_stream(url=self.url + "/stream", query=query, run_manager=run_manager, stream=stream or self.out_stream) asyncio.run(_post_stream()) return '' else: resp = self._post(url=self.url, query=query) if resp.status_code == 200: resp_json = resp.json() # self.chat_history.append({'q': prompt, 'a': resp_json['response']}) predictions = resp_json['response'] # display(self.convert_data(resp_json['history'])) return predictions else: raise ValueError(f'glm 请求异常,http code:{resp.status_code}') async def _acall( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: query = self._construct_query(prompt=prompt, temperature=kwargs["temperature"] if "temperature" in kwargs else 0.95) await self._post_stream(url=self.url + "/stream", query=query, run_manager=run_manager, stream=self.out_stream) return '' @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters. """ _param_dict = { "url": self.url } return _param_dict