Commit 229a0e7f by tinywell

chore: Add administrative division tool to agent

parent 6bcdfb52
divisions = [
{
"province": "青海省",
"cities": [
{
"name": "西宁市",
"counties": ["城东区", "城中区", "城西区", "城北区", "大通回族土族自治县", "湟中区", "湟源县"]
},
{
"name": "海东市",
"counties": ["乐都区", "平安区", "民和回族土族自治县", "互助土族自治县", "化隆回族自治县", "循化撒拉族自治县"]
},
{
"name": "海北藏族自治州",
"counties": ["门源回族自治县", "祁连县", "海晏县", "刚察县"]
},
{
"name": "黄南藏族自治州",
"counties": ["同仁市", "尖扎县", "泽库县", "河南蒙古族自治县"]
},
{
"name": "海南藏族自治州",
"counties": ["共和县", "同德县", "贵德县", "兴海县", "贵南县"]
},
{
"name": "果洛藏族自治州",
"counties": ["玛沁县", "班玛县", "甘德县", "达日县", "久治县", "玛多县"]
},
{
"name": "玉树藏族自治州",
"counties": ["玉树市", "杂多县", "称多县", "治多县", "囊谦县", "曲麻莱县"]
},
{
"name": "海西蒙古族藏族自治州",
"counties": ["格尔木市", "德令哈市", "茫崖市", "乌兰县", "都兰县", "天峻县"]
}
]
},
{
"province": "甘肃省",
"cities": [
{
"name": "兰州市",
"counties": ["城关区", "七里河区", "西固区", "安宁区", "红古区", "永登县", "皋兰县", "榆中县"]
},
{
"name": "嘉峪关市",
"counties": ["嘉峪关市"]
},
{
"name": "金昌市",
"counties": ["金川区", "永昌县"]
},
{
"name": "白银市",
"counties": ["白银区", "平川区", "靖远县", "会宁县", "景泰县"]
},
{
"name": "天水市",
"counties": ["秦州区", "麦积区", "清水县", "秦安县", "甘谷县", "武山县", "张家川回族自治县"]
},
{
"name": "武威市",
"counties": ["凉州区", "民勤县", "古浪县", "天祝藏族自治县"]
},
{
"name": "张掖市",
"counties": ["甘州区", "肃南裕固族自治县", "民乐县", "临泽县", "高台县", "山丹县"]
},
{
"name": "平凉市",
"counties": ["崆峒区", "泾川县", "灵台县", "崇信县", "华亭市", "庄浪县", "静宁县"]
},
{
"name": "酒泉市",
"counties": ["肃州区", "金塔县", "瓜州县", "肃北蒙古族自治县", "阿克塞哈萨克族自治县", "玉门市", "敦煌市"]
},
{
"name": "庆阳市",
"counties": ["西峰区", "庆城县", "环县", "华池县", "合水县", "正宁县", "宁县", "镇原县"]
},
{
"name": "定西市",
"counties": ["安定区", "通渭县", "陇西县", "渭源县", "临洮县", "漳县", "岷县"]
},
{
"name": "陇南市",
"counties": ["武都区", "成县", "文县", "宕昌县", "康县", "西和县", "礼县", "徽县", "两当县"]
},
{
"name": "临夏回族自治州",
"counties": ["临夏市", "临夏县", "康乐县", "永靖县", "广河县", "和政县", "东乡族自治县", "积石山保安族东乡族撒拉族自治县"]
},
{
"name": "甘南藏族自治州",
"counties": ["合作市", "临潭县", "卓尼县", "舟曲县", "迭部县", "玛曲县", "碌曲县", "夏河县"]
}
]
}
]
from typing import Type
from pydantic import BaseModel, Field
from langchain_core.tools import BaseTool
# 根据输入补全行政区划信息
def complete_administrative_division(input_text, data):
for province in data:
if input_text in province['province']:
return {'province': province['province'], 'cities': [city['name'] for city in province['cities']]}
for city in province['cities']:
if input_text in city['name']:
return {
'province': province['province'],
'city': city['name'],
'counties': city['counties']
}
for county in city['counties']:
if input_text in county:
return {
'province': province['province'],
'city': city['name'],
'county': county
}
return None
class AdministrativeDivisionArgs(BaseModel):
input_text: str = Field(..., description="输入的行政区划信息,可以是省、市、县(区)三级中的最低一级的行政区划名称,例如用户输入'广东省广州市',则此参数只需要取'广州市',用户输入'武汉市洪山区',则次参数为'洪山区'")
class AdministrativeDivision(BaseTool):
name = "administrative_division"
description = "根据输入补全行政区划信息,明确具体的省、市、县信息。"
args_schema: Type[BaseModel] = AdministrativeDivisionArgs
def _run(self, input_text: str) -> str:
result = complete_administrative_division(input_text, divisions)
return result
\ No newline at end of file
......@@ -18,6 +18,8 @@ from pydantic import BaseModel, Field
from src.server.agent import Agent
from src.config.prompts import PROMPT_AGENT_SYS, PROMPT_AGENT_HUMAN
from src.agent.tool_divisions import AdministrativeDivision
class CalcInput(BaseModel):
a: int = Field(...,description="第一个数")
b: int = Field(...,description="第二个数")
......@@ -35,7 +37,8 @@ class Calc(BaseTool):
return a + b
tools = [Calc()]
tools = [AdministrativeDivision()]
llm = ChatOpenAI(
openai_api_key='xxxxxxxxxxxxx',
......@@ -60,18 +63,29 @@ prompt = ChatPromptTemplate(
input_types=input_types,
messages=messages
)
# for msg in prompt.messages:
# print(msg)
agent = Agent(llm=llm, tools=tools, prompt=prompt, verbose=True)
# res = agent.agent_executor.invoke(input={"input": "what is 1 + 1?"})
# res = agent.invoke(prompt_args={"input": "what is 1 + 1?"})
res = agent.exec(prompt_args={"input": "what is 1 + 1?"})
def test_add():
tools = [Calc()]
agent = Agent(llm=llm, tools=tools, prompt=prompt, verbose=True)
res = agent.exec(prompt_args={"input": "what is 1 + 1?"})
# agent = create_structured_chat_agent(llm, tools, prompt)
# agent_executor = AgentExecutor(agent=agent,tools=tools,verbose=True,handle_parsing_errors=True)
# res = agent_executor.invoke(input={"input": "what is 1 + 1?"})
# agent = create_structured_chat_agent(llm, tools, prompt)
# agent_executor = AgentExecutor(agent=agent,tools=tools,verbose=True,handle_parsing_errors=True)
# res = agent_executor.invoke(input={"input": "what is 1 + 1?"})
# print(res)
print(res)
# for step in agent.stream(prompt_args={"input": "what is 1 + 1?"}):
# print("== step ==")
# print(step)
def test_agent_division():
tools = [AdministrativeDivision()]
agent = Agent(llm=llm, tools=tools, prompt=prompt, verbose=True)
res = agent.exec(prompt_args={"input": "介绍下陇南市武都区的基本情况"})
print(res)
if __name__ == "__main__":
test_agent_division()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment