Commit 0a2be35f by 文靖昊

新增分类方法

parent c65bb953
...@@ -2,7 +2,7 @@ from typing import Dict, Any, Optional,List ...@@ -2,7 +2,7 @@ from typing import Dict, Any, Optional,List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from .http_tools import MonitorClient, RateClient from .http_tools import MonitorClient, RateClient
from typing import Type
class MonitorPointResponse(): class MonitorPointResponse():
"""监测点查询结果""" """监测点查询结果"""
status: str = Field(..., description="状态") status: str = Field(..., description="状态")
...@@ -15,13 +15,13 @@ class MonitorPointArgs(BaseModel): ...@@ -15,13 +15,13 @@ class MonitorPointArgs(BaseModel):
class MonitorPointTool(BaseTool): class MonitorPointTool(BaseTool):
"""查询监测点信息的工具""" """查询监测点信息的工具"""
name = "monitor_points_query" name:str = "monitor_points_query"
description = """查询指定行政区划的监测点信息。 description:str = """查询指定行政区划的监测点信息。
可以查询任意省/市/区县级别的监测点数据。 可以查询任意省/市/区县级别的监测点数据。
输入参数为行政区划名称,如:湖南省、长沙市、岳麓区等。 输入参数为行政区划名称,如:湖南省、长沙市、岳麓区等。
返回该区域内的监测点列表,包含位置、经纬度等详细信息。 返回该区域内的监测点列表,包含位置、经纬度等详细信息。
""" """
args_schema: type[BaseModel] = MonitorPointArgs args_schema: Type[BaseModel] = MonitorPointArgs
client: Any = Field(None, exclude=True) client: Any = Field(None, exclude=True)
def __init__(self, base_url: str = "http://localhost:5001", **data): def __init__(self, base_url: str = "http://localhost:5001", **data):
......
...@@ -68,8 +68,8 @@ class RegionRateArgs(BaseModel): ...@@ -68,8 +68,8 @@ class RegionRateArgs(BaseModel):
class RegionRateTool(BaseRateTool): class RegionRateTool(BaseRateTool):
"""查询全国或者特定地区设备在线率的工具""" """查询全国或者特定地区设备在线率的工具"""
name = "region_online_rate" name:str = "region_online_rate"
description = "查询指定地区在指定时间段内的设备在线率" description:str = "查询指定地区在指定时间段内的设备在线率"
args_schema: Type[BaseModel] = RegionRateArgs args_schema: Type[BaseModel] = RegionRateArgs
client: Any = Field(None, exclude=True) client: Any = Field(None, exclude=True)
...@@ -119,8 +119,8 @@ class RankingRateArgs(BaseModel): ...@@ -119,8 +119,8 @@ class RankingRateArgs(BaseModel):
class RankingRateTool(BaseRateTool): class RankingRateTool(BaseRateTool):
"""查询在线率排名的工具""" """查询在线率排名的工具"""
name = "online_rate_ranking" name:str = "online_rate_ranking"
description = "查询设备在线率的排名数据,可查询省份排名或厂商排名" description:str = "查询设备在线率的排名数据,可查询省份排名或厂商排名"
args_schema: Type[BaseModel] = RankingRateArgs args_schema: Type[BaseModel] = RankingRateArgs
client: Any = Field(None, exclude=True) client: Any = Field(None, exclude=True)
......
...@@ -6,6 +6,7 @@ from fastapi import FastAPI, Header ...@@ -6,6 +6,7 @@ from fastapi import FastAPI, Header
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import uvicorn import uvicorn
from src.server.agent_rate import new_rate_agent from src.server.agent_rate import new_rate_agent
from src.server.classify import new_router_llm
from src.controller.request import GeoAgentRateRequest from src.controller.request import GeoAgentRateRequest
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
...@@ -35,6 +36,7 @@ class AgentManager: ...@@ -35,6 +36,7 @@ class AgentManager:
verbose=True verbose=True
) )
self.agent = new_rate_agent(self.llm,verbose=True,tool_base_url=tool_base_url) self.agent = new_rate_agent(self.llm,verbose=True,tool_base_url=tool_base_url)
self.router_llm = new_router_llm(self.llm)
def get_llm(self): def get_llm(self):
return self.llm return self.llm
...@@ -42,6 +44,9 @@ class AgentManager: ...@@ -42,6 +44,9 @@ class AgentManager:
def get_agent(self): def get_agent(self):
return self.agent return self.agent
def get_router_llm(self):
return self.router_llm
agent_manager = AgentManager() agent_manager = AgentManager()
...@@ -62,6 +67,22 @@ def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)): ...@@ -62,6 +67,22 @@ def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
'data': res 'data': res
} }
@app.post('/api/classify')
def classify(chat_request: GeoAgentRateRequest):
llm = agent_manager.get_router_llm()
try:
res = llm.invoke(chat_request.query)
except Exception as e:
print(f"分类失败, 错误信息: {str(e)},请重新提问")
return {
'code': 500,
'data': str(e)
}
return {
'code': 200,
'data': res
}
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -5,7 +5,6 @@ from langchain.tools import BaseTool ...@@ -5,7 +5,6 @@ from langchain.tools import BaseTool
from langchain_core.prompts import PromptTemplate,ChatPromptTemplate,SystemMessagePromptTemplate,MessagesPlaceholder,HumanMessagePromptTemplate from langchain_core.prompts import PromptTemplate,ChatPromptTemplate,SystemMessagePromptTemplate,MessagesPlaceholder,HumanMessagePromptTemplate
from langchain_core.runnables import RunnablePassthrough from langchain_core.runnables import RunnablePassthrough
from langchain.agents import AgentExecutor, Agent, create_tool_calling_agent,create_openai_functions_agent,create_structured_chat_agent from langchain.agents import AgentExecutor, Agent, create_tool_calling_agent,create_openai_functions_agent,create_structured_chat_agent
from langchain.tools.render import ToolsRenderer, render_text_description_and_args
from langchain.agents.format_scratchpad import format_log_to_str from langchain.agents.format_scratchpad import format_log_to_str
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
from langchain.agents.format_scratchpad.openai_tools import ( from langchain.agents.format_scratchpad.openai_tools import (
......
from pydantic import BaseModel,Field
from typing import Literal
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
class RouterQuery(BaseModel):
"""Route a user query to the most relevant datasource."""
datasource: Literal["agent", "rag", "none"] = Field(
description="给定用户的问题,选择最相关的组件来回答他们的问题。",
)
system1 = """您是将用户问题路由到不同组件的专家。
以下是分类问题:
agent:关于特定时间段内不同地区或品牌的设备在线率查询的问题。
rag: - 在某段时间范围内,不同地区的[环境指标/地理事件/地理特征]的问题
none: - 其他问题
以下是示例:
```
用户的输入样本匹配的数据源:
1. ** 查询 ** : “2024年8月1日,全国总的设备在线率是多少?”
- ** 匹配的分类问题 ** : agent组件类,因为这是定时间段内的设备在线率查询。
2. ** 查询 ** : “攸县近五天降雨量如何”
- ** 匹配的数据源 ** : rag组件,因为这是环境指标问题。
3. ** 查询 ** : “介绍一下武汉”
- ** 匹配的数据源 ** : none组件,因为这是其他问题,可以之间通过大模型获取答案。
```
您必须从关键词或问题的结构中推断用户的查询意图,并将其路由到相关分类组件
"""
class RouterLLM:
def __init__(self,llm):
parser = PydanticOutputParser(pydantic_object=RouterQuery)
prompt = PromptTemplate(
template=system1 + "\n{format_instructions}\n{query}\n",
input_variables=["query"],
partial_variables={"format_instructions": parser.get_format_instructions()},
)
self.router = prompt | llm | parser
def invoke(self, question):
return self.router.invoke({"query": question})
def new_router_llm(llm):
router = RouterLLM(llm)
return router
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment