chatglm.py 8.13 KB
Newer Older
1 2
import os
import requests
陈正乐 committed
3
from typing import Dict, Optional, List, Any, Mapping, Iterator
4 5 6
from pydantic import root_validator

import torch
陈正乐 committed
7
from transformers import AutoTokenizer, AutoModel, AutoConfig
8 9

import langchain
10 11 12 13
from langchain_core.language_models import BaseLLM, LLM
from langchain_community.cache import InMemoryCache
from langchain.callbacks.manager import Callbacks
from langchain_core.callbacks import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun
14 15 16
import aiohttp
import asyncio

陈正乐 committed
17

18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
# 启动llm的缓存
# langchain.llm_cache = InMemoryCache()

class ChatGLMLocLLM(LLM):
    model_name: str = "THUDM/chatglm-6b"
    ptuning_checkpoint: str = None
    quantization_bit: Optional[int] = None
    pre_seq_len: Optional[int] = None
    prefix_projection: bool = False

    tokenizer: AutoTokenizer = None
    model: AutoModel = None

    def _llm_type(self) -> str:
        return "chatglm_local"
陈正乐 committed
33

34
    # @root_validator()
陈正乐 committed
35 36
    @staticmethod
    def validate_environment(values: Dict) -> Dict:
37 38 39 40
        if not values["model_name"]:
            raise ValueError("No model name provided.")

        model_name = values["model_name"]
陈正乐 committed
41
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
        config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
        # model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True)
        if values["pre_seq_len"]:
            config.pre_seq_len = values["pre_seq_len"]
        if values["prefix_projection"]:
            config.prefix_projection = values["prefix_projection"]
        if values["ptuning_checkpoint"]:
            ptuning_checkpoint = values["ptuning_checkpoint"]
            print(f"Loading prefix_encoder weight from {ptuning_checkpoint}")
            model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True)
            prefix_state_dict = torch.load(os.path.join(ptuning_checkpoint, "pytorch_model.bin"))
            new_prefix_state_dict = {}
            for k, v in prefix_state_dict.items():
                if k.startswith("transformer.prefix_encoder."):
                    new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v

            model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
        else:
            model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True).half().cuda()
陈正乐 committed
61

62 63 64 65 66 67 68
        if values["pre_seq_len"]:
            # P-tuning v2
            model = model.half().cuda()
            model.transformer.prefix_encoder.float().cuda()

        if values["quantization_bit"]:
            print(f"Quantized to {values['quantization_bit']} bit")
陈正乐 committed
69
            model = model.quantize(values["quantization_bit"])
70 71 72 73 74 75 76

        model = model.eval()

        values["tokenizer"] = tokenizer
        values["model"] = model
        return values

