extend_classify.py 2.21 KB
Newer Older
文靖昊 committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
from pydantic import BaseModel,Field
from typing import Literal
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate,ChatPromptTemplate,SystemMessagePromptTemplate,MessagesPlaceholder,HumanMessagePromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
import json
from ..utils.logger import get_logger


system1 = """
您是问题分析专家,根据用户问题,判断是否能从问题背景中获取到问题的答案,如果可以回答,请选择yes,不能则选择no,
以下是示例:
问题: 问题背景:无,问题:2024年10月15,北京市的设备在线率是多少?
是否能直接回答: no
----------------
问题: 问题背景:无,问题:自建CORS组网基准站观测墩的建造要求是什么?
是否需要调用工具: no
----------------
问题: 问题背景:陕西的设备在线率是90%,问题:陕西的设备在线率是多少,
分类: yes
----------------
问题: 问题背景:山西的设备在线率是90%,问题:陕西的设备在线率是多少,
分类: no
----------------
"""


class ExtendClassifyModel(BaseModel):
    """Route a user query to the most relevant datasource."""
    classify: Literal["yes", "no"] = Field(
        description="给定用户的问题,是否能从问题背景中获取到问题的答案。",
    )

class ExtendClassifyLLM:
    def __init__(self,base_llm):
        parser = PydanticOutputParser(pydantic_object=ExtendClassifyModel)
        prompt = PromptTemplate(
            template= system1 + "\n{format_instructions}\n{query}\n",
            input_variables=["query"],
            partial_variables={"format_instructions": parser.get_format_instructions()},
        )
        self.llm = prompt | base_llm | parser
        self.logger = get_logger(self.__class__.__name__)

    def invoke(self, question):
        try:
            result = self.llm.invoke({"query": question})
            self.logger.info(f"扩展分类结果: {result}")
            return result
        except Exception as e:
            self.logger.error(f"扩展分类结果失败: {e}", exc_info=True)
            return {'classify': 'no'}


def new_extend_classify_llm(llm):
    router = ExtendClassifyLLM(llm)
    return router