from typing import Any, List, Sequence, Union

from langchain_core.prompts import PromptTemplate
from langchain.agents import AgentExecutor, create_tool_calling_agent,create_structured_chat_agent
from langchain.tools import BaseTool
# from langgraph.prebuilt import create_react_agent


from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.output_parsers.json import parse_json_markdown
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.exceptions import OutputParserException
from langchain.tools.render import ToolsRenderer, render_text_description_and_args
from langchain.agents.format_scratchpad import format_log_to_str
from langchain.agents.agent import AgentOutputParser

class ChartAgentOutputParser(AgentOutputParser):
    """Parses tool invocations and final answers in JSON format.

    Expects output to be in one of two formats.

    If the output signals that an action should be taken,
    should be in the below format. This will result in an AgentAction
    being returned.

    ```
    {
      "action": "search",
      "action_input": "2+2",
      "action_cache": ""
    }
    ```

    If the output signals that a final answer should be given,
    should be in the below format. This will result in an AgentFinish
    being returned.

    ```
    {
      "action": "Final Answer",
      "action_input": "4",
      "action_cache": ""
    }
    ```
    """

    def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
        try:
            response = parse_json_markdown(text)
            if isinstance(response, list):
                # gpt turbo frequently ignores the directive to emit a single action
                # logger.warning("Got multiple action responses: %s", response)
                response = response[0]
            if response["action"] == "Final Answer":
                if "action_cache" in response:
                    return AgentFinish({"output": response["action_input"],"cache":response["action_cache"]}, text)
                else:
                    return AgentFinish({"output": response["action_input"]}, text)
            else:
                return AgentAction(
                    response["action"], response.get("action_input", {}), text
                )
        except Exception as e:
            raise OutputParserException(f"Could not parse LLM output: {text}") from e

    @property
    def _type(self) -> str:
        return "chart-agent"


def create_chart_agent(
    llm: BaseLanguageModel,
    tools: Sequence[BaseTool],
    prompt: ChatPromptTemplate,
    chart_tool: str,
    tools_renderer: ToolsRenderer = render_text_description_and_args,
    *,
    stop_sequence: Union[bool, List[str]] = True,
) -> Runnable:
    """Create an agent aimed at supporting chart tools with multiple inputs.
    """
    missing_vars = {"tools", "tool_names", "agent_scratchpad"}.difference(
        prompt.input_variables + list(prompt.partial_variables)
    )
    if missing_vars:
        raise ValueError(f"Prompt missing required variables: {missing_vars}")

    prompt = prompt.partial(
        tools=tools_renderer(list(tools)),
        tool_names=", ".join([t.name for t in tools]),
        chart_tool=chart_tool,
    )
    if stop_sequence:
        stop = ["\nObservation"] if stop_sequence is True else stop_sequence
        llm_with_stop = llm.bind(stop=stop)
    else:
        llm_with_stop = llm

    agent = (
        RunnablePassthrough.assign(
            agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]),
        )
        | prompt
        | llm_with_stop
        | ChartAgentOutputParser()
    )
    return agent

class Agent:
    def __init__(self, llm, tools: List[BaseTool],prompt: PromptTemplate = None, verbose: bool = False,**args):
        self.llm = llm
        self.tools = tools
        self.prompt = prompt

        if not prompt:
            raise ValueError("PromptTemplate is required")
        
        agent = create_chart_agent(llm, tools, prompt, **args)
        self.agent_executor = AgentExecutor(agent=agent,tools=tools,verbose=verbose)        

    def exec(self, prompt_args: dict = {}, stream: bool = False):
        # if stream:
        #     for step in self.agent_executor.stream(prompt_args):
        #         yield step  
        return self.agent_executor.invoke(input=prompt_args)
        
    def stream(self, prompt_args: dict = {}):
        for step in self.agent_executor.stream(prompt_args):
            yield step