chatglm.py 8.08 KB
Newer Older
1 2
import os
import requests
陈正乐 committed
3
from typing import Dict, Optional, List, Any, Mapping, Iterator
4 5 6
from pydantic import root_validator

import torch
陈正乐 committed
7
from transformers import AutoTokenizer, AutoModel, AutoConfig
8 9

import langchain
陈正乐 committed
10
from langchain.llms.base import BaseLLM, LLM
11 12 13 14 15
from langchain.cache import InMemoryCache
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks, AsyncCallbackManagerForLLMRun
import aiohttp
import asyncio

陈正乐 committed
16

17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
# 启动llm的缓存
# langchain.llm_cache = InMemoryCache()

class ChatGLMLocLLM(LLM):
    model_name: str = "THUDM/chatglm-6b"
    ptuning_checkpoint: str = None
    quantization_bit: Optional[int] = None
    pre_seq_len: Optional[int] = None
    prefix_projection: bool = False

    tokenizer: AutoTokenizer = None
    model: AutoModel = None

    def _llm_type(self) -> str:
        return "chatglm_local"
陈正乐 committed
32

33
    # @root_validator()
陈正乐 committed
34 35
    @staticmethod
    def validate_environment(values: Dict) -> Dict:
36 37 38 39
        if not values["model_name"]:
            raise ValueError("No model name provided.")

        model_name = values["model_name"]
陈正乐 committed
40
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
        config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
        # model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True)
        if values["pre_seq_len"]:
            config.pre_seq_len = values["pre_seq_len"]
        if values["prefix_projection"]:
            config.prefix_projection = values["prefix_projection"]
        if values["ptuning_checkpoint"]:
            ptuning_checkpoint = values["ptuning_checkpoint"]
            print(f"Loading prefix_encoder weight from {ptuning_checkpoint}")
            model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True)
            prefix_state_dict = torch.load(os.path.join(ptuning_checkpoint, "pytorch_model.bin"))
            new_prefix_state_dict = {}
            for k, v in prefix_state_dict.items():
                if k.startswith("transformer.prefix_encoder."):
                    new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v

            model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
        else:
            model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True).half().cuda()
陈正乐 committed
60

61 62 63 64 65 66 67
        if values["pre_seq_len"]:
            # P-tuning v2
            model = model.half().cuda()
            model.transformer.prefix_encoder.float().cuda()

        if values["quantization_bit"]:
            print(f"Quantized to {values['quantization_bit']} bit")
陈正乐 committed
68
            model = model.quantize(values["quantization_bit"])
69 70 71 72 73 74 75

        model = model.eval()

        values["tokenizer"] = tokenizer
        values["model"] = model
        return values

