import os
import requests
from typing import Dict, Optional, List, Any, Mapping, Iterator
from pydantic import root_validator
from langchain.llms.base import LLM
from langchain.cache import InMemoryCache
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks, AsyncCallbackManagerForLLMRun
import qianfan
from qianfan import ChatCompletion


# 启动llm的缓存
# langchain.llm_cache = InMemoryCache()

class ChatERNIESerLLM(LLM):
    # 模型服务url
    chat_completion: ChatCompletion = None

    # url: str = "http://127.0.0.1:8000"
    chat_history: dict = []
    out_stream: bool = False
    cache: bool = False
    model_name: str = "ERNIE-Bot"

    # def __init__(self):
    #     self.chat_completion = qianfan.ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu")

    @property
    def _llm_type(self) -> str:
        return self.model_name

    def get_num_tokens(self, text: str) -> int:
        return len(text)

    @staticmethod
    def convert_data(data):
        result = []
        for item in data:
            result.append({'q': item[0], 'a': item[1]})
        return result

    def _call(self, prompt: str,
              stop: Optional[List[str]] = None,
              run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
              stream=False,
              **kwargs: Any) -> str:
        resp = self.chat_completion.do(model=self.model_name, messages=[{
            "role": "user",
            "content": prompt
        }])
        print(resp)
        assert resp.code == 200
        return resp.body["result"]

    async def _post_stream(self,
                           query: Dict,
                           run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
                           stream=False) -> Any:
        """POST请求
        """
        async for r in await self.chat_completion.ado(model=self.model_name, messages=[query], stream=stream):
            assert r.code == 200
            if run_manager:
                for _callable in run_manager.get_sync().handlers:
                    await _callable.on_llm_new_token(r.body["result"])

    async def _acall(
            self,
            prompt: str,
            stop: Optional[List[str]] = None,
            run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
            **kwargs: Any,
    ) -> str:
        await self._post_stream(query={
            "role": "user",
            "content": prompt
        }, stream=True, run_manager=run_manager)
        return ''