Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LAE
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
文靖昊
LAE
Commits
0a2be35f
Commit
0a2be35f
authored
Nov 08, 2024
by
文靖昊
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
新增分类方法
parent
c65bb953
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
82 additions
and
9 deletions
+82
-9
tool_monitor.py
src/agent/tool_monitor.py
+4
-4
tool_rate.py
src/agent/tool_rate.py
+4
-4
api.py
src/controller/api.py
+21
-0
agent_rate.py
src/server/agent_rate.py
+0
-1
classify.py
src/server/classify.py
+53
-0
No files found.
src/agent/tool_monitor.py
View file @
0a2be35f
...
@@ -2,7 +2,7 @@ from typing import Dict, Any, Optional,List
...
@@ -2,7 +2,7 @@ from typing import Dict, Any, Optional,List
from
pydantic
import
BaseModel
,
Field
from
pydantic
import
BaseModel
,
Field
from
langchain_core.tools
import
BaseTool
from
langchain_core.tools
import
BaseTool
from
.http_tools
import
MonitorClient
,
RateClient
from
.http_tools
import
MonitorClient
,
RateClient
from
typing
import
Type
class
MonitorPointResponse
():
class
MonitorPointResponse
():
"""监测点查询结果"""
"""监测点查询结果"""
status
:
str
=
Field
(
...
,
description
=
"状态"
)
status
:
str
=
Field
(
...
,
description
=
"状态"
)
...
@@ -15,13 +15,13 @@ class MonitorPointArgs(BaseModel):
...
@@ -15,13 +15,13 @@ class MonitorPointArgs(BaseModel):
class
MonitorPointTool
(
BaseTool
):
class
MonitorPointTool
(
BaseTool
):
"""查询监测点信息的工具"""
"""查询监测点信息的工具"""
name
=
"monitor_points_query"
name
:
str
=
"monitor_points_query"
description
=
"""查询指定行政区划的监测点信息。
description
:
str
=
"""查询指定行政区划的监测点信息。
可以查询任意省/市/区县级别的监测点数据。
可以查询任意省/市/区县级别的监测点数据。
输入参数为行政区划名称,如:湖南省、长沙市、岳麓区等。
输入参数为行政区划名称,如:湖南省、长沙市、岳麓区等。
返回该区域内的监测点列表,包含位置、经纬度等详细信息。
返回该区域内的监测点列表,包含位置、经纬度等详细信息。
"""
"""
args_schema
:
t
ype
[
BaseModel
]
=
MonitorPointArgs
args_schema
:
T
ype
[
BaseModel
]
=
MonitorPointArgs
client
:
Any
=
Field
(
None
,
exclude
=
True
)
client
:
Any
=
Field
(
None
,
exclude
=
True
)
def
__init__
(
self
,
base_url
:
str
=
"http://localhost:5001"
,
**
data
):
def
__init__
(
self
,
base_url
:
str
=
"http://localhost:5001"
,
**
data
):
...
...
src/agent/tool_rate.py
View file @
0a2be35f
...
@@ -68,8 +68,8 @@ class RegionRateArgs(BaseModel):
...
@@ -68,8 +68,8 @@ class RegionRateArgs(BaseModel):
class
RegionRateTool
(
BaseRateTool
):
class
RegionRateTool
(
BaseRateTool
):
"""查询全国或者特定地区设备在线率的工具"""
"""查询全国或者特定地区设备在线率的工具"""
name
=
"region_online_rate"
name
:
str
=
"region_online_rate"
description
=
"查询指定地区在指定时间段内的设备在线率"
description
:
str
=
"查询指定地区在指定时间段内的设备在线率"
args_schema
:
Type
[
BaseModel
]
=
RegionRateArgs
args_schema
:
Type
[
BaseModel
]
=
RegionRateArgs
client
:
Any
=
Field
(
None
,
exclude
=
True
)
client
:
Any
=
Field
(
None
,
exclude
=
True
)
...
@@ -119,8 +119,8 @@ class RankingRateArgs(BaseModel):
...
@@ -119,8 +119,8 @@ class RankingRateArgs(BaseModel):
class
RankingRateTool
(
BaseRateTool
):
class
RankingRateTool
(
BaseRateTool
):
"""查询在线率排名的工具"""
"""查询在线率排名的工具"""
name
=
"online_rate_ranking"
name
:
str
=
"online_rate_ranking"
description
=
"查询设备在线率的排名数据,可查询省份排名或厂商排名"
description
:
str
=
"查询设备在线率的排名数据,可查询省份排名或厂商排名"
args_schema
:
Type
[
BaseModel
]
=
RankingRateArgs
args_schema
:
Type
[
BaseModel
]
=
RankingRateArgs
client
:
Any
=
Field
(
None
,
exclude
=
True
)
client
:
Any
=
Field
(
None
,
exclude
=
True
)
...
...
src/controller/api.py
View file @
0a2be35f
...
@@ -6,6 +6,7 @@ from fastapi import FastAPI, Header
...
@@ -6,6 +6,7 @@ from fastapi import FastAPI, Header
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
import
uvicorn
import
uvicorn
from
src.server.agent_rate
import
new_rate_agent
from
src.server.agent_rate
import
new_rate_agent
from
src.server.classify
import
new_router_llm
from
src.controller.request
import
GeoAgentRateRequest
from
src.controller.request
import
GeoAgentRateRequest
from
langchain_openai
import
ChatOpenAI
from
langchain_openai
import
ChatOpenAI
...
@@ -35,6 +36,7 @@ class AgentManager:
...
@@ -35,6 +36,7 @@ class AgentManager:
verbose
=
True
verbose
=
True
)
)
self
.
agent
=
new_rate_agent
(
self
.
llm
,
verbose
=
True
,
tool_base_url
=
tool_base_url
)
self
.
agent
=
new_rate_agent
(
self
.
llm
,
verbose
=
True
,
tool_base_url
=
tool_base_url
)
self
.
router_llm
=
new_router_llm
(
self
.
llm
)
def
get_llm
(
self
):
def
get_llm
(
self
):
return
self
.
llm
return
self
.
llm
...
@@ -42,6 +44,9 @@ class AgentManager:
...
@@ -42,6 +44,9 @@ class AgentManager:
def
get_agent
(
self
):
def
get_agent
(
self
):
return
self
.
agent
return
self
.
agent
def
get_router_llm
(
self
):
return
self
.
router_llm
agent_manager
=
AgentManager
()
agent_manager
=
AgentManager
()
...
@@ -62,6 +67,22 @@ def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
...
@@ -62,6 +67,22 @@ def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
'data'
:
res
'data'
:
res
}
}
@app.post
(
'/api/classify'
)
def
classify
(
chat_request
:
GeoAgentRateRequest
):
llm
=
agent_manager
.
get_router_llm
()
try
:
res
=
llm
.
invoke
(
chat_request
.
query
)
except
Exception
as
e
:
print
(
f
"分类失败, 错误信息: {str(e)},请重新提问"
)
return
{
'code'
:
500
,
'data'
:
str
(
e
)
}
return
{
'code'
:
200
,
'data'
:
res
}
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
src/server/agent_rate.py
View file @
0a2be35f
...
@@ -5,7 +5,6 @@ from langchain.tools import BaseTool
...
@@ -5,7 +5,6 @@ from langchain.tools import BaseTool
from
langchain_core.prompts
import
PromptTemplate
,
ChatPromptTemplate
,
SystemMessagePromptTemplate
,
MessagesPlaceholder
,
HumanMessagePromptTemplate
from
langchain_core.prompts
import
PromptTemplate
,
ChatPromptTemplate
,
SystemMessagePromptTemplate
,
MessagesPlaceholder
,
HumanMessagePromptTemplate
from
langchain_core.runnables
import
RunnablePassthrough
from
langchain_core.runnables
import
RunnablePassthrough
from
langchain.agents
import
AgentExecutor
,
Agent
,
create_tool_calling_agent
,
create_openai_functions_agent
,
create_structured_chat_agent
from
langchain.agents
import
AgentExecutor
,
Agent
,
create_tool_calling_agent
,
create_openai_functions_agent
,
create_structured_chat_agent
from
langchain.tools.render
import
ToolsRenderer
,
render_text_description_and_args
from
langchain.agents.format_scratchpad
import
format_log_to_str
from
langchain.agents.format_scratchpad
import
format_log_to_str
from
langchain.agents.output_parsers.openai_tools
import
OpenAIToolsAgentOutputParser
from
langchain.agents.output_parsers.openai_tools
import
OpenAIToolsAgentOutputParser
from
langchain.agents.format_scratchpad.openai_tools
import
(
from
langchain.agents.format_scratchpad.openai_tools
import
(
...
...
src/server/classify.py
0 → 100644
View file @
0a2be35f
from
pydantic
import
BaseModel
,
Field
from
typing
import
Literal
from
langchain_core.prompts
import
PromptTemplate
from
langchain_core.output_parsers
import
PydanticOutputParser
class
RouterQuery
(
BaseModel
):
"""Route a user query to the most relevant datasource."""
datasource
:
Literal
[
"agent"
,
"rag"
,
"none"
]
=
Field
(
description
=
"给定用户的问题,选择最相关的组件来回答他们的问题。"
,
)
system1
=
"""您是将用户问题路由到不同组件的专家。
以下是分类问题:
agent:关于特定时间段内不同地区或品牌的设备在线率查询的问题。
rag: - 在某段时间范围内,不同地区的[环境指标/地理事件/地理特征]的问题
none: - 其他问题
以下是示例:
```
用户的输入样本匹配的数据源:
1. ** 查询 ** : “2024年8月1日,全国总的设备在线率是多少?”
- ** 匹配的分类问题 ** : agent组件类,因为这是定时间段内的设备在线率查询。
2. ** 查询 ** : “攸县近五天降雨量如何”
- ** 匹配的数据源 ** : rag组件,因为这是环境指标问题。
3. ** 查询 ** : “介绍一下武汉”
- ** 匹配的数据源 ** : none组件,因为这是其他问题,可以之间通过大模型获取答案。
```
您必须从关键词或问题的结构中推断用户的查询意图,并将其路由到相关分类组件
"""
class
RouterLLM
:
def
__init__
(
self
,
llm
):
parser
=
PydanticOutputParser
(
pydantic_object
=
RouterQuery
)
prompt
=
PromptTemplate
(
template
=
system1
+
"
\n
{format_instructions}
\n
{query}
\n
"
,
input_variables
=
[
"query"
],
partial_variables
=
{
"format_instructions"
:
parser
.
get_format_instructions
()},
)
self
.
router
=
prompt
|
llm
|
parser
def
invoke
(
self
,
question
):
return
self
.
router
.
invoke
({
"query"
:
question
})
def
new_router_llm
(
llm
):
router
=
RouterLLM
(
llm
)
return
router
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment