agent_test.py 4.11 KB
Newer Older
tinywell committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
import sys,os
sys.path.append("../")

from typing import List, Union, Type, Optional

from langchain import hub
import langchain_core
from langchain_core.tools import tool, BaseTool
from langchain_core.prompts import ChatPromptTemplate,PromptTemplate
from langchain_core.prompts.chat import ChatPromptTemplate,HumanMessagePromptTemplate,SystemMessagePromptTemplate,MessagesPlaceholder
from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_openai import ChatOpenAI
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain.agents import AgentExecutor, create_tool_calling_agent,create_structured_chat_agent

from pydantic import BaseModel, Field

18
from src.server.agent import Agent, create_chart_agent
tinywell committed
19
from src.config.prompts import PROMPT_AGENT_SYS, PROMPT_AGENT_HUMAN, PROMPT_AGENT_CHART_SYS
tinywell committed
20

21
from src.agent.tool_divisions import AdministrativeDivision, CountryInfo
22

tinywell committed
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
class CalcInput(BaseModel):
    a: int = Field(...,description="第一个数")
    b: int = Field(...,description="第二个数")

class Calc(BaseTool):
    name = "calc"
    description = "一个简单的计算工具,可以计算两个数的和"
    args_schema: Type[BaseModel] = CalcInput

    def _run(
        self, a: int, b: int, run_manager: Optional[CallbackManagerForToolRun] = None
    ) -> str:
        """Use the tool."""
        print(f"Calculating {a} + {b}")
        return a + b


40

41

42
tools = [AdministrativeDivision()]
tinywell committed
43 44 45 46 47 48 49 50 51 52 53

llm = ChatOpenAI(
    openai_api_key='xxxxxxxxxxxxx',
    openai_api_base='http://192.168.10.14:8000/v1',
    # openai_api_base='https://127.0.0.1:8000/v1',
    model_name='Qwen2-7B',
    verbose=True,
    temperature=0,
)

# prompt = hub.pull("hwchase17/openai-functions-agent")
54
input_variables=['agent_scratchpad', 'input', 'tool_names', 'tools', "chart_tool"] 
tinywell committed
55 56
input_types={'chat_history': List[Union[langchain_core.messages.ai.AIMessage, langchain_core.messages.human.HumanMessage, langchain_core.messages.chat.ChatMessage, langchain_core.messages.system.SystemMessage, langchain_core.messages.function.FunctionMessage, langchain_core.messages.tool.ToolMessage]]} 
messages=[
tinywell committed
57
  SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=['tool_names', 'tools', "chart_tool"], template=PROMPT_AGENT_CHART_SYS)), 
tinywell committed
58 59 60 61 62 63 64 65 66 67
  MessagesPlaceholder(variable_name='chat_history', optional=True), 
  HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['agent_scratchpad', 'input'], template=PROMPT_AGENT_HUMAN))
]

prompt = ChatPromptTemplate(
    input_variables=input_variables,
    input_types=input_types,
    messages=messages
)

68 69 70
def test_add():
    tools = [Calc()]
    agent = Agent(llm=llm, tools=tools, prompt=prompt, verbose=True)
71
    agent = create_chart_agent(llm, tools, prompt,chart_tool="chart")
72 73 74 75 76 77

    res = agent.exec(prompt_args={"input": "what is 1 + 1?"})

    # agent = create_structured_chat_agent(llm, tools, prompt)
    # agent_executor = AgentExecutor(agent=agent,tools=tools,verbose=True,handle_parsing_errors=True)        
    # res = agent_executor.invoke(input={"input": "what is 1 + 1?"})
tinywell committed
78

79
    # print(res)
tinywell committed
80 81


82 83 84 85 86
    # for step in agent.stream(prompt_args={"input": "what is 1 + 1?"}):
    #     print("== step ==")
    #     print(step)
     
def test_agent_division():
87
    tools = [AdministrativeDivision(),CountryInfo()]
88
    agent = Agent(llm=llm, tools=tools, prompt=prompt, verbose=True)
89
    res = agent.exec(prompt_args={"input": "我想知道陇南市西和县和文县的降雨量谁的多"})
90
    print(res)
tinywell committed
91

92 93 94 95 96 97 98 99 100 101 102 103
def test_chart_tool():
    from src.agent.tool_chart import chart_image
    x = [1,2,3,4,5]
    y = [1,4,9,16,25]
    chart_data = {
        "name": "test",
        "chart_type": "bar",
        "x": x,
        "y": y,
        "x_label": "x axis",
        "y_label": "y axis"
    }
tinywell committed
104 105 106 107 108 109 110 111 112 113
    img = chart_image(chart_data)
    img.show()

def test_agent_chart():
    from src.agent.tool_chart import Chart
    tools = [Chart()]
    agent = Agent(llm=llm, tools=tools, prompt=prompt, verbose=True, chart_tool="chart")
    res = agent.exec(prompt_args={"input": "请告诉我海拔前十的高山有哪些"})
    print(res)

114

115
if __name__ == "__main__":
116
    # test_agent_division()
tinywell committed
117 118
    # test_chart_tool()
    test_agent_chart()