import os import requests from typing import Dict, Optional, List, Any, Mapping, Iterator from pydantic import root_validator import torch from transformers import AutoTokenizer, AutoModel, AutoConfig import langchain from langchain_core.language_models import BaseLLM, LLM from langchain_openai import OpenAI from langchain_community.cache import InMemoryCache from langchain.callbacks.manager import Callbacks from langchain_core.callbacks import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun class ChatGLMSerLLM(OpenAI): def get_token_ids(self, text: str) -> List[int]: if self.model_name.__contains__("chatglm"): ## 发起http请求,获取token_ids url = f"{self.openai_api_base}/num_tokens" query = {"prompt": text, "model": self.model_name} _headers = {"Content_Type": "application/json", "Authorization": "chatglm " + self.openai_api_key} resp = self._post(url=url, query=query, headers=_headers) if resp.status_code == 200: resp_json = resp.json() print(resp_json) predictions = resp_json['choices'][0]['text'] ## predictions字符串转int return [int(predictions)] return [len(text)] @classmethod def _post(cls, url: str, query: Dict, headers: Dict) -> Any: """POST请求 """ _headers = {"Content_Type": "application/json"} _headers.update(headers) with requests.session() as sess: resp = sess.post(url, json=query, headers=_headers, timeout=300) return resp