rag_agent_test.py 8.01 KB
Newer Older
1
import json
2 3 4
import sys
sys.path.append('../')

文靖昊 committed
5
from typing import List,Union
6

文靖昊 committed
7 8 9 10
from langchain_openai import ChatOpenAI
from langchain.agents import AgentExecutor,create_structured_chat_agent
from langchain_core.prompts.chat import ChatPromptTemplate,HumanMessagePromptTemplate,SystemMessagePromptTemplate,MessagesPlaceholder
from langchain_core.prompts import PromptTemplate
11
from langchain.chains import LLMChain
文靖昊 committed
12
import langchain_core
13 14
from src.llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion
文靖昊 committed
15 16
from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.server.get_similarity import QAExt
17
from src.server.agent import create_chart_agent
文靖昊 committed
18 19 20
from src.pgdb.knowledge.k_db import PostgresDB
from src.pgdb.knowledge.txt_doc_table import TxtDoc
from src.agent.tool_divisions import AdministrativeDivision
21
from src.agent.rag_agent import RAGQuery,RAGAnalysisQuery
文靖昊 committed
22 23 24 25 26 27 28 29 30 31 32 33 34
from src.config.consts import (
    EMBEEDING_MODEL_PATH,
    FAISS_STORE_PATH,
    INDEX_NAME,
    VEC_DB_HOST,
    VEC_DB_PASSWORD,
    VEC_DB_PORT,
    VEC_DB_USER,
    VEC_DB_DBNAME,
    SIMILARITY_SHOW_NUMBER,
    prompt_enhancement_history_template,
    prompt1
)
35

36
from src.config.prompts import PROMPT_AGENT_SYS_VARS,PROMPT_AGENT_SYS,PROMPT_AGENT_CHAT_HUMAN,PROMPT_AGENT_CHAT_HUMAN_VARS,PROMPT_AGENT_EXTEND_SYS
37

文靖昊 committed
38 39 40 41 42 43 44
base_llm = ChatOpenAI(
    openai_api_key='xxxxxxxxxxxxx',
    openai_api_base='http://192.168.10.14:8000/v1',
    model_name='Qwen2-7B',
    verbose=True,
    temperature=0
)
45 46
# base_llm = ChatERNIESerLLM(
#         chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
文靖昊 committed
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62

vecstore_faiss = VectorStore_FAISS(
        embedding_model_name=EMBEEDING_MODEL_PATH,
        store_path=FAISS_STORE_PATH,
        index_name=INDEX_NAME,
        info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
              "password": VEC_DB_PASSWORD},
        show_number=SIMILARITY_SHOW_NUMBER,
        reset=False)
ext = QAExt(base_llm)

k_db = PostgresDB(host=VEC_DB_HOST, database=VEC_DB_DBNAME, user=VEC_DB_USER, password=VEC_DB_PASSWORD, port=VEC_DB_PORT)
k_db.connect()

llm_chain = LLMChain(llm=base_llm, prompt=PromptTemplate(input_variables=["history","context", "question"], template=prompt1), llm_kwargs= {"temperature": 0})

63 64
# tool_rag = RAGQuery(vecstore_faiss,ext,PromptTemplate(input_variables=["history","context", "question"], template=prompt_enhancement_history_template),_db=TxtDoc(k_db),_llm_chain=llm_chain)
tools = [RAGQuery(vecstore_faiss,ext,PromptTemplate(input_variables=["history","context", "question"], template=prompt_enhancement_history_template),_db=TxtDoc(k_db),_llm_chain=llm_chain)]
65 66 67
# input_variables=['agent_scratchpad', 'input', 'tool_names', 'tools','chart_tool']
input_variables=[]
input_variables.extend(PROMPT_AGENT_CHAT_HUMAN_VARS)
68
input_variables.extend(PROMPT_AGENT_SYS_VARS)
文靖昊 committed
69 70
input_types={'chat_history': List[Union[langchain_core.messages.ai.AIMessage, langchain_core.messages.human.HumanMessage, langchain_core.messages.chat.ChatMessage, langchain_core.messages.system.SystemMessage, langchain_core.messages.function.FunctionMessage, langchain_core.messages.tool.ToolMessage]]}
messages=[
71
#   SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=['tool_names', 'tools'], template='Respond to the human as helpfully and accurately as possible. You have access to the following tools:\n\n{tools}\n\nUse a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).\n\nValid "action" values: "Final Answer" or {tool_names}\n\nProvide only ONE action per $JSON_BLOB, as shown:\n\n```\n{{\n  "action": $TOOL_NAME,\n  "action_input": $INPUT\n}}\n```\n\nFollow this format:\n\nQuestion: input question to answer\nThought: consider previous and subsequent steps\nAction:\n```\n$JSON_BLOB\n```\nObservation: action result\n... (repeat Thought/Action/Observation N times)\nThought: I know what to respond\nAction:\n```\n{{\n  "action": "Final Answer",\n  "action_input": "Final response to human"\n}}\n\nBegin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation')),
72
  SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=PROMPT_AGENT_SYS_VARS, template=PROMPT_AGENT_SYS)),
