Commit 9c21eae7 by 文靖昊

添加混排,rag工具增加行政区名称参数

parent cc83b288
......@@ -16,6 +16,8 @@ from src.pgdb.knowledge.txt_doc_table import TxtDoc
from langchain_core.documents import Document
import json
from src.agent.tool_divisions import AdministrativeDivision
from src.llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion
from src.config.consts import (
RERANK_MODEL_PATH,
CHAT_DB_USER,
......@@ -1557,6 +1559,7 @@ def jieba_split(text: str) -> str:
class IssuanceArgs(BaseModel):
question: str = Field(description="对话问题")
history: str = Field(description="历史对话记录")
location: list = Field(description="行政区名称")
class RAGQuery(BaseTool):
......@@ -1582,22 +1585,24 @@ class RAGQuery(BaseTool):
def get_similarity_with_ext_origin(self, _ext):
return GetSimilarityWithExt(_question=_ext, _faiss_db=self.faiss_db)
def _run(self, question: str, history: str) -> str:
split_str = jieba_split(question)
split_list = []
for l in split_str:
split_list.append(l)
answer = self.db.find_like_doc(split_list)
print(answer)
def _run(self, question: str, history: str,location:list) -> str:
print(location)
# split_str = jieba_split(question)
# split_list = []
# for l in split_str:
# split_list.append(l)
answer = self.db.find_like_doc(location)
split_docs = []
for a in answer:
d = Document(page_content=a[0], metadata=json.loads(a[1]))
split_docs.append(d)
print(split_docs)
print(len(split_docs))
if len(split_docs)>100:
split_docs= split_docs[:100]
result = self.rerank.extend_query_with_str(question, history)
matches = re.findall(r'"([^"]+)"', result.content)
if len(matches) > 3:
matches = matches[:3]
print(result)
matches = re.findall(r'"([^"]+)"', result)
print(matches)
similarity = self.get_similarity_with_ext_origin(matches)
# cur_similarity = similarity.get_rerank(self.rerank_model)
......@@ -1606,13 +1611,15 @@ class RAGQuery(BaseTool):
return cur_question
base_llm = ChatOpenAI(
openai_api_key='xxxxxxxxxxxxx',
openai_api_base='http://192.168.10.14:8000/v1',
model_name='Qwen2-7B',
verbose=True,
temperature=0
)
# base_llm = ChatOpenAI(
# openai_api_key='xxxxxxxxxxxxx',
# openai_api_base='http://192.168.10.14:8000/v1',
# model_name='Qwen2-7B',
# verbose=True,
# temperature=0
# )
base_llm = ChatERNIESerLLM(
chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
vecstore_faiss = VectorStore_FAISS(
embedding_model_name=EMBEEDING_MODEL_PATH,
......@@ -1651,15 +1658,18 @@ agent = create_structured_chat_agent(llm=base_llm, tools=tools, prompt=prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools,verbose=True,handle_parsing_errors=True)
history = []
h1 = []
h1.append("大通县年降雨量")
h1.append("大通县年雨量平均为30ml")
h1.append("攸县年降雨量")
h1.append("攸县年雨量平均为30ml")
history.append(h1)
h1 = []
h1.append("长沙县年降雨量")
h1.append("长沙县年雨量平均为50ml")
history.append(h1)
prompt = ""
for h in history:
prompt += "问:{}\n答:{}\n\n".format(h[0], h[1])
print(prompt)
res = agent_executor.invoke({"input":"以下历史对话记录: "+prompt+"以下是问题:"+"西宁市年平均降雨量"})
res = agent_executor.invoke({"input":"以下历史对话记录: "+prompt+"以下是问题:"+"攸县、长沙县、大通县和化隆县谁的年平均降雨量大"})
print("====== result: ======")
print(res)
......@@ -35,7 +35,7 @@ LLM_SERVER_URL = '192.168.10.102:8002'
# =============================
# FAISS相似性查找配置
# =============================
SIMILARITY_SHOW_NUMBER = 10
SIMILARITY_SHOW_NUMBER = 30
SIMILARITY_THRESHOLD = 0.8
# =============================
......
......@@ -74,7 +74,7 @@ class TxtDoc:
print(item)
query = "select text,matadate FROM txt_doc WHERE text like '%"+i0+"%' "
for i in item:
query+= "and text like '%"+i+"%' "
query+= "or text like '%"+i+"%' "
print(query)
self.db.execute(query)
answer = self.db.fetchall()
......
from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.config.prompts import PROMPT_QUERY_EXTEND,PROMPT_QA_EXTEND_QUESTION
from .rerank import BgeRerank
from src.server.rerank import BgeRerank,reciprocal_rank_fusion
from langchain_core.prompts import PromptTemplate
......@@ -54,20 +54,31 @@ class GetSimilarityWithExt:
def get_rerank_with_doc(self, reranker: BgeRerank,split_doc:list, top_k=5):
question = '\n'.join(self.question)
print(question)
split_doc.extend(self.similarity_docs)
content_set = set()
unique_documents = []
for doc in split_doc:
content = hash(doc.page_content)
if content not in content_set:
unique_documents.append(doc)
content_set.add(content)
rerank_docs = reranker.compress_documents(unique_documents, question)
rerank_docs1 = reranker.compress_documents(split_doc, question)
rerank_docs2 = reranker.compress_documents(self.similarity_docs, question)
rerank_docs1_hash = []
rerank_docs2_hash = []
m = {}
for doc in rerank_docs1:
m[hash(doc.page_content)] = doc
rerank_docs1_hash.append(hash(doc.page_content))
for doc in rerank_docs2:
m[hash(doc.page_content)] = doc
rerank_docs2_hash.append(hash(doc.page_content))
result = []
result.append((60,rerank_docs1_hash))
result.append((55,rerank_docs2_hash))
print(len(rerank_docs1_hash))
print(len(rerank_docs2_hash))
rrf_doc = reciprocal_rank_fusion(result)
print(rrf_doc)
d_list = []
for d in rerank_docs[:top_k]:
d_list.append(d)
self.rerank_docs = rerank_docs[:top_k]
return self.faiss_db.join_document(d_list)
for key in rrf_doc:
d_list.append(m[key])
self.rerank_docs = d_list[:top_k]
return self.faiss_db.join_document(d_list[:top_k])
def get_similarity_doc(self):
return self.similarity_doc_txt
......
......@@ -115,11 +115,11 @@ def reciprocal_rank_fusion(results: list[set]):
]
# for TEST (print reranked documentsand scores)
print("Reranked documents: ", len(reranked_results))
for doc in reranked_results:
print('---')
print('Docs: ', ' '.join(doc[0].page_content[:100].split()))
print('RRF score: ', doc[1])
# print("Reranked documents: ", len(reranked_results))
# for doc in reranked_results:
# print('---')
# print('Docs: ', ' '.join(doc[0].page_content[:100].split()))
# print('RRF score: ', doc[1])
# return only documents
return [x[0] for x in reranked_results[:MAX_DOCS_FOR_CONTEXT]]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment