Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LAE
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
文靖昊
LAE
Commits
0a2be35f
Commit
0a2be35f
authored
4 months ago
by
文靖昊
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
新增分类方法
parent
c65bb953
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
82 additions
and
9 deletions
+82
-9
tool_monitor.py
src/agent/tool_monitor.py
+4
-4
tool_rate.py
src/agent/tool_rate.py
+4
-4
api.py
src/controller/api.py
+21
-0
agent_rate.py
src/server/agent_rate.py
+0
-1
classify.py
src/server/classify.py
+53
-0
No files found.
src/agent/tool_monitor.py
View file @
0a2be35f
...
...
@@ -2,7 +2,7 @@ from typing import Dict, Any, Optional,List
from
pydantic
import
BaseModel
,
Field
from
langchain_core.tools
import
BaseTool
from
.http_tools
import
MonitorClient
,
RateClient
from
typing
import
Type
class
MonitorPointResponse
():
"""监测点查询结果"""
status
:
str
=
Field
(
...
,
description
=
"状态"
)
...
...
@@ -15,13 +15,13 @@ class MonitorPointArgs(BaseModel):
class
MonitorPointTool
(
BaseTool
):
"""查询监测点信息的工具"""
name
=
"monitor_points_query"
description
=
"""查询指定行政区划的监测点信息。
name
:
str
=
"monitor_points_query"
description
:
str
=
"""查询指定行政区划的监测点信息。
可以查询任意省/市/区县级别的监测点数据。
输入参数为行政区划名称,如:湖南省、长沙市、岳麓区等。
返回该区域内的监测点列表,包含位置、经纬度等详细信息。
"""
args_schema
:
t
ype
[
BaseModel
]
=
MonitorPointArgs
args_schema
:
T
ype
[
BaseModel
]
=
MonitorPointArgs
client
:
Any
=
Field
(
None
,
exclude
=
True
)
def
__init__
(
self
,
base_url
:
str
=
"http://localhost:5001"
,
**
data
):
...
...
This diff is collapsed.
Click to expand it.
src/agent/tool_rate.py
View file @
0a2be35f
...
...
@@ -68,8 +68,8 @@ class RegionRateArgs(BaseModel):
class
RegionRateTool
(
BaseRateTool
):
"""查询全国或者特定地区设备在线率的工具"""
name
=
"region_online_rate"
description
=
"查询指定地区在指定时间段内的设备在线率"
name
:
str
=
"region_online_rate"
description
:
str
=
"查询指定地区在指定时间段内的设备在线率"
args_schema
:
Type
[
BaseModel
]
=
RegionRateArgs
client
:
Any
=
Field
(
None
,
exclude
=
True
)
...
...
@@ -119,8 +119,8 @@ class RankingRateArgs(BaseModel):
class
RankingRateTool
(
BaseRateTool
):
"""查询在线率排名的工具"""
name
=
"online_rate_ranking"
description
=
"查询设备在线率的排名数据,可查询省份排名或厂商排名"
name
:
str
=
"online_rate_ranking"
description
:
str
=
"查询设备在线率的排名数据,可查询省份排名或厂商排名"
args_schema
:
Type
[
BaseModel
]
=
RankingRateArgs
client
:
Any
=
Field
(
None
,
exclude
=
True
)
...
...
This diff is collapsed.
Click to expand it.
src/controller/api.py
View file @
0a2be35f
...
...
@@ -6,6 +6,7 @@ from fastapi import FastAPI, Header
from
fastapi.middleware.cors
import
CORSMiddleware
import
uvicorn
from
src.server.agent_rate
import
new_rate_agent
from
src.server.classify
import
new_router_llm
from
src.controller.request
import
GeoAgentRateRequest
from
langchain_openai
import
ChatOpenAI
...
...
@@ -35,6 +36,7 @@ class AgentManager:
verbose
=
True
)
self
.
agent
=
new_rate_agent
(
self
.
llm
,
verbose
=
True
,
tool_base_url
=
tool_base_url
)
self
.
router_llm
=
new_router_llm
(
self
.
llm
)
def
get_llm
(
self
):
return
self
.
llm
...
...
@@ -42,6 +44,9 @@ class AgentManager:
def
get_agent
(
self
):
return
self
.
agent
def
get_router_llm
(
self
):
return
self
.
router_llm
agent_manager
=
AgentManager
()
...
...
@@ -62,6 +67,22 @@ def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
'data'
:
res
}
@app.post
(
'/api/classify'
)
def
classify
(
chat_request
:
GeoAgentRateRequest
):
llm
=
agent_manager
.
get_router_llm
()
try
:
res
=
llm
.
invoke
(
chat_request
.
query
)
except
Exception
as
e
:
print
(
f
"分类失败, 错误信息: {str(e)},请重新提问"
)
return
{
'code'
:
500
,
'data'
:
str
(
e
)
}
return
{
'code'
:
200
,
'data'
:
res
}
if
__name__
==
"__main__"
:
...
...
This diff is collapsed.
Click to expand it.
src/server/agent_rate.py
View file @
0a2be35f
...
...
@@ -5,7 +5,6 @@ from langchain.tools import BaseTool
from
langchain_core.prompts
import
PromptTemplate
,
ChatPromptTemplate
,
SystemMessagePromptTemplate
,
MessagesPlaceholder
,
HumanMessagePromptTemplate
from
langchain_core.runnables
import
RunnablePassthrough
from
langchain.agents
import
AgentExecutor
,
Agent
,
create_tool_calling_agent
,
create_openai_functions_agent
,
create_structured_chat_agent
from
langchain.tools.render
import
ToolsRenderer
,
render_text_description_and_args
from
langchain.agents.format_scratchpad
import
format_log_to_str
from
langchain.agents.output_parsers.openai_tools
import
OpenAIToolsAgentOutputParser
from
langchain.agents.format_scratchpad.openai_tools
import
(
...
...
This diff is collapsed.
Click to expand it.
src/server/classify.py
0 → 100644
View file @
0a2be35f
from
pydantic
import
BaseModel
,
Field
from
typing
import
Literal
from
langchain_core.prompts
import
PromptTemplate
from
langchain_core.output_parsers
import
PydanticOutputParser
class
RouterQuery
(
BaseModel
):
"""Route a user query to the most relevant datasource."""
datasource
:
Literal
[
"agent"
,
"rag"
,
"none"
]
=
Field
(
description
=
"给定用户的问题,选择最相关的组件来回答他们的问题。"
,
)
system1
=
"""您是将用户问题路由到不同组件的专家。
以下是分类问题:
agent:关于特定时间段内不同地区或品牌的设备在线率查询的问题。
rag: - 在某段时间范围内,不同地区的[环境指标/地理事件/地理特征]的问题
none: - 其他问题
以下是示例:
```
用户的输入样本匹配的数据源:
1. ** 查询 ** : “2024年8月1日,全国总的设备在线率是多少?”
- ** 匹配的分类问题 ** : agent组件类,因为这是定时间段内的设备在线率查询。
2. ** 查询 ** : “攸县近五天降雨量如何”
- ** 匹配的数据源 ** : rag组件,因为这是环境指标问题。
3. ** 查询 ** : “介绍一下武汉”
- ** 匹配的数据源 ** : none组件,因为这是其他问题,可以之间通过大模型获取答案。
```
您必须从关键词或问题的结构中推断用户的查询意图,并将其路由到相关分类组件
"""
class
RouterLLM
:
def
__init__
(
self
,
llm
):
parser
=
PydanticOutputParser
(
pydantic_object
=
RouterQuery
)
prompt
=
PromptTemplate
(
template
=
system1
+
"
\n
{format_instructions}
\n
{query}
\n
"
,
input_variables
=
[
"query"
],
partial_variables
=
{
"format_instructions"
:
parser
.
get_format_instructions
()},
)
self
.
router
=
prompt
|
llm
|
parser
def
invoke
(
self
,
question
):
return
self
.
router
.
invoke
({
"query"
:
question
})
def
new_router_llm
(
llm
):
router
=
RouterLLM
(
llm
)
return
router
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment