import sys,os
sys.path.append("../")

from typing import List, Union, Type, Optional

from langchain import hub
import langchain_core
from langchain_core.tools import tool, BaseTool
from langchain_core.prompts import ChatPromptTemplate,PromptTemplate
from langchain_core.prompts.chat import ChatPromptTemplate,HumanMessagePromptTemplate,SystemMessagePromptTemplate,MessagesPlaceholder
from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_openai import ChatOpenAI
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain.agents import AgentExecutor, create_tool_calling_agent,create_structured_chat_agent

from pydantic import BaseModel, Field

from src.server.agent import Agent, create_chart_agent
from src.config.prompts import PROMPT_AGENT_SYS, PROMPT_AGENT_HUMAN, PROMPT_AGENT_CHART_SYS

from src.agent.tool_divisions import AdministrativeDivision, CountryInfo

class CalcInput(BaseModel):
    a: int = Field(...,description="第一个数")
    b: int = Field(...,description="第二个数")

class Calc(BaseTool):
    name = "calc"
    description = "一个简单的计算工具,可以计算两个数的和"
    args_schema: Type[BaseModel] = CalcInput

    def _run(
        self, a: int, b: int, run_manager: Optional[CallbackManagerForToolRun] = None
    ) -> str:
        """Use the tool."""
        print(f"Calculating {a} + {b}")
        return a + b




tools = [AdministrativeDivision()]

llm = ChatOpenAI(
    openai_api_key='xxxxxxxxxxxxx',
    openai_api_base='http://192.168.10.14:8000/v1',
    # openai_api_base='https://127.0.0.1:8000/v1',
    model_name='Qwen2-7B',
    verbose=True,
    temperature=0,
)

# prompt = hub.pull("hwchase17/openai-functions-agent")
input_variables=['agent_scratchpad', 'input', 'tool_names', 'tools', "chart_tool"] 
input_types={'chat_history': List[Union[langchain_core.messages.ai.AIMessage, langchain_core.messages.human.HumanMessage, langchain_core.messages.chat.ChatMessage, langchain_core.messages.system.SystemMessage, langchain_core.messages.function.FunctionMessage, langchain_core.messages.tool.ToolMessage]]} 
messages=[
  SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=['tool_names', 'tools', "chart_tool"], template=PROMPT_AGENT_CHART_SYS)), 
  MessagesPlaceholder(variable_name='chat_history', optional=True), 
  HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['agent_scratchpad', 'input'], template=PROMPT_AGENT_HUMAN))
]

prompt = ChatPromptTemplate(
    input_variables=input_variables,
    input_types=input_types,
    messages=messages
)

def test_add():
    tools = [Calc()]
    agent = Agent(llm=llm, tools=tools, prompt=prompt, verbose=True)
    agent = create_chart_agent(llm, tools, prompt,chart_tool="chart")

    res = agent.exec(prompt_args={"input": "what is 1 + 1?"})

    # agent = create_structured_chat_agent(llm, tools, prompt)
    # agent_executor = AgentExecutor(agent=agent,tools=tools,verbose=True,handle_parsing_errors=True)        
    # res = agent_executor.invoke(input={"input": "what is 1 + 1?"})

    # print(res)


    # for step in agent.stream(prompt_args={"input": "what is 1 + 1?"}):
    #     print("== step ==")
    #     print(step)
     
def test_agent_division():
    tools = [AdministrativeDivision(),CountryInfo()]
    agent = Agent(llm=llm, tools=tools, prompt=prompt, verbose=True)
    res = agent.exec(prompt_args={"input": "我想知道陇南市西和县和文县的降雨量谁的多"})
    print(res)

def test_chart_tool():
    from src.agent.tool_chart import chart_image
    x = [1,2,3,4,5]
    y = [1,4,9,16,25]
    chart_data = {
        "name": "test",
        "chart_type": "bar",
        "x": x,
        "y": y,
        "x_label": "x axis",
        "y_label": "y axis"
    }
    img = chart_image(chart_data)
    img.show()

def test_agent_chart():
    from src.agent.tool_chart import Chart
    tools = [Chart()]
    agent = Agent(llm=llm, tools=tools, prompt=prompt, verbose=True, chart_tool="chart")
    res = agent.exec(prompt_args={"input": "请告诉我海拔前十的高山有哪些"})
    print(res)


if __name__ == "__main__":
    # test_agent_division()
    # test_chart_tool()
    test_agent_chart()