Commit a9027029 by 文靖昊

web流程替换为agent

parent dd12b019
......@@ -5,14 +5,21 @@ from fastapi import FastAPI, Header,Query
from fastapi.middleware.cors import CORSMiddleware
from datetime import datetime,timedelta
from src.pgdb.chat.c_db import UPostgresDB
from src.pgdb.knowledge.k_db import PostgresDB
from src.pgdb.knowledge.txt_doc_table import TxtDoc
import uvicorn
import json
from src.pgdb.chat.crud import CRUD
from langchain.agents import AgentExecutor
from langchain_core.prompts.chat import ChatPromptTemplate,HumanMessagePromptTemplate,SystemMessagePromptTemplate,MessagesPlaceholder
from langchain.chains import LLMChain
import langchain_core
from typing import List,Union
from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.server.get_similarity import QAExt
from src.server.qa import QA
from src.server.agent import create_chart_agent
from src.pgdb.knowledge.k_db import PostgresDB
from src.pgdb.knowledge.txt_doc_table import TxtDoc
from src.agent.tool_divisions import AdministrativeDivision
from src.agent.rag_agent import RAGQuery
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.documents import Document
......@@ -37,8 +44,12 @@ from src.config.consts import (
VEC_DB_USER,
VEC_DB_DBNAME,
SIMILARITY_SHOW_NUMBER,
prompt_enhancement_history_template
prompt_enhancement_history_template,
prompt1
)
from src.config.prompts import PROMPT_AGENT_SYS_VARS,PROMPT_AGENT_SYS,PROMPT_AGENT_CHAT_HUMAN,PROMPT_AGENT_CHAT_HUMAN_VARS
app = FastAPI()
app.add_middleware(
CORSMiddleware,
......@@ -73,8 +84,37 @@ base_llm = ChatOpenAI(
ext = QAExt(base_llm)
my_chat = QA(PromptTemplate(input_variables=["history","context", "question"], template=prompt_enhancement_history_template), base_llm,
{"temperature": 0.9}, ['history','context', 'question'], _db=c_db, _faiss_db=vecstore_faiss,rerank=True)
llm_chain = LLMChain(llm=base_llm, prompt=PromptTemplate(input_variables=["history","context", "question"], template=prompt1), llm_kwargs= {"temperature": 0})
tool_rag = RAGQuery(vecstore_faiss,ext,PromptTemplate(input_variables=["history","context", "question"], template=prompt_enhancement_history_template),_db=TxtDoc(k_db),_llm_chain=llm_chain)
tools = [AdministrativeDivision(),RAGQuery(vecstore_faiss,ext,PromptTemplate(input_variables=["history","context", "question"], template=prompt_enhancement_history_template),_db=TxtDoc(k_db),_llm_chain=llm_chain)]
# input_variables=['agent_scratchpad', 'input', 'tool_names', 'tools','chart_tool']
input_variables=[]
input_variables.extend(PROMPT_AGENT_CHAT_HUMAN_VARS)
input_variables.extend(PROMPT_AGENT_SYS_VARS)
input_types={'chat_history': List[Union[langchain_core.messages.ai.AIMessage, langchain_core.messages.human.HumanMessage, langchain_core.messages.chat.ChatMessage, langchain_core.messages.system.SystemMessage, langchain_core.messages.function.FunctionMessage, langchain_core.messages.tool.ToolMessage]]}
messages=[
# SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=['tool_names', 'tools'], template='Respond to the human as helpfully and accurately as possible. You have access to the following tools:\n\n{tools}\n\nUse a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).\n\nValid "action" values: "Final Answer" or {tool_names}\n\nProvide only ONE action per $JSON_BLOB, as shown:\n\n```\n{{\n "action": $TOOL_NAME,\n "action_input": $INPUT\n}}\n```\n\nFollow this format:\n\nQuestion: input question to answer\nThought: consider previous and subsequent steps\nAction:\n```\n$JSON_BLOB\n```\nObservation: action result\n... (repeat Thought/Action/Observation N times)\nThought: I know what to respond\nAction:\n```\n{{\n "action": "Final Answer",\n "action_input": "Final response to human"\n}}\n\nBegin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation')),
SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=PROMPT_AGENT_SYS_VARS, template=PROMPT_AGENT_SYS)),
MessagesPlaceholder(variable_name='chat_history', optional=True),
# HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['agent_scratchpad', 'input'], template='{input}\n\n{agent_scratchpad}\n (reminder to respond in a JSON blob no matter what)'))
HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=PROMPT_AGENT_CHAT_HUMAN_VARS, template=PROMPT_AGENT_CHAT_HUMAN))
]
prompt = ChatPromptTemplate(
input_variables=input_variables,
input_types=input_types,
# metadata=metadata,
messages=messages
)
# agent = create_structured_chat_agent(llm=base_llm, tools=tools, prompt=prompt)
agent = create_chart_agent(base_llm, tools, prompt, chart_tool="chart")
agent_executor = AgentExecutor(agent=agent, tools=tools,verbose=True,handle_parsing_errors=True,return_intermediate_steps=True)
# my_chat = QA(PromptTemplate(input_variables=["history","context", "question"], template=prompt_enhancement_history_template), base_llm,
# {"temperature": 0.9}, ['history','context', 'question'], _db=c_db, _faiss_db=vecstore_faiss,rerank=True)
@app.post('/api/login')
def login(phone_request: PhoneLoginRequest):
......@@ -170,26 +210,29 @@ def question(chat_request: ChatRequest, token: str = Header(None)):
if session_id !="":
history = crud.get_last_history(str(session_id))
# answer = my_chat.chat(question)
result = ext.extend_query(question, history)
matches = re.findall(r'"([^"]+)"', result.content)
print(matches)
if len(matches)>3:
matches = matches[:3]
print(matches)
prompt = ""
for h in history:
prompt += "问:{}\n答:{}\n\n".format(h[0], h[1])
answer, docs = my_chat.chat_with_history_with_ext(question,ext=matches,history=prompt, with_similarity=True)
# result = ext.extend_query(question, history)
# matches = re.findall(r'"([^"]+)"', result.content)
# print(matches)
# if len(matches)>3:
# matches = matches[:3]
# print(matches)
# prompt = ""
# for h in history:
# prompt += "问:{}\n答:{}\n\n".format(h[0], h[1])
res = agent_executor.invoke({"input": question, "histories": history})
answer = res["output"]
# answer, docs = my_chat.chat_with_history_with_ext(question,ext=matches,history=prompt, with_similarity=True)
docs_json = []
for step in res["intermediate_steps"]:
j = json.loads(step[1]["参考文档"], strict=False)
docs_json.extend(j)
print(len(docs_json))
doc_hash = []
for d in docs:
j ={}
j["page_content"] = d.page_content
j["from_file"] = d.metadata["filename"]
j["page_number"] = 0
if "hash" in d.metadata:
doc_hash.append(d.metadata["hash"])
docs_json.append(j)
for d in docs_json:
if "hash" in d:
doc_hash.append(d["hash"])
if len(doc_hash)>0:
hash_str = ",".join(doc_hash)
else:
......@@ -226,26 +269,28 @@ def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)):
crud = CRUD(_db=c_db)
last_turn_id = crud.get_last_turn_num(str(session_id))
history = crud.get_last_history_before_turn_id(str(session_id),last_turn_id)
result = ext.extend_query(question, history)
matches = re.findall(r'"([^"]+)"', result.content)
if len(matches)>3:
matches = matches[:3]
print(matches)
prompt = ""
for h in history:
prompt += "问:{}\n答:{}\n\n".format(h[0], h[1])
answer, docs = my_chat.chat_with_history_with_ext(question,ext=matches, history=prompt, with_similarity=True)
# result = ext.extend_query(question, history)
# matches = re.findall(r'"([^"]+)"', result.content)
# if len(matches)>3:
# matches = matches[:3]
# print(matches)
# prompt = ""
# for h in history:
# prompt += "问:{}\n答:{}\n\n".format(h[0], h[1])
# answer, docs = my_chat.chat_with_history_with_ext(question,ext=matches, history=prompt, with_similarity=True)
# docs_json = []
res = agent_executor.invoke({"input": question, "histories": history})
answer = res["output"]
docs_json = []
for step in res["intermediate_steps"]:
j = json.loads(step[1]["参考文档"], strict=False)
docs_json.extend(j)
doc_hash = []
for d in docs:
j = {}
j["page_content"] = d.page_content
j["from_file"] = d.metadata["filename"]
j["page_number"] = 0
docs_json.append(j)
if "hash" in d.metadata:
doc_hash.append(d.metadata["hash"])
for d in docs_json:
if "hash" in d:
doc_hash.append(d["hash"])
if len(doc_hash)>0:
hash_str = ",".join(doc_hash)
else:
......
......@@ -61,13 +61,20 @@ class GetSimilarityWithExt:
page_number = "0"
if "page_number" in d1.metadata:
page_number = d1.metadata["page_number"]
result += ("{\"page_content\": \"" + d1.page_content + "\",\"filename\":\"" + d1.metadata["filename"] + "\",\"page_number\":\"" + page_number + "\"}")
hash_str = ""
if "hash" in d1.metadata:
hash_str = d1.metadata["hash"]
result += ("{\"page_content\": \"" + d1.page_content + "\",\"filename\":\"" + d1.metadata["filename"] + "\",\"hash\":\"" + hash_str + "\",\"page_number\":\"" + page_number + "\"}")
docs = docs[1:]
for doc in docs:
page_number = "0"
if "page_number" in doc.metadata:
page_number = doc.metadata["page_number"]
result += ("{\"page_content\": \"" + doc.page_content + "\",\"filename\":\"" + doc.metadata["filename"] + "\",\"page_number\":\"" + page_number + "\"}")
hash_str = ""
if "hash" in doc.metadata:
hash_str = doc.metadata["hash"]
result += (",{\"page_content\": \"" + doc.page_content + "\",\"filename\":\"" + doc.metadata["filename"] + "\",\"hash\":\"" + hash_str + "\",\"page_number\":\"" + page_number + "\"}")
result += "]"
return result
def get_rerank_with_doc(self, reranker: BgeRerank,split_doc:list, top_k=5):
......
......@@ -104,5 +104,12 @@ print("====== result: ======")
print(res)
print(type(res))
print(res["output"])
docs_json = []
for step in res["intermediate_steps"]:
print(type(step[1]["参考文档"]))
print(step[1]["参考文档"])
j = json.loads(step[1]["参考文档"],strict=False)
docs_json.extend(j)
print(docs_json)
print(len(docs_json))
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment