from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field
from typing import Type, Any,List
import re
from src.server.get_similarity import GetSimilarityWithExt
import time
from src.server.rerank import BgeRerank
from langchain_core.documents import Document
import json
from src.config.consts import (
    RERANK_MODEL_PATH,

)



class IssuanceArgs(BaseModel):
    question: str = Field(description="对话问题")
    history: str = Field(description="历史对话记录")
    location: list = Field(description="question参数中的行政区划名称")


class RAGQuery(BaseTool):
    name = "rag_query"
    description = """这是一个区(县)级的水文气象地质知识库,当咨询区(县)的水文气象地质等相关信息的时候,你可以提供数据和资料。注意,这个查询只能获取单个区(县)的水文气象地质等相关信息,当需要查询省市的详细信息时,需要获取改省市下的具体行政规划,并一一获取具体的区(县)的水文气象地质等相关信息。这个知识库中信息并不全面,有可能缺失。这个工具生成的结果需要再调用rag_analysis这个工具来进行解析,每次调用完成之后,一定要调用rag_analysis去解析结果"""
    args_schema: Type[BaseModel] = IssuanceArgs
    rerank: Any  # 替换 Any 为适当的类型
    rerank_model: Any  # 替换 Any 为适当的类型
    faiss_db: Any  # 替换 Any 为适当的类型
    prompt: Any  # 假设 prompt 是一个字符串
    db: Any
    llm_chain: Any

    def __init__(self,_faiss_db,_rerank,_prompt,_db,_llm_chain):
        super().__init__()

        self.rerank = _rerank
        self.rerank_model = BgeRerank(RERANK_MODEL_PATH)
        self.faiss_db = _faiss_db
        self.prompt = _prompt
        self.db = _db
        self.llm_chain = _llm_chain


    def get_similarity_with_ext_origin(self, _ext,_location):
        return GetSimilarityWithExt(_question=_ext, _faiss_db=self.faiss_db,_location=_location)



    def _run(self, question: str, history: str,location:list) :
        print(location)
        # split_str = jieba_split(question)
        # split_list = []
        # for l in split_str:
        #     split_list.append(l)

        start = time.time()
        answer = self.db.find_like_doc(location)
        end = time.time()
        print('find_like_doc time: %s Seconds' % (end - start))
        print(len(answer) if answer else 0)
        split_docs = []
        for a in answer if answer else []:
            d = Document(page_content=a[0], metadata=json.loads(a[1]))
            split_docs.append(d)
        print(len(split_docs))
        # if len(split_docs)>10:
        #     split_docs= split_docs[:10]

        start = time.time()
        result = self.rerank.extend_query_with_str(question, history)
        end = time.time()
        print('extend_query_with_str time: %s Seconds' % (end - start))
        print(result)
        matches = re.findall(r'"([^"]+)"', result.content)

        print(matches)
        similarity = self.get_similarity_with_ext_origin(matches,_location=location)
        # cur_similarity = similarity.get_rerank(self.rerank_model)
        cur_similarity = similarity.get_rerank_with_doc(self.rerank_model,split_docs)
        # docs = similarity.get_rerank_docs()
        # print(cur_similarity)
        # # geo_result = "以下是详细的水文气象地质资料:"+cur_similarity+"\n 以下是原问题"+question
        # # cur_question = self.prompt.format(history=history, context=cur_similarity, question=question)
        cur_answer = self.llm_chain.run(context=cur_similarity, question=question)
        # print(cur_answer)
        # return cur_answer
        # loc = location[0]
        # location = location[1:]
        # for i in location:
        #     loc += (","+i)
        return {"详细信息":cur_answer,"参考文档": cur_similarity}


class RAGAnalysisArgs(BaseModel):
    question: str = Field(description="rag_query附带的问题")
    doc: str = Field(description="rag_query获取的县级水文气象地质参考资料")

class RAGAnalysisQuery(BaseTool):
    name = "rag_analysis"
    description = """这是一个区(县)级的水文气象地质知识库解析工具,从rag_query查询到的资料,需要结合原始问题,用这个工具来解析出想要的答案,调用rag_query工具做查询之后一定要调用这个工具"""
    args_schema: Type[BaseModel] = RAGAnalysisArgs
    llm_chain: Any


    def __init__(self,_llm_chain):
        super().__init__()
        self.llm_chain = _llm_chain

    def _run(self, question: str, doc: list):
        cur_answer = self.llm_chain.run(context=doc, question=question)
        return cur_answer