import os
import requests
from typing import Dict, Optional, List, Any, Mapping, Iterator
from pydantic import root_validator

import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig

import langchain
from langchain_core.language_models import BaseLLM, LLM
from langchain_openai import OpenAI
from langchain_community.cache import InMemoryCache
from langchain.callbacks.manager import Callbacks
from langchain_core.callbacks import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun


class ChatGLMSerLLM(OpenAI):

    def get_token_ids(self, text: str) -> List[int]:

        if self.model_name.__contains__("chatglm"):
            ## 发起http请求,获取token_ids
            url = f"{self.openai_api_base}/num_tokens"

            query = {"prompt": text, "model": self.model_name}
            _headers = {"Content_Type": "application/json", "Authorization": "chatglm " + self.openai_api_key}
            resp = self._post(url=url, query=query, headers=_headers)
            if resp.status_code == 200:
                resp_json = resp.json()
                print(resp_json)
                predictions = resp_json['choices'][0]['text']
                ## predictions字符串转int
                return [int(predictions)]
        return [len(text)]

    @classmethod
    def _post(cls, url: str,
              query: Dict, headers: Dict) -> Any:
        """POST请求
        """
        _headers = {"Content_Type": "application/json"}
        _headers.update(headers)
        with requests.session() as sess:
            resp = sess.post(url,
                             json=query,
                             headers=_headers,
                             timeout=300)
        return resp