Commit 065e3f2c by tinywell

设备在线率agent 流程及工具验证

parent 98978c51
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import random
class FakeDataGenerator:
"""生成用于测试在线率分析工具的模拟数据"""
def __init__(self):
# 省份列表
self.provinces = [
'北京市', '上海市', '广东省', '江苏省', '浙江省',
'山东省', '河南省', '四川省', '湖北省', '福建省'
]
# 制造商列表
self.manufacturers = [
'华为', '中兴', '烽火', '诺基亚', '爱立信',
'思科', '新华三', '锐捷', '迈普', '东方通信'
]
# 基础在线率范围
self.base_rate_range = (0.85, 0.98)
def generate_region_data(self, region_name: str, start_time: str, end_time: str) -> pd.DataFrame:
"""生成指定地区和时间段的在线率数据"""
start_date = datetime.strptime(start_time, '%Y-%m-%d')
end_date = datetime.strptime(end_time, '%Y-%m-%d')
date_range = pd.date_range(start_date, end_date, freq='D')
data = []
base_rate = random.uniform(*self.base_rate_range)
for date in date_range:
# 添加一些随机波动
daily_rate = min(1.0, max(0.0, base_rate + random.uniform(-0.05, 0.05)))
data.append({
'date': date,
'region': region_name,
'online_rate': daily_rate,
'device_count': random.randint(1000, 5000)
})
return pd.DataFrame(data)
def generate_ranking_data(self, rank_type: int) -> pd.DataFrame:
"""生成排名数据"""
if rank_type == 1: # 省份排名
entities = self.provinces
else: # 厂商排名
entities = self.manufacturers
data = []
for entity in entities:
base_rate = random.uniform(*self.base_rate_range)
data.append({
'name': entity,
'online_rate': base_rate,
'device_count': random.randint(5000, 20000),
'offline_count': random.randint(100, 1000)
})
return pd.DataFrame(data).sort_values('online_rate', ascending=False)
def generate_national_trend(self, start_time: str, end_time: str) -> pd.DataFrame:
"""生成全国在线率趋势数据"""
start_date = datetime.strptime(start_time, '%Y-%m-%d')
end_date = datetime.strptime(end_time, '%Y-%m-%d')
date_range = pd.date_range(start_date, end_date, freq='D')
data = []
base_rate = random.uniform(*self.base_rate_range)
trend = np.linspace(-0.02, 0.02, len(date_range)) # 添加轻微的趋势
for i, date in enumerate(date_range):
# 基础在线率 + 趋势 + 随机波动
daily_rate = min(1.0, max(0.0, base_rate + trend[i] + random.uniform(-0.02, 0.02)))
data.append({
'date': date,
'online_rate': daily_rate,
'total_devices': random.randint(50000, 100000),
'online_devices': random.randint(40000, 90000)
})
return pd.DataFrame(data)
class MockDBConnection:
"""模拟数据库连接类"""
def __init__(self):
self.fake_data = FakeDataGenerator()
def query(self, sql: str, params: dict = None) -> pd.DataFrame:
"""模拟SQL查询"""
# 根据SQL语句特征返回相应的模拟数据
if 'region' in sql.lower():
return self.fake_data.generate_region_data(
params.get('region_name', '北京市'),
params.get('start_time', '2024-01-01'),
params.get('end_time', '2024-01-07')
)
elif 'rank' in sql.lower():
return self.fake_data.generate_ranking_data(
params.get('type', 1)
)
elif 'national' in sql.lower():
return self.fake_data.generate_national_trend(
params.get('start_time', '2024-01-01'),
params.get('end_time', '2024-01-07')
)
else:
return pd.DataFrame() # 默认返回空数据框
from typing import Dict, List, Tuple, Any, Optional
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from datetime import datetime
from pydantic import BaseModel, Field
from typing import Type
from langchain_core.tools import BaseTool
class BaseRateTool(BaseTool):
"""设备在线率分析基础工具类"""
db: Any = Field(None, exclude=True)
def __init__(self, db_connection, **data):
super().__init__(**data)
self.db = db_connection
def format_response(self, data: Dict[str, Any], chart: go.Figure) -> Dict[str, Any]:
"""格式化返回结果"""
return {
'data': data,
# 'chart': chart,
'summary': self._generate_summary(data)
}
def _generate_summary(self, data: Dict[str, Any]) -> str:
"""生成数据分析总结文本"""
# 保持原有的summary生成逻辑
if 'trend_data' in data: # 全国趋势数据
return (
f"在{data['date_range']['start']}至{data['date_range']['end']}期间,"
f"全国平均在线率为{data['statistics']['average_rate']:.2%},"
f"最高达到{data['statistics']['max_rate']:.2%},"
f"最低为{data['statistics']['min_rate']:.2%}。"
f"平均设备总数约{data['statistics']['average_devices']:,}台。"
)
elif 'rankings' in data: # 排名数据
if 'total_provinces' in data: # 省份排名
return (
f"共分析了{data['total_provinces']}个省份的在线率数据,"
f"平均在线率为{data['average_rate']:.2%}。"
f"{data['best_province']['name']}的表现最好,"
f"在线率达到{data['best_province']['rate']:.2%}。"
)
else: # 厂商排名
return (
f"共分析了{data['total_manufacturers']}个厂商的在线率数据,"
f"平均在线率为{data['average_rate']:.2%}。"
f"{data['best_manufacturer']['name']}的表现最好,"
f"在线率达到{data['best_manufacturer']['rate']:.2%}。"
)
else: # 地区在线率数据
return (
f"{data['region']}在{data['date_range']['start']}至{data['date_range']['end']}期间,"
f"平均在线率为{data['average_rate']:.2%},"
f"最高达到{data['max_rate']:.2%},"
f"最低为{data['min_rate']:.2%}。"
f"平均设备数约{int(data['total_devices']):,}台。"
)
class RegionRateArgs(BaseModel):
"""地区在线率查询参数"""
region_name: str = Field(..., description="地区名称")
start_time: str = Field(..., description="开始时间 (YYYY-MM-DD)")
end_time: str = Field(..., description="结束时间 (YYYY-MM-DD)")
class RegionRateTool(BaseRateTool):
"""查询特定地区设备在线率的工具"""
name = "region_online_rate"
description = "查询指定地区在指定时间段内的设备在线率"
args_schema: Type[BaseModel] = RegionRateArgs
def _run(self, region_name: str, start_time: str, end_time: str) -> Dict[str, Any]:
return self.get_region_online_rate(region_name, start_time, end_time)
def get_region_online_rate(self, region_name: str, start_time: str, end_time: str) -> Dict[str, Any]:
# 保持原有的实现...
# 查询数据
sql = """region"""
df = self.db.query(sql, {
'region_name': region_name,
'start_time': start_time,
'end_time': end_time
})
# 生成图表
fig = px.line(df,
x='date',
y='online_rate',
title=f'{region_name}设备在线率趋势')
fig.update_layout(
xaxis_title='日期',
yaxis_title='在线率',
yaxis_tickformat='.2%'
)
# 准备数据
data = {
'region': region_name,
'average_rate': df['online_rate'].mean(),
'max_rate': df['online_rate'].max(),
'min_rate': df['online_rate'].min(),
'total_devices': df['device_count'].mean(),
'date_range': {
'start': start_time,
'end': end_time
}
}
return self.format_response(data, fig)
class RankingRateArgs(BaseModel):
"""排名查询参数"""
rate_type: int = Field(..., description="排序类型:1-省份排名,2-厂商排名")
class RankingRateTool(BaseRateTool):
"""查询在线率排名的工具"""
name = "online_rate_ranking"
description = "查询设备在线率的排名数据,可查询省份排名或厂商排名"
args_schema: Type[BaseModel] = RankingRateArgs
def _run(self, rate_type: int) -> Dict[str, Any]:
return self.get_ranking_data(rate_type)
def get_ranking_data(self, rank_type: int) -> Dict[str, Any]:
if rank_type == 1:
return self._get_province_ranking()
else:
return self._get_manufacturer_ranking()
def _get_province_ranking(self) -> Dict[str, Any]:
"""获取省份在线率排名"""
sql = """rank"""
df = self.db.query(sql, {'type': 1})
# 生成排名图表
fig = px.bar(df,
x='name',
y='online_rate',
title='各省份设备在线率排名')
fig.update_layout(
xaxis_title='省份',
yaxis_title='在线率',
yaxis_tickformat='.2%'
)
# 准备数据
data = {
'rankings': df.to_dict('records'),
'total_provinces': len(df),
'average_rate': df['online_rate'].mean(),
'best_province': {
'name': df.iloc[0]['name'],
'rate': df.iloc[0]['online_rate']
}
}
return self.format_response(data, fig)
def _get_manufacturer_ranking(self) -> Dict[str, Any]:
"""获取厂商在线率排名"""
sql = """rank"""
df = self.db.query(sql, {'type': 2})
# 生成排名图表
fig = px.bar(df,
x='name',
y='online_rate',
title='各设备厂商在线率排名')
fig.update_layout(
xaxis_title='厂商',
yaxis_title='在线率',
yaxis_tickformat='.2%'
)
# 准备数据
data = {
'rankings': df.to_dict('records'),
'total_manufacturers': len(df),
'average_rate': df['online_rate'].mean(),
'best_manufacturer': {
'name': df.iloc[0]['name'],
'rate': df.iloc[0]['online_rate']
}
}
return self.format_response(data, fig)
class NationalTrendArgs(BaseModel):
"""全国趋势查询参数"""
start_time: str = Field(..., description="开始时间 (YYYY-MM-DD)")
end_time: str = Field(..., description="结束时间 (YYYY-MM-DD)")
class NationalTrendTool(BaseRateTool):
"""查询全国在线率趋势的工具"""
name = "national_online_trend"
description = "查询全国范围内设备在线率的变化趋势"
args_schema: Type[BaseModel] = NationalTrendArgs
def _run(self, start_time: str, end_time: str) -> Dict[str, Any]:
return self.get_national_trend(start_time, end_time)
def get_national_trend(self, start_time: str, end_time: str) -> Dict[str, Any]:
"""
接口三:获取全国在线率趋势
"""
sql = """national"""
df = self.db.query(sql, {
'start_time': start_time,
'end_time': end_time
})
# 生成趋势图表
fig = go.Figure()
# 添加在线率曲线
fig.add_trace(go.Scatter(
x=df['date'],
y=df['online_rate'],
name='在线率',
mode='lines+markers'
))
# 设置图表布局
fig.update_layout(
title='全国设备在线率趋势',
xaxis_title='日期',
yaxis_title='在线率',
yaxis_tickformat='.2%'
)
# 准备数据
data = {
'trend_data': df.to_dict('records'),
'statistics': {
'average_rate': df['online_rate'].mean(),
'max_rate': df['online_rate'].max(),
'min_rate': df['online_rate'].min(),
'average_devices': int(df['total_devices'].mean()),
'average_online': int(df['online_devices'].mean())
},
'date_range': {
'start': start_time,
'end': end_time
}
}
return self.format_response(data, fig)
\ No newline at end of file
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Dict, Optional
class LoginRequest(BaseModel): class LoginRequest(BaseModel):
...@@ -41,3 +42,8 @@ class ChatRequest(BaseModel): ...@@ -41,3 +42,8 @@ class ChatRequest(BaseModel):
class ReGenerateRequest(BaseModel): class ReGenerateRequest(BaseModel):
sessionID: str sessionID: str
question: str question: str
class GeoAgentRateRequest(BaseModel):
query: str
history: Optional[List[Dict]] = None
...@@ -17,10 +17,12 @@ from src.pgdb.knowledge.txt_doc_table import TxtDoc ...@@ -17,10 +17,12 @@ from src.pgdb.knowledge.txt_doc_table import TxtDoc
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from langchain_core.documents import Document from langchain_core.documents import Document
from src.server.rag_query import RagQuery from src.server.rag_query import RagQuery
from src.server.agent_rate import new_rate_agent
from src.controller.request import ( from src.controller.request import (
PhoneLoginRequest, PhoneLoginRequest,
ChatRequest, ChatRequest,
ReGenerateRequest ReGenerateRequest,
GeoAgentRateRequest
) )
from src.config.consts import ( from src.config.consts import (
CHAT_DB_USER, CHAT_DB_USER,
...@@ -49,30 +51,30 @@ app.add_middleware( ...@@ -49,30 +51,30 @@ app.add_middleware(
allow_headers=["*"], # 允许所有HTTP头 allow_headers=["*"], # 允许所有HTTP头
) )
c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD, # c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
port=CHAT_DB_PORT, ) # port=CHAT_DB_PORT, )
c_db.connect() # c_db.connect()
k_db = PostgresDB(host=VEC_DB_HOST, database=VEC_DB_DBNAME, user=VEC_DB_USER, password=VEC_DB_PASSWORD, port=VEC_DB_PORT) # k_db = PostgresDB(host=VEC_DB_HOST, database=VEC_DB_DBNAME, user=VEC_DB_USER, password=VEC_DB_PASSWORD, port=VEC_DB_PORT)
k_db.connect() # k_db.connect()
vecstore_faiss = VectorStore_FAISS( # vecstore_faiss = VectorStore_FAISS(
embedding_model_name=EMBEEDING_MODEL_PATH, # embedding_model_name=EMBEEDING_MODEL_PATH,
store_path=FAISS_STORE_PATH, # store_path=FAISS_STORE_PATH,
index_name=INDEX_NAME, # index_name=INDEX_NAME,
info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER, # info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
"password": VEC_DB_PASSWORD}, # "password": VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER, # show_number=SIMILARITY_SHOW_NUMBER,
reset=False) # reset=False)
base_llm = ChatOpenAI( base_llm = ChatOpenAI(
openai_api_key='xxxxxxxxxxxxx', openai_api_key='xxxxxxxxxxxxx',
openai_api_base='http://192.168.10.14:8000/v1', openai_api_base='http://192.168.10.14:8000/v1',
model_name='Qwen2-7B', model_name='Qwen/Qwen2-7B-Instruct',
verbose=True verbose=True
) )
rag_query = RagQuery(base_llm=base_llm,_faiss_db=vecstore_faiss,_db=TxtDoc(k_db)) # rag_query = RagQuery(base_llm=base_llm,_faiss_db=vecstore_faiss,_db=TxtDoc(k_db))
...@@ -252,6 +254,16 @@ def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)): ...@@ -252,6 +254,16 @@ def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)):
} }
} }
@app.post('/api/agent/rate')
def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
agent = new_rate_agent(base_llm,verbose=True)
res = agent.exec(prompt_args={"input": chat_request.query})
return {
'code': 200,
'data': res
}
def get_similarity_doc(similarity_doc_hash: str): def get_similarity_doc(similarity_doc_hash: str):
if similarity_doc_hash: if similarity_doc_hash:
hashs = similarity_doc_hash.split(",") hashs = similarity_doc_hash.split(",")
......
from typing import Any, List, Sequence, Union
import langchain_core
from langchain.tools import BaseTool
from langchain_core.prompts import PromptTemplate,ChatPromptTemplate,SystemMessagePromptTemplate,MessagesPlaceholder,HumanMessagePromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain.agents import AgentExecutor, Agent, create_tool_calling_agent,create_openai_functions_agent,create_structured_chat_agent
from langchain.tools.render import ToolsRenderer, render_text_description_and_args
from langchain.agents.format_scratchpad import format_log_to_str
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
from langchain.agents.format_scratchpad.openai_tools import (
format_to_openai_tool_messages,
)
from langchain import hub
from src.agent.tool_rate import RegionRateTool,RankingRateTool,NationalTrendTool
from src.agent.fake_data_rate import MockDBConnection
def create_rate_agent(llm, tools: List[BaseTool],prompt: PromptTemplate = None,
tools_renderer: ToolsRenderer = render_text_description_and_args,
verbose: bool = False,**args):
missing_vars = {"tools", "tool_names", "agent_scratchpad"}.difference(
prompt.input_variables + list(prompt.partial_variables)
)
if missing_vars:
raise ValueError(f"Prompt missing required variables: {missing_vars}")
prompt = prompt.partial(
tools=tools_renderer(list(tools)),
tool_names=", ".join([t.name for t in tools]),
)
if stop_sequence:
stop = ["\nObservation"] if stop_sequence is True else stop_sequence
llm_with_stop = llm.bind(stop=stop)
else:
llm_with_stop = llm
agent = (
RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]),
)
| prompt
| llm_with_stop
)
return agent
class RateAgent:
def __init__(self, llm, tools: List[BaseTool],prompt: PromptTemplate = None, verbose: bool = False,**args):
# if not prompt:
# raise ValueError("PromptTemplate is required")
prompt = hub.pull("hwchase17/openai-tools-agent")
agent = create_tool_calling_agent(llm, tools, prompt)
self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=verbose)
def exec(self, prompt_args: dict = {}, stream: bool = False):
return self.agent_executor.invoke(input=prompt_args)
def stream(self, prompt_args: dict = {}):
for step in self.agent_executor.stream(prompt_args):
yield step
ONLINE_RATE_SYSTEM_PROMPT = """你是一个专门处理设备在线率分析的AI助手。你可以通过调用专门的工具来分析和展示不同维度的在线率数据。
你可以处理以下三类核心任务:
1. 地区在线率分析
- 分析指定地区(省/市/区县)在特定时间段的设备在线率
- 必需参数:地区名称、开始时间、结束时间
- 示例问题:"查询福建省2024年1月的设备在线率"
2. 在线率排名分析
- 分析各省份或各厂商的在线率排名情况
- type=1:查看各省份在线率排名
- type=2:查看各厂商在线率排名
- 示例问题:"显示所有省份的在线率排名" 或 "各设备厂商的在线率排名如何?"
3. 全国趋势分析
- 分析全国范围内在线率随时间的变化趋势
- 必需参数:开始时间、结束时间
- 示例问题:"展示2024年1月至2月的全国在线率趋势"
你需要:
1. 理解用户意图,将用户问题映射到合适的分析类型
2. 确保必要参数完整,如果缺少参数要主动询问
3. 调用相应的分析工具获取数据
4. 生成清晰的分析报告,包括数据解读和可视化图表
5. 对异常情况(如数据缺失、参数错误)提供友好的解释和建议
注意事项:
- 时间格式统一使用:YYYY-MM-DD
- 地区名称需要包含行政级别(如:福建省、厦门市)
- 数据展示优先使用图表,并配合文字说明
- 百分比数据保留两位小数
您可以使用以下工具:
{tools}
使用 JSON 对象指定工具,提供一个 action 键(工具名称)和一个 action_input 键(工具输入) 。
有效的 "action" 值: "Final Answer" 或 {tool_names}
每个 $JSON_BLOB 只提供一个操作,如下所示:
```
{{
"action": $TOOL_NAME,
"action_input": $INPUT,
}}
```
按照以下格式:
Question: 输入要回答的问题
Thought: 考虑前后步骤
Action:
```
$JSON_BLOB
```
Observation: 操作结果
...(重复 Thought/Action/Observation N 次)
Thought: 我知道如何回复
Action:
```
{{
"action": "Final Answer",
"action_input": "最终回复给人类",
}}
```
开始!始终以有效的单个操作的 JSON 对象回复。如有必要,请使用工具。如果你知道答案,请直接回复。
你的回复格式为 Action:```$JSON_BLOB```然后 Observation。
"""
PROMPT_AGENT_HUMAN = """{input}\n\n {agent_scratchpad}\n (请注意,无论如何都要以 JSON 对象回复。你的主要目标是帮助用户快速理解和分析设备在线率数据,提供准确、直观的分析结果)"""
PROMPT_AGENT_SYS = """请尽量帮助人类并准确回答问题。您可以使用以下工具:
{tools}
使用 JSON 对象指定工具,提供一个 action 键(工具名称)和一个 action_input 键(工具输入), 以及 action_cache 键(有些工具要求缓存其返回的中间结果) 。
有效的 "action" 值: "Final Answer" 或 {tool_names}
每个 $JSON_BLOB 只提供一个操作,如下所示:
```
{{
"action": $TOOL_NAME,
"action_input": $INPUT,
}}
```
按照以下格式:
Question: 输入要回答的问题
Thought: 考虑前后步骤
Action:
```
$JSON_BLOB
```
Observation: 操作结果
...(重复 Thought/Action/Observation N 次)
Thought: 我知道如何回复
Action:
```
{{
"action": "Final Answer",
"action_input": "最终回复给人类",
"action_cache": {{所有要求保存中间结果的工具操作结果汇总}}
}}
```
开始!始终以有效的单个操作的 JSON 对象回复。如有必要,请使用工具,并在最后一步按照工具要求将工具操作结果汇总到最后一个 Action 中的 action_cache。如果你知道答案,请直接回复。
你的回复格式为 Action:```$JSON_BLOB```然后 Observation,并在必要时将 Observation 结果更新到下一个 action_cache 中。
"""
PROMPT_AGENT_SYS_VARS = [ "tool_names", "tools"]
class RateAgentV2:
def __init__(self, llm, tools: List[BaseTool],prompt: PromptTemplate = None, verbose: bool = False,**args):
prompt = ChatPromptTemplate.from_messages([
SystemMessagePromptTemplate.from_template(ONLINE_RATE_SYSTEM_PROMPT),
MessagesPlaceholder(variable_name="chat_history", optional=True),
HumanMessagePromptTemplate.from_template(PROMPT_AGENT_HUMAN)
])
agent = create_structured_chat_agent(llm, tools, prompt)
self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=verbose)
def exec(self, prompt_args: dict = {}, stream: bool = False):
return self.agent_executor.invoke(input=prompt_args)
def stream(self, prompt_args: dict = {}):
for step in self.agent_executor.stream(prompt_args):
yield step
def new_rate_agent(llm, verbose: bool = False,**args):
conn = MockDBConnection()
tools = [
RegionRateTool(db_connection=conn),
RankingRateTool(db_connection=conn),
NationalTrendTool(db_connection=conn)
]
# prompt = ChatPromptTemplate.from_messages([
# SystemMessagePromptTemplate.from_template(ONLINE_RATE_SYSTEM_PROMPT),
# MessagesPlaceholder(variable_name="chat_history", optional=True),
# HumanMessagePromptTemplate.from_template(PROMPT_AGENT_HUMAN)
# ])
# prompt = prompt.partial(tools=render_text_description_and_args(tools), tool_names=", ".join([t.name for t in tools]))
# 使用 LangChain 的工具调用代理
agent = RateAgentV2(llm=llm, tools=tools, verbose=verbose, **args)
return agent
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment