import os
import requests
from typing import Dict, Optional, List, Any, Mapping, Iterator
from pydantic import root_validator

import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig

import langchain
from langchain.llms.base import BaseLLM, LLM
from langchain.cache import InMemoryCache
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks, AsyncCallbackManagerForLLMRun
import aiohttp
import asyncio


# 启动llm的缓存
# langchain.llm_cache = InMemoryCache()

class ChatGLMLocLLM(LLM):
    model_name: str = "THUDM/chatglm-6b"
    ptuning_checkpoint: str = None
    quantization_bit: Optional[int] = None
    pre_seq_len: Optional[int] = None
    prefix_projection: bool = False

    tokenizer: AutoTokenizer = None
    model: AutoModel = None

    def _llm_type(self) -> str:
        return "chatglm_local"

    # @root_validator()
    @staticmethod
    def validate_environment(values: Dict) -> Dict:
        if not values["model_name"]:
            raise ValueError("No model name provided.")

        model_name = values["model_name"]
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
        # model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True)
        if values["pre_seq_len"]:
            config.pre_seq_len = values["pre_seq_len"]
        if values["prefix_projection"]:
            config.prefix_projection = values["prefix_projection"]
        if values["ptuning_checkpoint"]:
            ptuning_checkpoint = values["ptuning_checkpoint"]
            print(f"Loading prefix_encoder weight from {ptuning_checkpoint}")
            model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True)
            prefix_state_dict = torch.load(os.path.join(ptuning_checkpoint, "pytorch_model.bin"))
            new_prefix_state_dict = {}
            for k, v in prefix_state_dict.items():
                if k.startswith("transformer.prefix_encoder."):
                    new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v

            model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
        else:
            model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True).half().cuda()

        if values["pre_seq_len"]:
            # P-tuning v2
            model = model.half().cuda()
            model.transformer.prefix_encoder.float().cuda()

        if values["quantization_bit"]:
            print(f"Quantized to {values['quantization_bit']} bit")
            model = model.quantize(values["quantization_bit"])

        model = model.eval()

        values["tokenizer"] = tokenizer
        values["model"] = model
        return values

    def _call(self, prompt: str, stop: Optional[List[str]] = None,
              run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str:
        resp, his = self.model.chat(self.tokenizer, prompt)
        # print(f"prompt:{prompt}\nresponse:{resp}\n")
        return resp


class ChatGLMSerLLM(LLM):
    # 模型服务url
    url: str = "http://127.0.0.1:8000"
    chat_history: dict = []
    out_stream: bool = False
    cache: bool = False

    @property
    def _llm_type(self) -> str:
        return "chatglm3-6b"

    def get_num_tokens(self, text: str) -> int:
        resp = self._post(url=self.url + "/tokens", query=self._construct_query(text))
        if resp.status_code == 200:
            resp_json = resp.json()
            predictions = resp_json['response']
            # display(self.convert_data(resp_json['history']))
            return predictions
        else:
            return len(text)

    @staticmethod
    def convert_data(data):
        result = []
        for item in data:
            result.append({'q': item[0], 'a': item[1]})
        return result

    def _construct_query(self, prompt: str, temperature=0.95) -> Dict:
        """构造请求体
        """
        # self.chat_history.append({"role": "user", "content": prompt})
        query = {
            "prompt": prompt,
            "history": self.chat_history,
            "max_length": 4096,
            "top_p": 0.7,
            "temperature": temperature
        }
        return query

    @classmethod
    def _post(cls, url: str,
              query: Dict) -> Any:
        """POST请求
        """
        _headers = {"Content_Type": "application/json"}
        with requests.session() as sess:
            resp = sess.post(url,
                             json=query,
                             headers=_headers,
                             timeout=300)
        return resp

    @staticmethod
    async def _post_stream(url: str,
                           query: Dict,
                           run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, stream=False) -> Any:
        """POST请求
        """
        _headers = {"Content_Type": "application/json"}

        async with aiohttp.ClientSession() as sess:
            async with sess.post(url, json=query, headers=_headers, timeout=300) as response:
                if response.status == 200:
                    if stream and not run_manager:
                        print('not callable')
                    if run_manager:
                        for _callable in run_manager.get_sync().handlers:
                            await _callable.on_llm_start(None, None)
                    async for chunk in response.content.iter_any():
                        # 处理每个块的数据
                        if chunk and run_manager:
                            for _callable in run_manager.get_sync().handlers:
                                # print(chunk.decode("utf-8"),end="")
                                await _callable.on_llm_new_token(chunk.decode("utf-8"))
                    if run_manager:
                        for _callable in run_manager.get_sync().handlers:
                            await _callable.on_llm_end(None)
                else:
                    raise ValueError(f'glm 请求异常,http code:{response.status}')

    def _call(self, prompt: str,
              stop: Optional[List[str]] = None,
              run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
              stream=False,
              **kwargs: Any) -> str:
        query = self._construct_query(prompt=prompt,
                                      temperature=kwargs["temperature"] if "temperature" in kwargs else 0.95)
        # display("==============================")
        # display(query)
        # post
        if stream or self.out_stream:
            async def _post_stream():
                await self._post_stream(url=self.url + "/stream",
                                        query=query, run_manager=run_manager, stream=stream or self.out_stream)

            asyncio.run(_post_stream())
            return ''
        else:
            resp = self._post(url=self.url,
                              query=query)

            if resp.status_code == 200:
                resp_json = resp.json()
                # self.chat_history.append({'q': prompt, 'a': resp_json['response']})
                predictions = resp_json['response']
                # display(self.convert_data(resp_json['history']))
                return predictions
            else:
                raise ValueError(f'glm 请求异常,http code:{resp.status_code}')

    async def _acall(
            self,
            prompt: str,
            stop: Optional[List[str]] = None,
            run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
            **kwargs: Any,
    ) -> str:
        query = self._construct_query(prompt=prompt,
                                      temperature=kwargs["temperature"] if "temperature" in kwargs else 0.95)
        await self._post_stream(url=self.url + "/stream",
                                query=query, run_manager=run_manager, stream=self.out_stream)
        return ''

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters.
        """
        _param_dict = {
            "url": self.url
        }
        return _param_dict