Commit 45d5cb3a by 文靖昊

问题重写逻辑修改,加入正则匹配不规范的回答

parent be543be5
...@@ -123,12 +123,20 @@ def classify(chat_request: GeoAgentRateRequest): ...@@ -123,12 +123,20 @@ def classify(chat_request: GeoAgentRateRequest):
llm = agent_manager.get_router_llm() llm = agent_manager.get_router_llm()
re_llm = agent_manager.get_re_rewriter_llm() re_llm = agent_manager.get_re_rewriter_llm()
try: try:
# history = re_llm.extend_history(history=chat_request.history) if chat_request.query is None:
# rewrite = re_llm.invoke(chat_request.query, history) return {
answer = llm.invoke(chat_request.query) 'code': 500,
'data': "缺少必要的参数 query"
}
if chat_request.history is None:
history = ""
else:
history = re_llm.extend_history(history=chat_request.history)
rewrite = re_llm.invoke(chat_request.query, history)
answer = llm.invoke(rewrite)
res = { res = {
"datasource": answer.datasource, "datasource": answer.datasource,
"rewrite": chat_request.query "rewrite": rewrite
} }
except Exception as e: except Exception as e:
print(f"分类失败, 错误信息: {str(e)},请重新提问") print(f"分类失败, 错误信息: {str(e)},请重新提问")
......
...@@ -83,6 +83,7 @@ A: Laf 是一个云函数开发平台。 ...@@ -83,6 +83,7 @@ A: Laf 是一个云函数开发平台。
from pydantic import BaseModel,Field from pydantic import BaseModel,Field
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import PydanticOutputParser from langchain_core.output_parsers import PydanticOutputParser
import re
class ReWriterModel(BaseModel): class ReWriterModel(BaseModel):
...@@ -99,10 +100,26 @@ class ReWriteLLM: ...@@ -99,10 +100,26 @@ class ReWriteLLM:
input_variables=["histories","query"], input_variables=["histories","query"],
partial_variables={"format_instructions": parser.get_format_instructions()}, partial_variables={"format_instructions": parser.get_format_instructions()},
) )
self.router = prompt | llm | parser self.router = prompt | llm
self.parser = parser
def invoke(self, question, history): def invoke(self, question, history):
return self.router.invoke({"query": question,"histories":history}) response = self.router.invoke({"query": question, "histories": history})
try:
result = self.parser.invoke(response)
except:
# 定义正则表达式模式
pattern = re.compile(r'重写后的问题:\s*(?:{\s*"rewriter"\s*:\s*"([^"]+)"\s*}|([^,]+))')
# 查找所有匹配
matches = pattern.findall(response.content)
for match in matches:
if match[0]: # 第一种情况
return match[0].strip()
elif match[1]: # 第二种情况
return match[1].strip()
return response.content
return result.rewriter
def extend_history(self, history): def extend_history(self, history):
history_extend = "" history_extend = ""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment