from typing import Any, List, Sequence, Union
from datetime import datetime
import langchain_core
from langchain.tools import BaseTool
from langchain_core.prompts import PromptTemplate,ChatPromptTemplate,SystemMessagePromptTemplate,MessagesPlaceholder,HumanMessagePromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain.agents import AgentExecutor, Agent, create_tool_calling_agent,create_openai_functions_agent,create_structured_chat_agent
from langchain.agents.format_scratchpad import format_log_to_str
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
from langchain.agents.format_scratchpad.openai_tools import (
    format_to_openai_tool_messages,
)
from langchain import hub


from src.agent.tool_rate import RegionRateTool,RankingRateTool
from src.agent.tool_monitor import MonitorPointTool
from src.agent.tool_warn import WarningTool
from src.server.tool_picker import ToolPicker, ToolRunner

# def create_rate_agent(llm, tools: List[BaseTool],prompt: PromptTemplate = None, 
#     tools_renderer: ToolsRenderer = render_text_description_and_args,
#     verbose: bool = False,**args):
#     missing_vars = {"tools", "tool_names", "agent_scratchpad"}.difference(
#         prompt.input_variables + list(prompt.partial_variables)
#     )
#     if missing_vars:
#         raise ValueError(f"Prompt missing required variables: {missing_vars}")

#     prompt = prompt.partial(
#         tools=tools_renderer(list(tools)),
#         tool_names=", ".join([t.name for t in tools]),
#     )
#     if stop_sequence:
#         stop = ["\nObservation"] if stop_sequence is True else stop_sequence
#         llm_with_stop = llm.bind(stop=stop)
#     else:
#         llm_with_stop = llm

#     agent = (
#         RunnablePassthrough.assign(
#             agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]),
#         )
#         | prompt
#         | llm_with_stop
#     )
#     return agent


class RateAgent:
    def __init__(self, llm, tools: List[BaseTool],prompt: PromptTemplate = None, verbose: bool = False,**args):
        # if not prompt:
        #     raise ValueError("PromptTemplate is required")

        prompt = hub.pull("hwchase17/openai-tools-agent")
        agent = create_tool_calling_agent(llm, tools, prompt)

        self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=verbose)    
        
    def exec(self, prompt_args: dict = {}, stream: bool = False):
        return self.agent_executor.invoke(input=prompt_args)

    def stream(self, prompt_args: dict = {}):
        for step in self.agent_executor.stream(prompt_args):
            yield step


# 适配 structured_chat_agent 的 prompt
ONLINE_RATE_SYSTEM_PROMPT = """你是一个专门处理地质监测点信息及监测设备在线率分析的AI助手。你可以通过调用专门的工具来分析和展示不同维度的在线率数据。

核心工作流程:
1. 仔细分析用户问题,准确识别所需的分析类型和必要参数
2. 一次性收集所有必要参数 - 如参数不完整,立即要求用户补充,不进行任何工具调用
3. 确认参数完整后,仅调用一次最合适的分析工具
4. 收到工具返回数据后,必须使用 markdown 表格格式完整展示所有数据,不遗漏任何字段
5. 在表格后提供简洁的数据解读

严格遵守的规则:
- 禁止重复调用工具获取相同维度的数据
- 工具返回的数据必须且只能用 markdown 表格展示,不使用其他格式
- 表格中的数据必须完整展示,不省略任何字段
- 时间格式:YYYY-MM-DD
- 地区名称必须包含行政级别(如:福建省、厦门市)
- 百分比数据保留两位小数

异常处理:
- 参数缺失:立即提示用户补充具体参数,不进行工具调用
- 数据异常:仅使用工具返回的实际数据,不做假设或补充

您可以使用以下工具:

{tools}

使用 JSON 对象指定工具,提供一个 action 键(工具名称)和一个 action_input 键(工具输入) 。

有效的 "action" 值: "Final Answer" 或 {tool_names}

每个 $JSON_BLOB 只提供一个操作,如下所示:
    ``` 
    {{
    "action": $TOOL_NAME,
    "action_input": $INPUT,
    }}
    ```

按照以下格式:

Question: 输入要回答的问题
Thought: 考虑前后步骤
Action:
    ```
    $JSON_BLOB
    ```
Observation: 操作结果
...(重复 Thought/Action/Observation N 次)
Thought: 我知道如何回复
Action:
    ```
    {{
    "action": "Final Answer",
    "action_input": "最终回复给人类",
    }}
    ```
开始!始终以有效的单个操作的 JSON 对象回复。如有必要,请使用工具。如果你知道答案,请直接回复。
你的回复格式为 Action:```$JSON_BLOB```然后 Observation。
"""


PROMPT_AGENT_HUMAN = """{input}\n\n {agent_scratchpad}\n (请注意,无论如何都要以 JSON 对象回复。工具返回的数据必须使用表格展示,包含在最终输出中,并且要保证数据的完整性)"""
PROMPT_AGENT_SYS_VARS = [ "tool_names", "tools"]

class RateAgentV2:
    def __init__(self, llm, tools: List[BaseTool],prompt: PromptTemplate = None, verbose: bool = False,**args):
        date_now = datetime.now().strftime("%Y-%m-%d")
        prompt_human = f"{PROMPT_AGENT_HUMAN}\n\n今天是{date_now}"
        prompt = ChatPromptTemplate.from_messages([
            SystemMessagePromptTemplate.from_template(ONLINE_RATE_SYSTEM_PROMPT),
            MessagesPlaceholder(variable_name="chat_history", optional=True),
            HumanMessagePromptTemplate.from_template(prompt_human)
        ])

        agent = create_structured_chat_agent(llm, tools, prompt)

        self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=verbose, return_intermediate_steps=True,handle_parsing_errors=True)
        
    def exec(self, prompt_args: dict = {}, stream: bool = False):
        return self.agent_executor.invoke(input=prompt_args)

    def stream(self, prompt_args: dict = {}):
        for step in self.agent_executor.stream(prompt_args):
            yield step

def new_rate_agent(llm, verbose: bool = False,**args):

    if args['tool_base_url']:
        tool_base_url = args['tool_base_url']
    else:
        tool_base_url = const_base_url

    tools = [
        RegionRateTool(base_url=tool_base_url),
        RankingRateTool(base_url=tool_base_url),
        MonitorPointTool(base_url=tool_base_url)
    ]

    # 使用 LangChain 的工具调用代理
    agent = RateAgentV2(llm=llm, tools=tools, verbose=verbose, **args)
    return agent



    
class RateAgentV3:
    def __init__(self, llm, tool_base_url: str,version: str = "v1"):
        tools = [
            RegionRateTool(base_url=tool_base_url),
            RankingRateTool(base_url=tool_base_url),
            MonitorPointTool(base_url=tool_base_url),
            WarningTool(base_url=tool_base_url)
        ]
        self.picker = ToolPicker(llm, tools)
        tools_dict = {}
        for t in tools:
            tools_dict[t.name] = t
        self.runner = ToolRunner(llm, tools_dict,version)
        self.version = version

    def run(self, input: str):
        picker_result = self.picker.pick(input)
        res = self.runner.run(input, picker_result["tool"], picker_result["params"])
        if self.version != "v1": 
            output = res['output']
        else: # v1 版本, 使用表格单独展示数据
            output = f"相关数据如下:\n{res['table']}\n\n{res['output']}"
            
        return {
            "input": input,
            "output": output
        }