陈正乐 committed
76 77 78
    def _call(self, prompt: str, stop: Optional[List[str]] = None,
              run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str:
        resp, his = self.model.chat(self.tokenizer, prompt)
79 80 81
        # print(f"prompt:{prompt}\nresponse:{resp}\n")
        return resp

陈正乐 committed
82

83 84 85
class ChatGLMSerLLM(LLM):
    # 模型服务url
    url: str = "http://127.0.0.1:8000"
陈正乐 committed
86
    chat_history: dict = []
87 88 89 90 91 92
    out_stream: bool = False
    cache: bool = False

    @property
    def _llm_type(self) -> str:
        return "chatglm3-6b"
陈正乐 committed
93

94
    def get_num_tokens(self, text: str) -> int:
陈正乐 committed
95
        resp = self._post(url=self.url + "/tokens", query=self._construct_query(text))
96 97 98 99 100 101 102
        if resp.status_code == 200:
            resp_json = resp.json()
            predictions = resp_json['response']
            # display(self.convert_data(resp_json['history']))
            return predictions
        else:
            return len(text)
陈正乐 committed
103 104 105

    @staticmethod
    def convert_data(data):
106 107 108 109 110
        result = []
        for item in data:
            result.append({'q': item[0], 'a': item[1]})
        return result

陈正乐 committed
111
    def _construct_query(self, prompt: str, temperature=0.95) -> Dict:
112 113 114 115 116
        """构造请求体
        """
        # self.chat_history.append({"role": "user", "content": prompt})
        query = {
            "prompt": prompt,
陈正乐 committed
117
            "history": self.chat_history,
118 119 120 121 122
            "max_length": 4096,
            "top_p": 0.7,
            "temperature": temperature
        }
        return query
陈正乐 committed
123

124
    @classmethod
陈正乐 committed
125
    def _post(cls, url: str,
126 127 128 129 130 131 132 133 134 135
              query: Dict) -> Any:
        """POST请求
        """
        _headers = {"Content_Type": "application/json"}
        with requests.session() as sess:
            resp = sess.post(url,
                             json=query,
                             headers=_headers,
                             timeout=300)
        return resp
陈正乐 committed
136 137 138 139 140

    @staticmethod
    async def _post_stream(url: str,
                           query: Dict,
                           run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, stream=False) -> Any:
141 142 143
        """POST请求
        """
        _headers = {"Content_Type": "application/json"}
陈正乐 committed
144

145
        async with aiohttp.ClientSession() as sess:
陈正乐 committed
146
            async with sess.post(url, json=query, headers=_headers, timeout=300) as response:
147 148 149 150
                if response.status == 200:
                    if stream and not run_manager:
                        print('not callable')
                    if run_manager:
陈正乐 committed
151 152
                        for _callable in run_manager.get_sync().handlers:
                            await _callable.on_llm_start(None, None)
153 154 155
                    async for chunk in response.content.iter_any():
                        # 处理每个块的数据
                        if chunk and run_manager:
陈正乐 committed
156
                            for _callable in run_manager.get_sync().handlers:
157
                                # print(chunk.decode("utf-8"),end="")
陈正乐 committed
158
                                await _callable.on_llm_new_token(chunk.decode("utf-8"))
159
                    if run_manager:
陈正乐 committed
160 161
                        for _callable in run_manager.get_sync().handlers:
                            await _callable.on_llm_end(None)
162 163
                else:
                    raise ValueError(f'glm 请求异常,http code:{response.status}')
陈正乐 committed
164

165 166 167
    def _call(self, prompt: str,
              stop: Optional[List[str]] = None,
              run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
陈正乐 committed
168
              stream=False,
169
              **kwargs: Any) -> str:
陈正乐 committed
170 171
        query = self._construct_query(prompt=prompt,
                                      temperature=kwargs["temperature"] if "temperature" in kwargs else 0.95)
172 173 174 175 176
        # display("==============================")
        # display(query)
        # post
        if stream or self.out_stream:
            async def _post_stream():
陈正乐 committed
177 178 179
                await self._post_stream(url=self.url + "/stream",
                                        query=query, run_manager=run_manager, stream=stream or self.out_stream)

180 181 182 183
            asyncio.run(_post_stream())
            return ''
        else:
            resp = self._post(url=self.url,
陈正乐 committed
184
                              query=query)
185 186 187 188 189 190 191 192 193

            if resp.status_code == 200:
                resp_json = resp.json()
                # self.chat_history.append({'q': prompt, 'a': resp_json['response']})
                predictions = resp_json['response']
                # display(self.convert_data(resp_json['history']))
                return predictions
            else:
                raise ValueError(f'glm 请求异常,http code:{resp.status_code}')
陈正乐 committed
194

195
    async def _acall(
陈正乐 committed
196 197 198 199 200
            self,
            prompt: str,
            stop: Optional[List[str]] = None,
            run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
            **kwargs: Any,
201
    ) -> str:
陈正乐 committed
202 203 204 205 206
        query = self._construct_query(prompt=prompt,
                                      temperature=kwargs["temperature"] if "temperature" in kwargs else 0.95)
        await self._post_stream(url=self.url + "/stream",
                                query=query, run_manager=run_manager, stream=self.out_stream)
        return ''
207 208 209 210 211 212 213 214

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters.
        """
        _param_dict = {
            "url": self.url
        }
陈正乐 committed
215
        return _param_dict