Commit a086b59c by 文靖昊

分类更新

parent 66496537
...@@ -12,6 +12,7 @@ import uvicorn ...@@ -12,6 +12,7 @@ import uvicorn
from src.server.agent_rate import new_rate_agent, RateAgentV3 from src.server.agent_rate import new_rate_agent, RateAgentV3
from src.server.classify import new_router_llm from src.server.classify import new_router_llm
from src.server.extend_classify import new_extend_classify_llm
from src.server.rewrite import new_re_rewriter_llm from src.server.rewrite import new_re_rewriter_llm
from src.controller.request import GeoAgentRateRequest from src.controller.request import GeoAgentRateRequest
from src.utils.logger import setup_logging from src.utils.logger import setup_logging
...@@ -84,6 +85,7 @@ class AgentManager: ...@@ -84,6 +85,7 @@ class AgentManager:
# self.rate_agent = RateAgentV3(self.llm, tool_base_url=tool_base_url) # self.rate_agent = RateAgentV3(self.llm, tool_base_url=tool_base_url)
self.router_llm = new_router_llm(self.llm) self.router_llm = new_router_llm(self.llm)
self.re_rewriter_llm = new_re_rewriter_llm(self.llm) self.re_rewriter_llm = new_re_rewriter_llm(self.llm)
self.extend_classify_llm = new_extend_classify_llm(self.llm)
def get_llm(self): def get_llm(self):
return self.llm return self.llm
...@@ -100,6 +102,9 @@ class AgentManager: ...@@ -100,6 +102,9 @@ class AgentManager:
def get_re_rewriter_llm(self): def get_re_rewriter_llm(self):
return self.re_rewriter_llm return self.re_rewriter_llm
def get_extend_classify_llm(self):
return self.extend_classify_llm
agent_manager = AgentManager() agent_manager = AgentManager()
@app.post('/api/agent/rate') @app.post('/api/agent/rate')
...@@ -124,6 +129,7 @@ def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)): ...@@ -124,6 +129,7 @@ def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
def classify(chat_request: GeoAgentRateRequest): def classify(chat_request: GeoAgentRateRequest):
llm = agent_manager.get_router_llm() llm = agent_manager.get_router_llm()
re_llm = agent_manager.get_re_rewriter_llm() re_llm = agent_manager.get_re_rewriter_llm()
extend_llm = agent_manager.get_extend_classify_llm()
try: try:
if chat_request.query is None: if chat_request.query is None:
return { return {
...@@ -135,6 +141,15 @@ def classify(chat_request: GeoAgentRateRequest): ...@@ -135,6 +141,15 @@ def classify(chat_request: GeoAgentRateRequest):
else: else:
history = re_llm.extend_history(history=chat_request.history) history = re_llm.extend_history(history=chat_request.history)
rewrite = re_llm.invoke(chat_request.query, history) rewrite = re_llm.invoke(chat_request.query, history)
extend = extend_llm.invoke(rewrite)
if extend.classify == "yes":
return {
'code': 200,
'data': {
"datasource":"none",
"rewrite": rewrite
}
}
answer = llm.invoke(rewrite) answer = llm.invoke(rewrite)
res = { res = {
"datasource": answer.datasource, "datasource": answer.datasource,
......
...@@ -6,28 +6,33 @@ from ..utils.logger import get_logger ...@@ -6,28 +6,33 @@ from ..utils.logger import get_logger
class RouterQuery(BaseModel): class RouterQuery(BaseModel):
"""Route a user query to the most relevant datasource.""" """Route a user query to the most relevant datasource."""
datasource: Literal["agent", "specRag", "rag","none"] = Field( datasource: Literal["agent", "rag","none"] = Field(
description="给定用户的问题,选择最相关的组件来回答他们的问题。", description="给定用户的问题,选择最相关的组件来回答他们的问题。",
) )
system1 = """您是将用户问题路由到不同组件的专家,现在请根据以下四个类别对输入的问题进行分类,并输出相应的类别标签。 system1 = """您是将用户问题路由到不同组件的专家,现在请根据以下三个类别对输入的问题进行分类,并输出相应的类别标签。
specRag: 地质灾害监测预警系统的构建与管理的分类,其涵盖监测预警系统的技术框架和管理实践,如监测预警的概念、目标、任务、监测对象类型、监测预警设计、仪器选择与布设、数据通信与数据库建设、工作机制建立、定期评估、监测方案设计、仪器安装与维护等。 rag: 地质灾害监测预警系统的构建与管理的分类和特定区域地质灾害监测预警体系建设的分类,其涵盖监测预警系统的技术框架和管理实践以及某一特定区域(如海南州)的地质灾害监测预警体系建设,如监测预警的概念、目标、任务、监测对象类型、监测预警设计、仪器选择与布设、数据通信与数据库建设、工作机制建立、定期评估、监测方案设计、仪器安装与维护等,还包括区域背景分析、地质灾害现状评估、监测预警项目的具体规划与实施、项目的组织管理、环境保护措施、经费预算等。
rag: 特定区域地质灾害监测预警体系建设的分类,其专注于某一特定区域(如海南州)的地质灾害监测预警体系建设,包括区域背景分析、地质灾害现状评估、监测预警项目的具体规划与实施、项目的组织管理、环境保护措施、经费预算等。
agent: 地区性设备与监测系统效能评估问题,例如:根据多个维度(如时间、灾害类型、设备类型)查询地区的监测点的数据、设备状态(如在线率)、处理效率(如处置率、虚警率)、预警等级等。 agent: 地区性设备与监测系统效能评估问题,例如:根据多个维度(如时间、灾害类型、设备类型)查询地区的监测点的数据、设备状态(如在线率)、处理效率(如处置率、虚警率)、预警等级等。
none: 其他问题 none: 其他问题
以下是示例: 以下是示例:
问题: 2024年10月15,北京市的设备在线率是多少? 问题: 问题背景:无,问题:2024年10月15,北京市的设备在线率是多少?
分类: agent,因为这是地区性设备评估问题 分类: agent,因为这是地区性设备评估问题
---------------- ----------------
问题: 自建CORS组网基准站观测墩的建造要求是什么? 问题: 问题背景:无,问题:自建CORS组网基准站观测墩的建造要求是什么?
分类: specRag,因为这属于灾害监测预警系统的构建,涉及到预警系统的要求 分类: rag,因为这属于灾害监测预警系统的构建,涉及到预警系统的要求
---------------- ----------------
问题: 海南州地质灾害发育现状? 问题: 问题背景:无,问题:海南州地质灾害发育现状?
分类: rag,因为这属于具体建设体系中的地质灾害现状评估 分类: rag,因为这属于具体建设体系中的地质灾害现状评估
---------------- ----------------
问题: 介绍一下武汉 问题: 问题背景:无,问题:介绍一下武汉
分类:none组件,因为这是其他问题, 分类: none组件,因为这是其他问题,
----------------
问题: 问题背景:之前讨论过山东省滑坡仪数量为10,雨量计数量为30,裂缝计的数量为10,总计50。,问题如下:山西的滑坡仪、雨量计及裂缝计的总数是多少?
分类:agent,问题背景和问题都涉及滑坡仪、雨量计和裂缝计,属于地区性设备与监测系统效能评估问题
----------------
问题: 问题背景:无,问题如下:中海达的在线率在设备厂商中排名第几?
分类:agent,问题都涉及设备在线率,属于地区性设备与监测系统效能评估问题
---------------- ----------------
""" """
......
from pydantic import BaseModel,Field
from typing import Literal
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate,ChatPromptTemplate,SystemMessagePromptTemplate,MessagesPlaceholder,HumanMessagePromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
import json
from ..utils.logger import get_logger
system1 = """
您是问题分析专家,根据用户问题,判断是否能从问题背景中获取到问题的答案,如果可以回答,请选择yes,不能则选择no,
以下是示例:
问题: 问题背景:无,问题:2024年10月15,北京市的设备在线率是多少?
是否能直接回答: no
----------------
问题: 问题背景:无,问题:自建CORS组网基准站观测墩的建造要求是什么?
是否需要调用工具: no
----------------
问题: 问题背景:陕西的设备在线率是90%,问题:陕西的设备在线率是多少,
分类: yes
----------------
问题: 问题背景:山西的设备在线率是90%,问题:陕西的设备在线率是多少,
分类: no
----------------
"""
class ExtendClassifyModel(BaseModel):
"""Route a user query to the most relevant datasource."""
classify: Literal["yes", "no"] = Field(
description="给定用户的问题,是否能从问题背景中获取到问题的答案。",
)
class ExtendClassifyLLM:
def __init__(self,base_llm):
parser = PydanticOutputParser(pydantic_object=ExtendClassifyModel)
prompt = PromptTemplate(
template= system1 + "\n{format_instructions}\n{query}\n",
input_variables=["query"],
partial_variables={"format_instructions": parser.get_format_instructions()},
)
self.llm = prompt | base_llm | parser
self.logger = get_logger(self.__class__.__name__)
def invoke(self, question):
try:
result = self.llm.invoke({"query": question})
self.logger.info(f"扩展分类结果: {result}")
return result
except Exception as e:
self.logger.error(f"扩展分类结果失败: {e}", exc_info=True)
return {'classify': 'no'}
def new_extend_classify_llm(llm):
router = ExtendClassifyLLM(llm)
return router
...@@ -68,8 +68,8 @@ A: 关于 FatGPT 的介绍和使用等问题。 ...@@ -68,8 +68,8 @@ A: 关于 FatGPT 的介绍和使用等问题。
历史记录: 历史记录:
''' '''
''' '''
原问题: 北京和上海那个天气好 原问题: 武汉的天气怎么样
优化后的问题: 北京和上海那个天气好 优化后的问题: 武汉的天气怎么样?
相关背景总结: 无 相关背景总结: 无
---------------- ----------------
历史记录: 历史记录:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment