rag_query.py 3.74 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
import re
from src.server.get_similarity import GetSimilarityWithExt
import time
from src.server.rerank import BgeRerank
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate
from src.server.get_similarity import QAExt
from src.server.extend import LocationExt
import json
from src.agent.tool_divisions import complete_administrative_division,divisions
from langchain.chains import LLMChain
from src.config.consts import (
    RERANK_MODEL_PATH,
    prompt1
)


class RagQuery():
    def __init__(self,base_llm,_faiss_db,_db):
        self.qa_ext = QAExt(base_llm)
        self.location_ext = LocationExt(base_llm)
        self.rerank_model = BgeRerank(RERANK_MODEL_PATH)
        self.faiss_db = _faiss_db
        self.db = _db
        self.llm_chain = LLMChain(llm=base_llm, prompt=PromptTemplate(input_variables=["history","context", "question"], template=prompt1), llm_kwargs= {"temperature": 0})


    def get_similarity_with_ext_origin(self, _ext,_location):
        return GetSimilarityWithExt(_question=_ext, _faiss_db=self.faiss_db,_location=_location)

    def query(self, question: str, history: str) :
        location_result = self.location_ext.extend_query_str(question=question, history=history)
        index = location_result.content.find("提取到的行政区:")
        if index == -1:
            location_str = location_result.content
        else:
            location_str = location_result.content[index + len("提取到的行政区:"):]
38
        print(location_str)
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
        pattern = r'\[([^\]]+)\]'
        match = re.search(pattern, location_str)
        cities = []
        if match:
            cities = match.group(1).split(', ')
        cities_ext = []
        for m in cities:
            city_ext = complete_administrative_division(m, divisions)
            cities_ext.append(city_ext)
        location = []
        prompt = ""
        for city in cities_ext:
            if city is not None and "县(区)" in city:
                if isinstance(city["县(区)"], str):
                    location.append(city["县(区)"])
                    prompt += city["县(区)"] + "位于" + city["省"] + city["市"] + ","
                if isinstance(city["县(区)"], list):
                    location.extend(city["县(区)"])
                    prompt += city["省"] + city["市"] + "管辖"
                    for x in city["县(区)"]:
                        prompt += x + ","
文靖昊 committed
60
        print(location)
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
        new_question = prompt + question
        print(new_question)
        split_docs_list = []
        for l in location:
            start = time.time()
            answer = self.db.find_like_doc(l)
            end = time.time()
            print('find_like_doc time: %s Seconds' % (end - start))
            print(len(answer) if answer else 0)
            split_docs = []
            for a in answer if answer else []:
                d = Document(page_content=a[0], metadata=json.loads(a[1]))
                split_docs.append(d)
            print(len(split_docs))
            # if len(split_docs) > 5:
            #     split_docs = split_docs[:5]
            split_docs_list.append(split_docs)

        start = time.time()
        result = self.qa_ext.extend_query_with_str(new_question, history)
        end = time.time()
        print('extend_query_with_str time: %s Seconds' % (end - start))
        print(result)
        matches = re.findall(r'"([^"]+)"', result.content)

        print(matches)
        similarity = self.get_similarity_with_ext_origin(matches,_location=location)
        cur_similarity = similarity.get_rerank_with_doc(self.rerank_model,split_docs_list)

        cur_answer = self.llm_chain.run(context=cur_similarity, question=new_question, history=history)
        return {"answer":cur_answer,"docs": cur_similarity}