from langchain_core.prompts import PromptTemplate
from src.config.prompts import PROMPT_LOCATION_EXTEND

class LocationExt:
    llm = None

    def __init__(self, llm) -> None:
        self.llm = llm
        prompt = PromptTemplate.from_template(PROMPT_LOCATION_EXTEND)
        # parser = ListOutputParser()
        self.query_extend = prompt | llm

    def extend_query(self, question, messages=None):
        """
            question: str
            messages: list of tuple (str,str)
                eg:
                [
                    ("Q1","A1"),
                    ("Q2","A2"),
                    ...
                ]

        """
        if not messages:
            messages = []
        history = ""
        for msg in messages:
            history += f"Q: {msg[0]}\nA: {msg[1]}\n"
        return self.query_extend.invoke(input={"histories": history, "query": question})

    def extend_query_str(self, question, history):
        return self.query_extend.invoke(input={"histories": history, "query": question})