import os
from typing import Dict, Optional, List
from langchain.llms.base import BaseLLM, LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks
import torch

from transformers import AutoTokenizer, AutoModel, AutoConfig, AutoModelForCausalLM
from transformers.generation.utils import GenerationConfig

from pydantic import root_validator


class BaichuanLLM(LLM):
    model_name: str = "baichuan-inc/Baichuan-13B-Chat"
    quantization_bit: Optional[int] = None

    tokenizer: AutoTokenizer = None
    model: AutoModel = None

    def _llm_type(self) -> str:
        return "chatglm_local"

    @root_validator()
    def validate_environment(self, values: Dict) -> Dict:
        if not values["model_name"]:
            raise ValueError("No model name provided.")

        model_name = values["model_name"]
        tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            # device_map="auto",
            trust_remote_code=True
        )
        model.generation_config = GenerationConfig.from_pretrained(
            model_name
        )
        if values["quantization_bit"]:
            print(f"Quantized to {values['quantization_bit']} bit")
            model = model.quantize(values["quantization_bit"]).cuda()
        else:
            model = model.half().cuda()

        model = model.eval()

        values["tokenizer"] = tokenizer
        values["model"] = model
        return values

    def _call(self, prompt: str, stop: Optional[List[str]] = None,
              run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str:
        message = [{"role": "user", "content": prompt}]
        resp = self.model.chat(self.tokenizer, message)
        # print(f"prompt:{prompt}\nresponse:{resp}\n")
        return resp