from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field
from typing import Type, Any,List
import re
from src.server.get_similarity import GetSimilarityWithExt
import time
from src.server.rerank import BgeRerank
from langchain_core.documents import Document
import json
from src.config.consts import (
RERANK_MODEL_PATH,
)
class IssuanceArgs(BaseModel):
question: str = Field(description="对话问题")
history: str = Field(description="历史对话记录")
location: list = Field(description="question参数中的行政区划名称")
class RAGQuery(BaseTool):
name = "rag_query"
description = """这是一个区(县)级的水文气象地质知识库,当咨询区(县)的水文气象地质等相关信息的时候,你可以提供数据和资料。注意,这个查询只能获取一个区(县)的水文气象地质等相关信息。如果问题中有多个区县,请拆解出来,并一个区县一个区县的查询。当需要查询省市的详细信息时,需要获取改省市下的具体行政规划,并一一获取具体的区(县)的水文气象地质等相关信息。这个知识库中信息并不全面,有可能缺失。"""
args_schema: Type[BaseModel] = IssuanceArgs
rerank: Any # 替换 Any 为适当的类型
rerank_model: Any # 替换 Any 为适当的类型
faiss_db: Any # 替换 Any 为适当的类型
prompt: Any # 假设 prompt 是一个字符串
db: Any
llm_chain: Any
def __init__(self,_faiss_db,_rerank,_prompt,_db,_llm_chain):
super().__init__()
self.rerank = _rerank
self.rerank_model = BgeRerank(RERANK_MODEL_PATH)
self.faiss_db = _faiss_db
self.prompt = _prompt
self.db = _db
self.llm_chain = _llm_chain
def get_similarity_with_ext_origin(self, _ext,_location):
return GetSimilarityWithExt(_question=_ext, _faiss_db=self.faiss_db,_location=_location)
def _run(self, question: str, history: str,location:list) :
print(location)
# split_str = jieba_split(question)
# split_list = []
# for l in split_str:
# split_list.append(l)
split_docs_list = []
for l in location:
start = time.time()
answer = self.db.find_like_doc(l)
end = time.time()
print('find_like_doc time: %s Seconds' % (end - start))
print(len(answer) if answer else 0)
split_docs = []
for a in answer if answer else []:
d = Document(page_content=a[0], metadata=json.loads(a[1]))
split_docs.append(d)
print(len(split_docs))
split_docs_list.append(split_docs)
# if len(split_docs)>10:
# split_docs= split_docs[:10]
start = time.time()
result = self.rerank.extend_query_with_str(question, history)
end = time.time()
print('extend_query_with_str time: %s Seconds' % (end - start))
print(result)
matches = re.findall(r'"([^"]+)"', result.content)
print(matches)
similarity = self.get_similarity_with_ext_origin(matches,_location=location)
# cur_similarity = similarity.get_rerank(self.rerank_model)
cur_similarity = similarity.get_rerank_with_doc(self.rerank_model,split_docs_list)
# docs = similarity.get_rerank_docs()
# print(cur_similarity)
# # geo_result = "以下是详细的水文气象地质资料:"+cur_similarity+"\n 以下是原问题"+question
# # cur_question = self.prompt.format(history=history, context=cur_similarity, question=question)
cur_answer = self.llm_chain.run(context=cur_similarity, question=question)
# print(cur_answer)
# return cur_answer
# loc = location[0]
# location = location[1:]
# for i in location:
# loc += (","+i)
return {"详细信息":cur_answer,"参考文档": cur_similarity}
class RAGAnalysisArgs(BaseModel):
question: str = Field(description="rag_query附带的问题")
doc: str = Field(description="rag_query获取的县级水文气象地质参考资料")
class RAGAnalysisQuery(BaseTool):
name = "rag_analysis"
description = """这是一个区(县)级的水文气象地质知识库解析工具,从rag_query查询到的资料,需要结合原始问题,用这个工具来解析出想要的答案,调用rag_query工具做查询之后一定要调用这个工具"""
args_schema: Type[BaseModel] = RAGAnalysisArgs
llm_chain: Any
def __init__(self,_llm_chain):
super().__init__()
self.llm_chain = _llm_chain
def _run(self, question: str, doc: list):
cur_answer = self.llm_chain.run(context=doc, question=question)
return cur_answer