import re
from src.server.get_similarity import GetSimilarityWithExt
import time
from src.server.rerank import BgeRerank
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate
from src.server.get_similarity import QAExt
from src.server.extend import LocationExt
import json
from src.agent.tool_divisions import complete_administrative_division,divisions
from langchain.chains import LLMChain
from src.config.consts import (
    RERANK_MODEL_PATH,
    prompt1
)


class RagQuery():
    def __init__(self,base_llm,_faiss_db,_db):
        self.qa_ext = QAExt(base_llm)
        self.location_ext = LocationExt(base_llm)
        self.rerank_model = BgeRerank(RERANK_MODEL_PATH)
        self.faiss_db = _faiss_db
        self.db = _db
        self.llm_chain = LLMChain(llm=base_llm, prompt=PromptTemplate(input_variables=["history","context", "question"], template=prompt1), llm_kwargs= {"temperature": 0})


    def get_similarity_with_ext_origin(self, _ext,_location):
        return GetSimilarityWithExt(_question=_ext, _faiss_db=self.faiss_db,_location=_location)

    def query(self, question: str, history: str) :
        location_result = self.location_ext.extend_query_str(question=question, history=history)
        index = location_result.content.find("提取到的行政区:")
        if index == -1:
            location_str = location_result.content
        else:
            location_str = location_result.content[index + len("提取到的行政区:"):]
        print(location_str)
        pattern = r'\[([^\]]+)\]'
        match = re.search(pattern, location_str)
        cities = []
        if match:
            cities = match.group(1).split(', ')
        cities_ext = []
        for m in cities:
            city_ext = complete_administrative_division(m, divisions)
            cities_ext.append(city_ext)
        location = []
        prompt = ""
        for city in cities_ext:
            if city is not None and "县(区)" in city:
                if isinstance(city["县(区)"], str):
                    location.append(city["县(区)"])
                    prompt += city["县(区)"] + "位于" + city["省"] + city["市"] + ","
                if isinstance(city["县(区)"], list):
                    location.extend(city["县(区)"])
                    prompt += city["省"] + city["市"] + "管辖"
                    for x in city["县(区)"]:
                        prompt += x + ","
        print(location)
        new_question = prompt + question
        print(new_question)
        split_docs_list = []
        for l in location:
            start = time.time()
            answer = self.db.find_like_doc(l)
            end = time.time()
            print('find_like_doc time: %s Seconds' % (end - start))
            print(len(answer) if answer else 0)
            split_docs = []
            for a in answer if answer else []:
                d = Document(page_content=a[0], metadata=json.loads(a[1]))
                split_docs.append(d)
            print(len(split_docs))
            # if len(split_docs) > 5:
            #     split_docs = split_docs[:5]
            split_docs_list.append(split_docs)

        start = time.time()
        result = self.qa_ext.extend_query_with_str(new_question, history)
        end = time.time()
        print('extend_query_with_str time: %s Seconds' % (end - start))
        print(result)
        matches = re.findall(r'"([^"]+)"', result.content)

        print(matches)
        similarity = self.get_similarity_with_ext_origin(matches,_location=location)
        cur_similarity = similarity.get_rerank_with_doc(self.rerank_model,split_docs_list)

        cur_answer = self.llm_chain.run(context=cur_similarity, question=new_question, history=history)
        return {"answer":cur_answer,"docs": cur_similarity}