Commit 2c152b87 by 文靖昊

修改流程,改成无agent流程,新增地区提取工具

parent 480f1773
...@@ -9,7 +9,7 @@ divisions = [ ...@@ -9,7 +9,7 @@ divisions = [
}, },
{ {
"name": "海东市", "name": "海东市",
"counties": ["乐都区", "平安区", "民和回族土族自治县", "互助土族自治县", "化隆回族自治县", "循化撒拉族自治县"] "counties": ["乐都区", "平安区", "民和回族土族自治县", "互助土族自治县", "化隆县", "循化撒拉族自治县"]
}, },
{ {
"name": "海北藏族自治州", "name": "海北藏族自治州",
......
...@@ -306,4 +306,41 @@ A: 你可以在 nginx 官网上下载 nginx。 ...@@ -306,4 +306,41 @@ A: 你可以在 nginx 官网上下载 nginx。
{histories} {histories}
''' '''
新问题: 新问题:
"""
# 结合历史问答对话,提取问题中的位置信息
PROMPT_LOCATION_EXTEND = """
做为一个行政区提取助手,你的任务是结合历史对话,提取出问题所涉及到的行政区。例如:
历史记录:
'''
'''
原问题: 西宁市各区县谁的年平均降雨量大。
提取到的行政区: [西宁市]
----------------
历史记录:
'''
Q: 对话背景。
A: 当前对话是关于化隆县的介绍。
'''
原问题: 其年最高温是多少
提取到的行政区: [化隆县]
----------------
历史记录:
'''
'''
原问题: 请问 nginx 如何下载?
提取到的行政区: []
----------------
历史记录:
'''
'''
原问题: 大通县和湟源县谁的降雨量高?
提取到的行政区:[大通县,湟源县]
----------------
'''
{histories}
'''
原问题: {query}
提取到的行政区:
""" """
\ No newline at end of file
...@@ -8,22 +8,15 @@ from src.pgdb.chat.c_db import UPostgresDB ...@@ -8,22 +8,15 @@ from src.pgdb.chat.c_db import UPostgresDB
import uvicorn import uvicorn
import json import json
from src.pgdb.chat.crud import CRUD from src.pgdb.chat.crud import CRUD
from langchain.agents import AgentExecutor
from langchain_core.prompts.chat import ChatPromptTemplate,HumanMessagePromptTemplate,SystemMessagePromptTemplate,MessagesPlaceholder
from langchain.chains import LLMChain
import langchain_core
from typing import List,Union
from src.pgdb.knowledge.similarity import VectorStore_FAISS from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.server.get_similarity import QAExt
from src.server.agent import create_chart_agent
from src.pgdb.knowledge.k_db import PostgresDB from src.pgdb.knowledge.k_db import PostgresDB
from src.pgdb.knowledge.txt_doc_table import TxtDoc from src.pgdb.knowledge.txt_doc_table import TxtDoc
from src.agent.tool_divisions import AdministrativeDivision
from src.agent.rag_agent import RAGQuery
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from langchain_core.documents import Document from langchain_core.documents import Document
import re from src.server.rag_query import RagQuery
from src.controller.request import ( from src.controller.request import (
PhoneLoginRequest, PhoneLoginRequest,
ChatRequest, ChatRequest,
...@@ -44,10 +37,7 @@ from src.config.consts import ( ...@@ -44,10 +37,7 @@ from src.config.consts import (
VEC_DB_USER, VEC_DB_USER,
VEC_DB_DBNAME, VEC_DB_DBNAME,
SIMILARITY_SHOW_NUMBER, SIMILARITY_SHOW_NUMBER,
prompt_enhancement_history_template,
prompt1
) )
from src.config.prompts import PROMPT_AGENT_SYS_VARS,PROMPT_AGENT_SYS,PROMPT_AGENT_CHAT_HUMAN,PROMPT_AGENT_CHAT_HUMAN_VARS,PROMPT_AGENT_EXTEND_SYS
app = FastAPI() app = FastAPI()
...@@ -82,54 +72,10 @@ base_llm = ChatOpenAI( ...@@ -82,54 +72,10 @@ base_llm = ChatOpenAI(
verbose=True verbose=True
) )
ext = QAExt(base_llm) rag_query = RagQuery(base_llm=base_llm,_faiss_db=vecstore_faiss,_db=TxtDoc(k_db))
llm_chain = LLMChain(llm=base_llm, prompt=PromptTemplate(input_variables=["history","context", "question"], template=prompt1), llm_kwargs= {"temperature": 0})
tools = [RAGQuery(vecstore_faiss,ext,PromptTemplate(input_variables=["history","context", "question"], template=prompt_enhancement_history_template),_db=TxtDoc(k_db),_llm_chain=llm_chain)]
# input_variables=['agent_scratchpad', 'input', 'tool_names', 'tools','chart_tool']
input_variables=[]
input_variables.extend(PROMPT_AGENT_CHAT_HUMAN_VARS)
input_variables.extend(PROMPT_AGENT_SYS_VARS)
input_types={'chat_history': List[Union[langchain_core.messages.ai.AIMessage, langchain_core.messages.human.HumanMessage, langchain_core.messages.chat.ChatMessage, langchain_core.messages.system.SystemMessage, langchain_core.messages.function.FunctionMessage, langchain_core.messages.tool.ToolMessage]]}
messages=[
# SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=['tool_names', 'tools'], template='Respond to the human as helpfully and accurately as possible. You have access to the following tools:\n\n{tools}\n\nUse a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).\n\nValid "action" values: "Final Answer" or {tool_names}\n\nProvide only ONE action per $JSON_BLOB, as shown:\n\n```\n{{\n "action": $TOOL_NAME,\n "action_input": $INPUT\n}}\n```\n\nFollow this format:\n\nQuestion: input question to answer\nThought: consider previous and subsequent steps\nAction:\n```\n$JSON_BLOB\n```\nObservation: action result\n... (repeat Thought/Action/Observation N times)\nThought: I know what to respond\nAction:\n```\n{{\n "action": "Final Answer",\n "action_input": "Final response to human"\n}}\n\nBegin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation')),
SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=PROMPT_AGENT_SYS_VARS, template=PROMPT_AGENT_SYS)),
MessagesPlaceholder(variable_name='chat_history', optional=True),
# HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['agent_scratchpad', 'input'], template='{input}\n\n{agent_scratchpad}\n (reminder to respond in a JSON blob no matter what)'))
HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=PROMPT_AGENT_CHAT_HUMAN_VARS, template=PROMPT_AGENT_CHAT_HUMAN))
]
prompt = ChatPromptTemplate(
input_variables=input_variables,
input_types=input_types,
# metadata=metadata,
messages=messages
)
# agent = create_structured_chat_agent(llm=base_llm, tools=tools, prompt=prompt)
agent = create_chart_agent(base_llm, tools, prompt, chart_tool="chart")
agent_executor = AgentExecutor(agent=agent, tools=tools,verbose=True,handle_parsing_errors=True,return_intermediate_steps=True)
administrative_messages=[
# SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=['tool_names', 'tools'], template='Respond to the human as helpfully and accurately as possible. You have access to the following tools:\n\n{tools}\n\nUse a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).\n\nValid "action" values: "Final Answer" or {tool_names}\n\nProvide only ONE action per $JSON_BLOB, as shown:\n\n```\n{{\n "action": $TOOL_NAME,\n "action_input": $INPUT\n}}\n```\n\nFollow this format:\n\nQuestion: input question to answer\nThought: consider previous and subsequent steps\nAction:\n```\n$JSON_BLOB\n```\nObservation: action result\n... (repeat Thought/Action/Observation N times)\nThought: I know what to respond\nAction:\n```\n{{\n "action": "Final Answer",\n "action_input": "Final response to human"\n}}\n\nBegin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation')),
SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=PROMPT_AGENT_SYS_VARS, template=PROMPT_AGENT_EXTEND_SYS)),
MessagesPlaceholder(variable_name='chat_history', optional=True),
# HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['agent_scratchpad', 'input'], template='{input}\n\n{agent_scratchpad}\n (reminder to respond in a JSON blob no matter what)'))
HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=PROMPT_AGENT_CHAT_HUMAN_VARS, template=PROMPT_AGENT_CHAT_HUMAN))
]
administrative_prompt = ChatPromptTemplate(
input_variables=input_variables,
input_types=input_types,
# metadata=metadata,
messages=administrative_messages
)
AdministrativeTools =[AdministrativeDivision()]
administrative_agent = create_chart_agent(base_llm, AdministrativeTools, administrative_prompt, chart_tool="chart")
administrative_agent_executor = AgentExecutor(agent=administrative_agent, tools=AdministrativeTools,verbose=True,handle_parsing_errors=True,return_intermediate_steps=True)
# my_chat = QA(PromptTemplate(input_variables=["history","context", "question"], template=prompt_enhancement_history_template), base_llm,
# {"temperature": 0.9}, ['history','context', 'question'], _db=c_db, _faiss_db=vecstore_faiss,rerank=True)
@app.post('/api/login') @app.post('/api/login')
def login(phone_request: PhoneLoginRequest): def login(phone_request: PhoneLoginRequest):
...@@ -155,7 +101,7 @@ def login(phone_request: PhoneLoginRequest): ...@@ -155,7 +101,7 @@ def login(phone_request: PhoneLoginRequest):
@app.get('/api/sessions/chat/') @app.get('/api/sessions/chat/')
def get_sessions(timestamp: int = Query(None, alias="_"),token: str = Header(None)): def get_sessions(token: str = Header(None)):
if not token: if not token:
return { return {
'code': 404, 'code': 404,
...@@ -224,28 +170,14 @@ def question(chat_request: ChatRequest, token: str = Header(None)): ...@@ -224,28 +170,14 @@ def question(chat_request: ChatRequest, token: str = Header(None)):
history = [] history = []
if session_id !="": if session_id !="":
history = crud.get_last_history(str(session_id)) history = crud.get_last_history(str(session_id))
# answer = my_chat.chat(question) prompt = ""
# result = ext.extend_query(question, history) for h in history:
# matches = re.findall(r'"([^"]+)"', result.content) prompt += "Q: {}\nA:{}\n".format(h[0], h[1])
# print(matches) res = rag_query.query(question=question,history=prompt)
# if len(matches)>3: answer = res["answer"]
# matches = matches[:3] docs = res["docs"]
# print(matches) docs_json = json.loads(docs, strict=False)
# prompt = ""
# for h in history:
# prompt += "问:{}\n答:{}\n\n".format(h[0], h[1])
res_a = administrative_agent_executor.invoke({"input": question, "histories": history})
new_question = res_a['output']
res = agent_executor.invoke({"input": new_question, "histories": history})
answer = res["output"]
# answer, docs = my_chat.chat_with_history_with_ext(question,ext=matches,history=prompt, with_similarity=True)
docs_json = []
for step in res["intermediate_steps"]:
if "rag_query" == step[0].tool:
j = json.loads(step[1]["参考文档"], strict=False)
docs_json.extend(j)
print(len(docs_json)) print(len(docs_json))
doc_hash = [] doc_hash = []
for d in docs_json: for d in docs_json:
...@@ -287,26 +219,14 @@ def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)): ...@@ -287,26 +219,14 @@ def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)):
crud = CRUD(_db=c_db) crud = CRUD(_db=c_db)
last_turn_id = crud.get_last_turn_num(str(session_id)) last_turn_id = crud.get_last_turn_num(str(session_id))
history = crud.get_last_history_before_turn_id(str(session_id),last_turn_id) history = crud.get_last_history_before_turn_id(str(session_id),last_turn_id)
# result = ext.extend_query(question, history) prompt = ""
# matches = re.findall(r'"([^"]+)"', result.content) for h in history:
# if len(matches)>3: prompt += "Q: {}\nA:{}\n".format(h[0], h[1])
# matches = matches[:3] res = rag_query.query(question=question, history=prompt)
# print(matches) answer = res["answer"]
# prompt = "" docs = res["docs"]
# for h in history: docs_json = json.loads(docs, strict=False)
# prompt += "问:{}\n答:{}\n\n".format(h[0], h[1])
# answer, docs = my_chat.chat_with_history_with_ext(question,ext=matches, history=prompt, with_similarity=True)
# docs_json = []
res_a = administrative_agent_executor.invoke({"input": question, "histories": history})
new_question = res_a['output']
res = agent_executor.invoke({"input": new_question, "histories": history})
answer = res["output"]
docs_json = []
for step in res["intermediate_steps"]:
if "rag_query" == step[0].tool:
j = json.loads(step[1]["参考文档"], strict=False)
docs_json.extend(j)
doc_hash = [] doc_hash = []
for d in docs_json: for d in docs_json:
......
from langchain_core.prompts import PromptTemplate
from src.config.prompts import PROMPT_LOCATION_EXTEND
class LocationExt:
llm = None
def __init__(self, llm) -> None:
self.llm = llm
prompt = PromptTemplate.from_template(PROMPT_LOCATION_EXTEND)
# parser = ListOutputParser()
self.query_extend = prompt | llm
def extend_query(self, question, messages=None):
"""
question: str
messages: list of tuple (str,str)
eg:
[
("Q1","A1"),
("Q2","A2"),
...
]
"""
if not messages:
messages = []
history = ""
for msg in messages:
history += f"Q: {msg[0]}\nA: {msg[1]}\n"
return self.query_extend.invoke(input={"histories": history, "query": question})
def extend_query_str(self, question, history):
return self.query_extend.invoke(input={"histories": history, "query": question})
\ No newline at end of file
import re
from src.server.get_similarity import GetSimilarityWithExt
import time
from src.server.rerank import BgeRerank
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate
from src.server.get_similarity import QAExt
from src.server.extend import LocationExt
import json
from src.agent.tool_divisions import complete_administrative_division,divisions
from langchain.chains import LLMChain
from src.config.consts import (
RERANK_MODEL_PATH,
prompt1
)
class RagQuery():
def __init__(self,base_llm,_faiss_db,_db):
self.qa_ext = QAExt(base_llm)
self.location_ext = LocationExt(base_llm)
self.rerank_model = BgeRerank(RERANK_MODEL_PATH)
self.faiss_db = _faiss_db
self.db = _db
self.llm_chain = LLMChain(llm=base_llm, prompt=PromptTemplate(input_variables=["history","context", "question"], template=prompt1), llm_kwargs= {"temperature": 0})
def get_similarity_with_ext_origin(self, _ext,_location):
return GetSimilarityWithExt(_question=_ext, _faiss_db=self.faiss_db,_location=_location)
def query(self, question: str, history: str) :
location_result = self.location_ext.extend_query_str(question=question, history=history)
index = location_result.content.find("提取到的行政区:")
if index == -1:
location_str = location_result.content
else:
location_str = location_result.content[index + len("提取到的行政区:"):]
pattern = r'\[([^\]]+)\]'
match = re.search(pattern, location_str)
cities = []
if match:
cities = match.group(1).split(', ')
cities_ext = []
for m in cities:
city_ext = complete_administrative_division(m, divisions)
cities_ext.append(city_ext)
location = []
prompt = ""
for city in cities_ext:
if city is not None and "县(区)" in city:
if isinstance(city["县(区)"], str):
location.append(city["县(区)"])
prompt += city["县(区)"] + "位于" + city["省"] + city["市"] + ","
if isinstance(city["县(区)"], list):
location.extend(city["县(区)"])
prompt += city["省"] + city["市"] + "管辖"
for x in city["县(区)"]:
prompt += x + ","
new_question = prompt + question
print(new_question)
split_docs_list = []
for l in location:
start = time.time()
answer = self.db.find_like_doc(l)
end = time.time()
print('find_like_doc time: %s Seconds' % (end - start))
print(len(answer) if answer else 0)
split_docs = []
for a in answer if answer else []:
d = Document(page_content=a[0], metadata=json.loads(a[1]))
split_docs.append(d)
print(len(split_docs))
# if len(split_docs) > 5:
# split_docs = split_docs[:5]
split_docs_list.append(split_docs)
start = time.time()
result = self.qa_ext.extend_query_with_str(new_question, history)
end = time.time()
print('extend_query_with_str time: %s Seconds' % (end - start))
print(result)
matches = re.findall(r'"([^"]+)"', result.content)
print(matches)
similarity = self.get_similarity_with_ext_origin(matches,_location=location)
cur_similarity = similarity.get_rerank_with_doc(self.rerank_model,split_docs_list)
cur_answer = self.llm_chain.run(context=cur_similarity, question=new_question, history=history)
return {"answer":cur_answer,"docs": cur_similarity}
\ No newline at end of file
from src.server.extend import LocationExt
from langchain_openai import ChatOpenAI
import re
from src.agent.tool_divisions import complete_administrative_division,divisions
base_llm = ChatOpenAI(
openai_api_key='xxxxxxxxxxxxx',
openai_api_base='http://192.168.10.14:8000/v1',
model_name='Qwen2-7B',
verbose=True
)
ext = LocationExt(base_llm)
message = [
]
question = "大通县和化隆县最大降雨量是多少"
result = ext.extend_query(question=question,messages=message)
print(result.content)
index = result.content.find("提取到的行政区:")
if index==-1:
location = result.content
else:
location = result.content[index + len("提取到的行政区:"):]
# 获取特定子字符串之后的字符串
print(location)
pattern = r'\[([^\]]+)\]'
match = re.search(pattern, location)
cities = []
if match:
# 匹配到的内容是括号内的全部内容,包括逗号和空格
# 再次使用split分割这个字符串
cities = match.group(1).split(', ')
print(cities)
location_exts = []
for m in cities:
print(m)
location_ext = complete_administrative_division(m,divisions)
print(location_ext)
location_exts.append(location_ext)
prompt1 = ""
locations = []
for l in location_exts:
if l is not None and "县(区)" in l:
if isinstance(l["县(区)"], str):
locations.append(l["县(区)"])
prompt1 += l["县(区)"] + "位于" + l["省"]+ l["市"]+","
if isinstance(l["县(区)"], list):
locations.extend(l["县(区)"])
prompt1 += l["省"] + l["市"] + "拥有"
for x in l["县(区)"]:
prompt1+= x+","
print(locations)
question = prompt1+question
print(question)
\ No newline at end of file
from langchain_openai import ChatOpenAI
from src.server.rag_query import RagQuery
from src.pgdb.knowledge.txt_doc_table import TxtDoc
from src.pgdb.knowledge.k_db import PostgresDB
from src.pgdb.knowledge.similarity import VectorStore_FAISS
import json
from src.config.consts import (
EMBEEDING_MODEL_PATH,
FAISS_STORE_PATH,
INDEX_NAME,
VEC_DB_HOST,
VEC_DB_PASSWORD,
VEC_DB_PORT,
VEC_DB_USER,
VEC_DB_DBNAME,
SIMILARITY_SHOW_NUMBER,
)
base_llm = ChatOpenAI(
openai_api_key='xxxxxxxxxxxxx',
openai_api_base='http://192.168.10.14:8000/v1',
model_name='Qwen2-7B',
verbose=True
)
vecstore_faiss = VectorStore_FAISS(
embedding_model_name=EMBEEDING_MODEL_PATH,
store_path=FAISS_STORE_PATH,
index_name=INDEX_NAME,
info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
"password": VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER,
reset=False)
k_db = PostgresDB(host=VEC_DB_HOST, database=VEC_DB_DBNAME, user=VEC_DB_USER, password=VEC_DB_PASSWORD, port=VEC_DB_PORT)
k_db.connect()
db = TxtDoc(k_db)
rag_query = RagQuery(base_llm=base_llm,_faiss_db=vecstore_faiss,_db=db)
question = "大通县和化隆县最大降雨量是多少"
history = ""
result = rag_query.query(question=question,history=history)
print(result)
j = json.loads(result["docs"], strict=False)
print(type(j))
print(j)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment