chatglm_openapi.py 1.69 KB
Newer Older
1 2
import os
import requests
陈正乐 committed
3
from typing import Dict, Optional, List, Any, Mapping, Iterator
4 5 6
from pydantic import root_validator

import torch
陈正乐 committed
7
from transformers import AutoTokenizer, AutoModel, AutoConfig
8 9

import langchain
10
from langchain_core.language_models import BaseLLM, LLM
11
from langchain_openai import OpenAI
12 13 14
from langchain_community.cache import InMemoryCache
from langchain.callbacks.manager import Callbacks
from langchain_core.callbacks import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun
15

陈正乐 committed
16

17
class ChatGLMSerLLM(OpenAI):
陈正乐 committed
18

19
    def get_token_ids(self, text: str) -> List[int]:
陈正乐 committed
20

21 22 23
        if self.model_name.__contains__("chatglm"):
            ## 发起http请求,获取token_ids
            url = f"{self.openai_api_base}/num_tokens"
陈正乐 committed
24 25 26 27

            query = {"prompt": text, "model": self.model_name}
            _headers = {"Content_Type": "application/json", "Authorization": "chatglm " + self.openai_api_key}
            resp = self._post(url=url, query=query, headers=_headers)
28 29 30 31 32 33 34
            if resp.status_code == 200:
                resp_json = resp.json()
                print(resp_json)
                predictions = resp_json['choices'][0]['text']
                ## predictions字符串转int
                return [int(predictions)]
        return [len(text)]
陈正乐 committed
35

36
    @classmethod
陈正乐 committed
37 38
    def _post(cls, url: str,
              query: Dict, headers: Dict) -> Any:
39 40 41 42 43 44 45 46 47
        """POST请求
        """
        _headers = {"Content_Type": "application/json"}
        _headers.update(headers)
        with requests.session() as sess:
            resp = sess.post(url,
                             json=query,
                             headers=_headers,
                             timeout=300)
陈正乐 committed
48
        return resp