陈正乐 committed
77 78 79
    def _call(self, prompt: str, stop: Optional[List[str]] = None,
              run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str:
        resp, his = self.model.chat(self.tokenizer, prompt)
80 81 82
        # print(f"prompt:{prompt}\nresponse:{resp}\n")
        return resp

陈正乐 committed
83

84 85 86
class ChatGLMSerLLM(LLM):
    # 模型服务url
    url: str = "http://127.0.0.1:8000"
陈正乐 committed
87
    chat_history: dict = []
88 89 90 91 92 93
    out_stream: bool = False
    cache: bool = False

    @property
    def _llm_type(self) -> str:
        return "chatglm3-6b"
陈正乐 committed
94

95
    def get_num_tokens(self, text: str) -> int:
陈正乐 committed
96
        resp = self._post(url=self.url + "/tokens", query=self._construct_query(text))
97 98 99 100 101 102 103
        if resp.status_code == 200:
            resp_json = resp.json()
            predictions = resp_json['response']
            # display(self.convert_data(resp_json['history']))
            return predictions
        else:
            return len(text)
陈正乐 committed
104 105 106

    @staticmethod
    def convert_data(data):
107 108 109 110 111
        result = []
        for item in data:
            result.append({'q': item[0], 'a': item[1]})
        return result

陈正乐 committed
112
    def _construct_query(self, prompt: str, temperature=0.95) -> Dict:
113 114 115 116 117
        """构造请求体
        """
        # self.chat_history.append({"role": "user", "content": prompt})
        query = {
            "prompt": prompt,
陈正乐 committed
118
            "history": self.chat_history,
119 120 121 122 123
            "max_length": 4096,
            "top_p": 0.7,
            "temperature": temperature
        }
        return query
陈正乐 committed
124

125
    @classmethod
陈正乐 committed
126
    def _post(cls, url: str,
127 128 129 130 131 132 133 134 135 136
              query: Dict) -> Any:
        """POST请求
        """
        _headers = {"Content_Type": "application/json"}
        with requests.session() as sess:
            resp = sess.post(url,
                             json=query,
                             headers=_headers,
                             timeout=300)
        return resp
陈正乐 committed
137 138 139 140 141

    @staticmethod
    async def _post_stream(url: str,
                           query: Dict,
                           run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, stream=False) -> Any:
142 143 144
        """POST请求
        """
        _headers = {"Content_Type": "application/json"}
陈正乐 committed
145

146
        async with aiohttp.ClientSession() as sess:
陈正乐 committed
147
            async with sess.post(url, json=query, headers=_headers, timeout=300) as response:
148 149 150 151
                if response.status == 200:
                    if stream and not run_manager:
                        print('not callable')
                    if run_manager:
陈正乐 committed
152 153
                        for _callable in run_manager.get_sync().handlers:
                            await _callable.on_llm_start(None, None)
154 155 156
                    async for chunk in response.content.iter_any():
                        # 处理每个块的数据
                        if chunk and run_manager:
陈正乐 committed
157
                            for _callable in run_manager.get_sync().handlers:
158
                                # print(chunk.decode("utf-8"),end="")
陈正乐 committed
159
                                await _callable.on_llm_new_token(chunk.decode("utf-8"))
160
                    if run_manager:
陈正乐 committed
161 162
                        for _callable in run_manager.get_sync().handlers:
                            await _callable.on_llm_end(None)
163 164
                else:
                    raise ValueError(f'glm 请求异常,http code:{response.status}')
陈正乐 committed
165

166 167 168
    def _call(self, prompt: str,
              stop: Optional[List[str]] = None,
              run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
陈正乐 committed
169
              stream=False,
170
              **kwargs: Any) -> str:
陈正乐 committed
171 172
        query = self._construct_query(prompt=prompt,
                                      temperature=kwargs["temperature"] if "temperature" in kwargs else 0.95)
173 174 175 176 177
        # display("==============================")
        # display(query)
        # post
        if stream or self.out_stream:
            async def _post_stream():
陈正乐 committed
178 179 180
                await self._post_stream(url=self.url + "/stream",
                                        query=query, run_manager=run_manager, stream=stream or self.out_stream)

181 182 183 184
            asyncio.run(_post_stream())
            return ''
        else:
            resp = self._post(url=self.url,
陈正乐 committed
185
                              query=query)
186 187 188 189 190 191 192 193 194

            if resp.status_code == 200:
                resp_json = resp.json()
                # self.chat_history.append({'q': prompt, 'a': resp_json['response']})
                predictions = resp_json['response']
                # display(self.convert_data(resp_json['history']))
                return predictions
            else:
                raise ValueError(f'glm 请求异常,http code:{resp.status_code}')
陈正乐 committed
195

196
    async def _acall(
陈正乐 committed
197 198 199 200 201
            self,
            prompt: str,
            stop: Optional[List[str]] = None,
            run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
            **kwargs: Any,
202
    ) -> str:
陈正乐 committed
203 204 205 206 207
        query = self._construct_query(prompt=prompt,
                                      temperature=kwargs["temperature"] if "temperature" in kwargs else 0.95)
        await self._post_stream(url=self.url + "/stream",
                                query=query, run_manager=run_manager, stream=self.out_stream)
        return ''
208 209 210 211 212 213 214 215

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters.
        """
        _param_dict = {
            "url": self.url
        }
陈正乐 committed
216
        return _param_dict