rag_agent.py 4.72 KB
Newer Older
文靖昊 committed
1
from langchain_core.tools import BaseTool
文靖昊 committed
2
from pydantic import BaseModel, Field
3
from typing import Type, Any,List
文靖昊 committed
4 5
import re
from src.server.get_similarity import GetSimilarityWithExt
6
import time
文靖昊 committed
7
from src.server.rerank import BgeRerank
文靖昊 committed
8 9
from langchain_core.documents import Document
import json
文靖昊 committed
10 11
from src.config.consts import (
    RERANK_MODEL_PATH,
文靖昊 committed
12

文靖昊 committed
13
)
文靖昊 committed
14 15


文靖昊 committed
16 17 18

class IssuanceArgs(BaseModel):
    question: str = Field(description="对话问题")
文靖昊 committed
19
    history: str = Field(description="历史对话记录")
20
    location: list = Field(description="question参数中的行政区划名称")
文靖昊 committed
21 22 23 24


class RAGQuery(BaseTool):
    name = "rag_query"
25
    description = """这是一个区(县)级的水文气象地质知识库,当咨询区(县)的水文气象地质等相关信息的时候,你可以提供数据和资料。注意,这个查询只能获取一个区(县)的水文气象地质等相关信息。如果问题中有多个区县,请拆解出来,并一个区县一个区县的查询。当需要查询省市的详细信息时,需要获取改省市下的具体行政规划,并一一获取具体的区(县)的水文气象地质等相关信息。这个知识库中信息并不全面,有可能缺失。"""
文靖昊 committed
26 27 28 29 30
    args_schema: Type[BaseModel] = IssuanceArgs
    rerank: Any  # 替换 Any 为适当的类型
    rerank_model: Any  # 替换 Any 为适当的类型
    faiss_db: Any  # 替换 Any 为适当的类型
    prompt: Any  # 假设 prompt 是一个字符串
文靖昊 committed
31
    db: Any
文靖昊 committed
32
    llm_chain: Any
文靖昊 committed
33

文靖昊 committed
34
    def __init__(self,_faiss_db,_rerank,_prompt,_db,_llm_chain):
文靖昊 committed
35 36 37 38 39 40
        super().__init__()

        self.rerank = _rerank
        self.rerank_model = BgeRerank(RERANK_MODEL_PATH)
        self.faiss_db = _faiss_db
        self.prompt = _prompt
文靖昊 committed
41
        self.db = _db
文靖昊 committed
42
        self.llm_chain = _llm_chain
文靖昊 committed
43 44


45 46
    def get_similarity_with_ext_origin(self, _ext,_location):
        return GetSimilarityWithExt(_question=_ext, _faiss_db=self.faiss_db,_location=_location)
文靖昊 committed
47

48 49 50


    def _run(self, question: str, history: str,location:list) :
51 52 53 54 55
        print(location)
        # split_str = jieba_split(question)
        # split_list = []
        # for l in split_str:
        #     split_list.append(l)
56 57 58 59 60 61 62 63 64 65 66 67 68
        split_docs_list = []
        for l in location:
            start = time.time()
            answer = self.db.find_like_doc(l)
            end = time.time()
            print('find_like_doc time: %s Seconds' % (end - start))
            print(len(answer) if answer else 0)
            split_docs = []
            for a in answer if answer else []:
                d = Document(page_content=a[0], metadata=json.loads(a[1]))
                split_docs.append(d)
            print(len(split_docs))
            split_docs_list.append(split_docs)
69 70 71 72
        # if len(split_docs)>10:
        #     split_docs= split_docs[:10]

        start = time.time()
文靖昊 committed
73
        result = self.rerank.extend_query_with_str(question, history)
74 75
        end = time.time()
        print('extend_query_with_str time: %s Seconds' % (end - start))
76
        print(result)
文靖昊 committed
77
        matches = re.findall(r'"([^"]+)"', result.content)
78

文靖昊 committed
79
        print(matches)
80
        similarity = self.get_similarity_with_ext_origin(matches,_location=location)
文靖昊 committed
81
        # cur_similarity = similarity.get_rerank(self.rerank_model)
82
        cur_similarity = similarity.get_rerank_with_doc(self.rerank_model,split_docs_list)
83
        # docs = similarity.get_rerank_docs()
84 85 86
        # print(cur_similarity)
        # # geo_result = "以下是详细的水文气象地质资料:"+cur_similarity+"\n 以下是原问题"+question
        # # cur_question = self.prompt.format(history=history, context=cur_similarity, question=question)
87
        cur_answer = self.llm_chain.run(context=cur_similarity, question=question)
88 89
        # print(cur_answer)
        # return cur_answer
90 91 92 93
        # loc = location[0]
        # location = location[1:]
        # for i in location:
        #     loc += (","+i)
94 95
        return {"详细信息":cur_answer,"参考文档": cur_similarity}

文靖昊 committed
96

97 98 99 100 101 102 103 104 105 106 107 108 109 110
class RAGAnalysisArgs(BaseModel):
    question: str = Field(description="rag_query附带的问题")
    doc: str = Field(description="rag_query获取的县级水文气象地质参考资料")

class RAGAnalysisQuery(BaseTool):
    name = "rag_analysis"
    description = """这是一个区(县)级的水文气象地质知识库解析工具,从rag_query查询到的资料,需要结合原始问题,用这个工具来解析出想要的答案,调用rag_query工具做查询之后一定要调用这个工具"""
    args_schema: Type[BaseModel] = RAGAnalysisArgs
    llm_chain: Any


    def __init__(self,_llm_chain):
        super().__init__()
        self.llm_chain = _llm_chain
文靖昊 committed
111

112 113 114
    def _run(self, question: str, doc: list):
        cur_answer = self.llm_chain.run(context=doc, question=question)
        return cur_answer