文靖昊 committed
73
  MessagesPlaceholder(variable_name='chat_history', optional=True),
74 75
#   HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['agent_scratchpad', 'input'], template='{input}\n\n{agent_scratchpad}\n (reminder to respond in a JSON blob no matter what)'))
  HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=PROMPT_AGENT_CHAT_HUMAN_VARS, template=PROMPT_AGENT_CHAT_HUMAN))
文靖昊 committed
76 77
]

78 79 80 81 82 83 84 85 86

administrative_messages=[
#   SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=['tool_names', 'tools'], template='Respond to the human as helpfully and accurately as possible. You have access to the following tools:\n\n{tools}\n\nUse a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).\n\nValid "action" values: "Final Answer" or {tool_names}\n\nProvide only ONE action per $JSON_BLOB, as shown:\n\n```\n{{\n  "action": $TOOL_NAME,\n  "action_input": $INPUT\n}}\n```\n\nFollow this format:\n\nQuestion: input question to answer\nThought: consider previous and subsequent steps\nAction:\n```\n$JSON_BLOB\n```\nObservation: action result\n... (repeat Thought/Action/Observation N times)\nThought: I know what to respond\nAction:\n```\n{{\n  "action": "Final Answer",\n  "action_input": "Final response to human"\n}}\n\nBegin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation')),
  SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=PROMPT_AGENT_SYS_VARS, template=PROMPT_AGENT_EXTEND_SYS)),
  MessagesPlaceholder(variable_name='chat_history', optional=True),
#   HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['agent_scratchpad', 'input'], template='{input}\n\n{agent_scratchpad}\n (reminder to respond in a JSON blob no matter what)'))
  HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=PROMPT_AGENT_CHAT_HUMAN_VARS, template=PROMPT_AGENT_CHAT_HUMAN))
]

文靖昊 committed
87 88 89
prompt = ChatPromptTemplate(
    input_variables=input_variables,
    input_types=input_types,
90
    # metadata=metadata,
文靖昊 committed
91 92 93 94
    messages=messages
)


95 96 97 98 99 100 101 102 103 104
administrative_prompt = ChatPromptTemplate(
    input_variables=input_variables,
    input_types=input_types,
    # metadata=metadata,
    messages=administrative_messages
)
AdministrativeTools =[AdministrativeDivision()]



105 106
# agent = create_structured_chat_agent(llm=base_llm, tools=tools, prompt=prompt)
agent = create_chart_agent(base_llm, tools, prompt, chart_tool="chart")
107 108 109 110
administrative_agent = create_chart_agent(base_llm, AdministrativeTools, administrative_prompt, chart_tool="chart")
administrative_agent_executor = AgentExecutor(agent=administrative_agent, tools=AdministrativeTools,verbose=True,handle_parsing_errors=True,return_intermediate_steps=True)


111
agent_executor = AgentExecutor(agent=agent, tools=tools,verbose=True,handle_parsing_errors=True,return_intermediate_steps=True)
文靖昊 committed
112
history = []
113 114 115 116 117 118 119 120
# h1 = []
# h1.append("对话背景")
# h1.append("当前对话是关于庄浪县的介绍")
# history.append(h1)
# h1 = []
# h1.append("长沙县年降雨量")
# h1.append("长沙县年雨量平均为50ml")
# history.append(h1)
121

122
# res = agent_executor.invoke({"input":"以下历史对话记录: "+prompt+"以下是问题:"+"攸县、长沙县、化隆县和大通县谁的年平均降雨量大"})
123
res_a = administrative_agent_executor.invoke({"input":"株洲市各个县的最高气温","histories":history})
124
print(res_a)
125 126 127 128 129 130 131 132 133 134 135 136
new_question = res_a['output']
print(new_question)

res = agent_executor.invoke({"input":new_question,"histories":history})

print(res)
docs_json = []
for step in res["intermediate_steps"]:
    if "rag_query" ==step[0].tool:
        j = json.loads(step[1]["参考文档"], strict=False)
        docs_json.extend(j)
print(docs_json)
137