from langchain_core.tools import BaseTool from pydantic import BaseModel, Field from typing import Type, Any,List import re from src.server.get_similarity import GetSimilarityWithExt import time from src.server.rerank import BgeRerank from langchain_core.documents import Document import json from src.config.consts import ( RERANK_MODEL_PATH, ) class IssuanceArgs(BaseModel): question: str = Field(description="对话问题") history: str = Field(description="历史对话记录") location: list = Field(description="question参数中的行政区划名称") class RAGQuery(BaseTool): name = "rag_query" description = """这是一个区(县)级的水文气象地质知识库,当咨询区(县)的水文气象地质等相关信息的时候,你可以提供数据和资料。注意,这个查询只能获取单个区(县)的水文气象地质等相关信息,当需要查询省市的详细信息时,需要获取改省市下的具体行政规划,并一一获取具体的区(县)的水文气象地质等相关信息。这个知识库中信息并不全面,有可能缺失。这个工具生成的结果需要再调用rag_analysis这个工具来进行解析,每次调用完成之后,一定要调用rag_analysis去解析结果""" args_schema: Type[BaseModel] = IssuanceArgs rerank: Any # 替换 Any 为适当的类型 rerank_model: Any # 替换 Any 为适当的类型 faiss_db: Any # 替换 Any 为适当的类型 prompt: Any # 假设 prompt 是一个字符串 db: Any llm_chain: Any def __init__(self,_faiss_db,_rerank,_prompt,_db,_llm_chain): super().__init__() self.rerank = _rerank self.rerank_model = BgeRerank(RERANK_MODEL_PATH) self.faiss_db = _faiss_db self.prompt = _prompt self.db = _db self.llm_chain = _llm_chain def get_similarity_with_ext_origin(self, _ext,_location): return GetSimilarityWithExt(_question=_ext, _faiss_db=self.faiss_db,_location=_location) def _run(self, question: str, history: str,location:list) : print(location) # split_str = jieba_split(question) # split_list = [] # for l in split_str: # split_list.append(l) start = time.time() answer = self.db.find_like_doc(location) end = time.time() print('find_like_doc time: %s Seconds' % (end - start)) print(len(answer) if answer else 0) split_docs = [] for a in answer if answer else []: d = Document(page_content=a[0], metadata=json.loads(a[1])) split_docs.append(d) print(len(split_docs)) # if len(split_docs)>10: # split_docs= split_docs[:10] start = time.time() result = self.rerank.extend_query_with_str(question, history) end = time.time() print('extend_query_with_str time: %s Seconds' % (end - start)) print(result) matches = re.findall(r'"([^"]+)"', result.content) print(matches) similarity = self.get_similarity_with_ext_origin(matches,_location=location) # cur_similarity = similarity.get_rerank(self.rerank_model) cur_similarity = similarity.get_rerank_with_doc(self.rerank_model,split_docs) # docs = similarity.get_rerank_docs() # print(cur_similarity) # # geo_result = "以下是详细的水文气象地质资料:"+cur_similarity+"\n 以下是原问题"+question # # cur_question = self.prompt.format(history=history, context=cur_similarity, question=question) cur_answer = self.llm_chain.run(context=cur_similarity, question=question) # print(cur_answer) # return cur_answer # loc = location[0] # location = location[1:] # for i in location: # loc += (","+i) return {"详细信息":cur_answer,"参考文档": cur_similarity} class RAGAnalysisArgs(BaseModel): question: str = Field(description="rag_query附带的问题") doc: str = Field(description="rag_query获取的县级水文气象地质参考资料") class RAGAnalysisQuery(BaseTool): name = "rag_analysis" description = """这是一个区(县)级的水文气象地质知识库解析工具,从rag_query查询到的资料,需要结合原始问题,用这个工具来解析出想要的答案,调用rag_query工具做查询之后一定要调用这个工具""" args_schema: Type[BaseModel] = RAGAnalysisArgs llm_chain: Any def __init__(self,_llm_chain): super().__init__() self.llm_chain = _llm_chain def _run(self, question: str, doc: list): cur_answer = self.llm_chain.run(context=doc, question=question) return cur_answer