Commit 696bd09e by 文靖昊

增加问题重写逻辑

parent e0059c27
import json
import os
import sys
sys.path.append('../')
......@@ -11,6 +12,7 @@ import uvicorn
from src.server.agent_rate import new_rate_agent, RateAgentV3
from src.server.classify import new_router_llm
from src.server.rewrite import new_re_rewriter_llm
from src.controller.request import GeoAgentRateRequest
from langchain_openai import ChatOpenAI
......@@ -79,6 +81,7 @@ class AgentManager:
self.agent = new_rate_agent(self.llm, verbose=True, tool_base_url=tool_base_url)
self.rate_agent = RateAgentV3(self.llm, tool_base_url=tool_base_url)
self.router_llm = new_router_llm(self.llm)
self.re_rewriter_llm = new_re_rewriter_llm(self.llm)
def get_llm(self):
return self.llm
......@@ -92,6 +95,9 @@ class AgentManager:
def get_router_llm(self):
return self.router_llm
def get_re_rewriter_llm(self):
return self.re_rewriter_llm
agent_manager = AgentManager()
@app.post('/api/agent/rate')
......@@ -115,8 +121,14 @@ def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
@app.post('/api/classify')
def classify(chat_request: GeoAgentRateRequest):
llm = agent_manager.get_router_llm()
re_llm = agent_manager.get_re_rewriter_llm()
try:
res = llm.invoke(chat_request.query)
history = re_llm.extend_history(history=chat_request.history)
rewrite = re_llm.invoke(chat_request.query, history)
answer = llm.invoke(rewrite.rewriter)
res = {}
res["datasource"] =answer.datasource
res["rewrite"] = rewrite.rewriter
except Exception as e:
print(f"分类失败, 错误信息: {str(e)},请重新提问")
return {
......
PROMPT_QUERY_REWRITE = """作为一个向量检索助手,你的任务是结合历史记录,对“原问题”进行优化,从而提高向量检索的语义丰富度,提高向量检索的精度。生成的问题要求指向对象清晰明确,并与“原问题语言相同”。例如
'''
'''
原问题: 介绍下剧情。
重写后的问题: 介绍下剧情
----------------
历史记录:
'''
Q: 对话背景。
A: 当前对话是关于 Nginx 的介绍和使用等。
'''
原问题: 怎么下载
重写后的问题: Nginx 如何怎么下载?
----------------
历史记录:
'''
Q: 对话背景。
A: 当前对话是关于 Nginx 的介绍和使用等。
Q: 报错 "no connection"
A: 报错"no connection"可能是因为……
'''
原问题: 这是什么原因要怎么解决
重写后的问题: Nginx报错"no connection"如何解决?
----------------
历史记录:
'''
Q: 护产假多少天?
A: 护产假的天数根据员工所在的城市而定。请提供您所在的城市,以便我回答您的问题。
'''
原问题: 沈阳
重写后的问题: 沈阳的护产假多少天?
----------------
历史记录:
'''
Q: 作者是谁?
A: FastGPT 的作者是 labring。
'''
原问题: Tell me about him
重写后的问题: Introduce labring, the author of FastGPT
----------------
历史记录:
'''
Q: 对话背景。
A: 关于 FatGPT 的介绍和使用等问题。
'''
原问题: 你好。
重写后的问题: 你好
----------------
历史记录:
'''
'''
原问题: 北京和上海那个天气好?
重写后的问题: 北京和上海那个天气好
----------------
历史记录:
'''
Q: FastGPT 的优势
A: 1. 开源
2. 简便
3. 扩展性强
'''
原问题: 介绍下第2点。
重写后的问题: 介绍下 FastGPT 简便的优势
----------------
历史记录:
'''
Q: 什么是 FastGPT?
A: FastGPT 是一个 RAG 平台。
Q: 什么是 Laf?
A: Laf 是一个云函数开发平台。
'''
原问题: 它们有什么关系?
重写后的问题: FastGPT和Laf有什么关系?
----------------
历史记录:
'''
{histories}
'''
原问题: {query}
重写后的问题: """
from pydantic import BaseModel,Field
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
class ReWriterModel(BaseModel):
rewriter: str = Field(
description="重写后的问题",
)
class ReWriteLLM:
def __init__(self,llm):
parser = PydanticOutputParser(pydantic_object=ReWriterModel)
prompt = PromptTemplate(
template=PROMPT_QUERY_REWRITE + "\n{format_instructions}\n",
input_variables=["histories","query"],
partial_variables={"format_instructions": parser.get_format_instructions()},
)
self.router = prompt | llm | parser
def invoke(self, question, history):
return self.router.invoke({"query": question,"histories":history})
def extend_history(self, history):
history_extend = ""
for msg in history:
if msg['obj'] == 'Human':
history_extend += f"Q: {msg['value'][0]['text']['content']}\n"
elif msg['obj'] == 'AI':
history_extend += f"A: {msg['value'][0]['text']['content']}\n"
return history_extend
def new_re_rewriter_llm(llm):
re_writer = ReWriteLLM(llm)
return re_writer
......@@ -2,7 +2,7 @@ from typing import List, Dict
from datetime import datetime
from langchain_core.prompts import PromptTemplate,ChatPromptTemplate,SystemMessagePromptTemplate,MessagesPlaceholder,HumanMessagePromptTemplate
from langchain.tools.render import ToolsRenderer, render_text_description_and_args
from langchain.tools.render import render_text_description_and_args
from langchain_core.output_parsers import JsonOutputParser as JSONOutputParser
from langchain_core.tools import BaseTool
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment