import re from src.server.get_similarity import GetSimilarityWithExt import time from src.server.rerank import BgeRerank from langchain_core.documents import Document from langchain_core.prompts import PromptTemplate from src.server.get_similarity import QAExt from src.server.extend import LocationExt import json from src.agent.tool_divisions import complete_administrative_division,divisions from langchain.chains import LLMChain from src.config.consts import ( RERANK_MODEL_PATH, prompt1 ) class RagQuery(): def __init__(self,base_llm,_faiss_db,_db): self.qa_ext = QAExt(base_llm) self.location_ext = LocationExt(base_llm) self.rerank_model = BgeRerank(RERANK_MODEL_PATH) self.faiss_db = _faiss_db self.db = _db self.llm_chain = LLMChain(llm=base_llm, prompt=PromptTemplate(input_variables=["history","context", "question"], template=prompt1), llm_kwargs= {"temperature": 0}) def get_similarity_with_ext_origin(self, _ext,_location): return GetSimilarityWithExt(_question=_ext, _faiss_db=self.faiss_db,_location=_location) def query(self, question: str, history: str) : location_result = self.location_ext.extend_query_str(question=question, history=history) index = location_result.content.find("提取到的行政区:") if index == -1: location_str = location_result.content else: location_str = location_result.content[index + len("提取到的行政区:"):] print(location_str) pattern = r'\[([^\]]+)\]' match = re.search(pattern, location_str) cities = [] if match: cities = match.group(1).split(', ') cities_ext = [] for m in cities: city_ext = complete_administrative_division(m, divisions) cities_ext.append(city_ext) location = [] prompt = "" for city in cities_ext: if city is not None and "县(区)" in city: if isinstance(city["县(区)"], str): location.append(city["县(区)"]) prompt += city["县(区)"] + "位于" + city["省"] + city["市"] + "," if isinstance(city["县(区)"], list): location.extend(city["县(区)"]) prompt += city["省"] + city["市"] + "管辖" for x in city["县(区)"]: prompt += x + "," print(location) new_question = prompt + question print(new_question) split_docs_list = [] for l in location: start = time.time() answer = self.db.find_like_doc(l) end = time.time() print('find_like_doc time: %s Seconds' % (end - start)) print(len(answer) if answer else 0) split_docs = [] for a in answer if answer else []: d = Document(page_content=a[0], metadata=json.loads(a[1])) split_docs.append(d) print(len(split_docs)) # if len(split_docs) > 5: # split_docs = split_docs[:5] split_docs_list.append(split_docs) start = time.time() result = self.qa_ext.extend_query_with_str(new_question, history) end = time.time() print('extend_query_with_str time: %s Seconds' % (end - start)) print(result) matches = re.findall(r'"([^"]+)"', result.content) print(matches) similarity = self.get_similarity_with_ext_origin(matches,_location=location) cur_similarity = similarity.get_rerank_with_doc(self.rerank_model,split_docs_list) cur_answer = self.llm_chain.run(context=cur_similarity, question=new_question, history=history) return {"answer":cur_answer,"docs": cur_similarity}