ernie_with_sdk.py 2.54 KB
Newer Older
1 2
import os
import requests
陈正乐 committed
3
from typing import Dict, Optional, List, Any, Mapping, Iterator
4
from pydantic import root_validator
5 6 7 8
from langchain_core.language_models import LLM
from langchain_community.cache import InMemoryCache
from langchain.callbacks.manager import Callbacks
from langchain_core.callbacks import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun
9 10 11
import qianfan
from qianfan import ChatCompletion

陈正乐 committed
12

13 14 15 16 17
# 启动llm的缓存
# langchain.llm_cache = InMemoryCache()

class ChatERNIESerLLM(LLM):
    # 模型服务url
陈正乐 committed
18
    chat_completion: ChatCompletion = None
19 20

    # url: str = "http://127.0.0.1:8000"
陈正乐 committed
21
    chat_history: dict = []
22 23
    out_stream: bool = False
    cache: bool = False
陈正乐 committed
24 25
    model_name: str = "ERNIE-Bot"

26 27 28 29 30 31
    # def __init__(self):
    #     self.chat_completion = qianfan.ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu")

    @property
    def _llm_type(self) -> str:
        return self.model_name
陈正乐 committed
32

33 34
    def get_num_tokens(self, text: str) -> int:
        return len(text)
陈正乐 committed
35 36 37

    @staticmethod
    def convert_data(data):
38 39 40 41 42 43 44 45
        result = []
        for item in data:
            result.append({'q': item[0], 'a': item[1]})
        return result

    def _call(self, prompt: str,
              stop: Optional[List[str]] = None,
              run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
陈正乐 committed
46
              stream=False,
47
              **kwargs: Any) -> str:
陈正乐 committed
48
        resp = self.chat_completion.do(model=self.model_name, messages=[{
49 50 51 52 53 54 55 56
            "role": "user",
            "content": prompt
        }])
        print(resp)
        assert resp.code == 200
        return resp.body["result"]

    async def _post_stream(self,
陈正乐 committed
57 58 59
                           query: Dict,
                           run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
                           stream=False) -> Any:
60 61
        """POST请求
        """
陈正乐 committed
62
        async for r in await self.chat_completion.ado(model=self.model_name, messages=[query], stream=stream):
63 64
            assert r.code == 200
            if run_manager:
陈正乐 committed
65 66 67
                for _callable in run_manager.get_sync().handlers:
                    await _callable.on_llm_new_token(r.body["result"])

68
    async def _acall(
陈正乐 committed
69 70 71 72 73
            self,
            prompt: str,
            stop: Optional[List[str]] = None,
            run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
            **kwargs: Any,
74 75
    ) -> str:
        await self._post_stream(query={
陈正乐 committed
76 77 78
            "role": "user",
            "content": prompt
        }, stream=True, run_manager=run_manager)
79
        return ''