Commit 08987981 by tinywell

chore: Update database schema to include similar_docs column in turn_qa table

parent 6ab32349
...@@ -59,3 +59,5 @@ aaa/ ...@@ -59,3 +59,5 @@ aaa/
bbb/ bbb/
ccc/ ccc/
.env .env
lae_pg_data
tmp
\ No newline at end of file
...@@ -30,6 +30,7 @@ CREATE TABLE turn_qa ( ...@@ -30,6 +30,7 @@ CREATE TABLE turn_qa (
chat_id varchar(1000), chat_id varchar(1000),
question text, question text,
answer text, answer text,
similar_docs text,
create_time timestamp(6) DEFAULT current_timestamp, create_time timestamp(6) DEFAULT current_timestamp,
turn_number int, turn_number int,
is_last int2 is_last int2
...@@ -38,6 +39,7 @@ COMMENT ON COLUMN "turn_qa"."turn_id" IS '会话轮次id'; ...@@ -38,6 +39,7 @@ COMMENT ON COLUMN "turn_qa"."turn_id" IS '会话轮次id';
COMMENT ON COLUMN "turn_qa"."chat_id" IS '会话id'; COMMENT ON COLUMN "turn_qa"."chat_id" IS '会话id';
COMMENT ON COLUMN "turn_qa"."question" IS '该轮会话问题'; COMMENT ON COLUMN "turn_qa"."question" IS '该轮会话问题';
COMMENT ON COLUMN "turn_qa"."answer" IS '该轮会话答案'; COMMENT ON COLUMN "turn_qa"."answer" IS '该轮会话答案';
COMMENT ON COLUMN "turn_qa"."similar_docs" IS '该轮会话相似文档 hash 索引';
COMMENT ON COLUMN "turn_qa"."create_time" IS '该轮会话创建时间,默认为当前时间'; COMMENT ON COLUMN "turn_qa"."create_time" IS '该轮会话创建时间,默认为当前时间';
COMMENT ON COLUMN "turn_qa"."turn_number" IS '会话轮数'; COMMENT ON COLUMN "turn_qa"."turn_number" IS '会话轮数';
COMMENT ON COLUMN "turn_qa"."is_last" IS '是否为最后一轮对话:0=否,1=是'; COMMENT ON COLUMN "turn_qa"."is_last" IS '是否为最后一轮对话:0=否,1=是';
......
...@@ -5,6 +5,8 @@ from fastapi import FastAPI, Header,Query ...@@ -5,6 +5,8 @@ from fastapi import FastAPI, Header,Query
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from datetime import datetime,timedelta from datetime import datetime,timedelta
from src.pgdb.chat.c_db import UPostgresDB from src.pgdb.chat.c_db import UPostgresDB
from src.pgdb.knowledge.k_db import PostgresDB
from src.pgdb.knowledge.txt_doc_table import TxtDoc
import uvicorn import uvicorn
import json import json
from src.pgdb.chat.crud import CRUD from src.pgdb.chat.crud import CRUD
...@@ -13,6 +15,7 @@ from src.server.get_similarity import QAExt ...@@ -13,6 +15,7 @@ from src.server.get_similarity import QAExt
from src.server.qa import QA from src.server.qa import QA
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from langchain_core.documents import Document
import re import re
from src.controller.request import ( from src.controller.request import (
PhoneLoginRequest, PhoneLoginRequest,
...@@ -49,6 +52,9 @@ c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER ...@@ -49,6 +52,9 @@ c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER
port=CHAT_DB_PORT, ) port=CHAT_DB_PORT, )
c_db.connect() c_db.connect()
k_db = PostgresDB(host=VEC_DB_HOST, database=VEC_DB_DBNAME, user=VEC_DB_USER, password=VEC_DB_PASSWORD, port=VEC_DB_PORT)
k_db.connect()
vecstore_faiss = VectorStore_FAISS( vecstore_faiss = VectorStore_FAISS(
embedding_model_name=EMBEEDING_MODEL_PATH, embedding_model_name=EMBEEDING_MODEL_PATH,
store_path=FAISS_STORE_PATH, store_path=FAISS_STORE_PATH,
...@@ -127,7 +133,7 @@ def get_history_by_session_id(session_id:str,token: str = Header(None)): ...@@ -127,7 +133,7 @@ def get_history_by_session_id(session_id:str,token: str = Header(None)):
j["Question"] = h[1] j["Question"] = h[1]
j["Answer"] = h[2] j["Answer"] = h[2]
j["IsLast"] = h[3] j["IsLast"] = h[3]
j["SimilarDocuments"] =[] j["SimilarDocuments"] = get_similarity_doc(h[4])
history_json.append(j) history_json.append(j)
history_str = json.dumps(history_json) history_str = json.dumps(history_json)
return { return {
...@@ -172,19 +178,26 @@ def question(chat_request: ChatRequest, token: str = Header(None)): ...@@ -172,19 +178,26 @@ def question(chat_request: ChatRequest, token: str = Header(None)):
prompt += "问:{}\n答:{}\n\n".format(h[0], h[1]) prompt += "问:{}\n答:{}\n\n".format(h[0], h[1])
answer, docs = my_chat.chat_with_history_with_ext(question,ext=matches,history=prompt, with_similarity=True) answer, docs = my_chat.chat_with_history_with_ext(question,ext=matches,history=prompt, with_similarity=True)
docs_json = [] docs_json = []
doc_hash = []
for d in docs: for d in docs:
j ={} j ={}
j["page_content"] = d.page_content j["page_content"] = d.page_content
j["from_file"] = d.metadata["filename"] j["from_file"] = d.metadata["filename"]
j["page_number"] = 0 j["page_number"] = 0
if "hash" in d.metadata:
doc_hash.append(d.metadata["hash"])
docs_json.append(j) docs_json.append(j)
if len(doc_hash)>0:
hash_str = ",".join(doc_hash)
else:
hash_str = ""
# answer = "test Answer" # answer = "test Answer"
if session_id =="": if session_id =="":
session_id = crud.create_chat(token, '\t\t', '0') session_id = crud.create_chat(token, '\t\t', '0')
crud.insert_turn_qa(session_id, question, answer, 0, 1) crud.insert_turn_qa(session_id, question, answer, 0, 1, hash_str)
else: else:
last_turn_id = crud.get_last_turn_num(str(session_id)) last_turn_id = crud.get_last_turn_num(str(session_id))
crud.insert_turn_qa(session_id, question, answer, last_turn_id+1, 1) crud.insert_turn_qa(session_id, question, answer, last_turn_id+1, 1, hash_str)
return { return {
'code': 200, 'code': 200,
...@@ -217,16 +230,24 @@ def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)): ...@@ -217,16 +230,24 @@ def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)):
prompt += "问:{}\n答:{}\n\n".format(h[0], h[1]) prompt += "问:{}\n答:{}\n\n".format(h[0], h[1])
answer, docs = my_chat.chat_with_history_with_ext(question,ext=matches, history=prompt, with_similarity=True) answer, docs = my_chat.chat_with_history_with_ext(question,ext=matches, history=prompt, with_similarity=True)
docs_json = [] docs_json = []
doc_hash = []
for d in docs: for d in docs:
j = {} j = {}
j["page_content"] = d.page_content j["page_content"] = d.page_content
j["from_file"] = d.metadata["filename"] j["from_file"] = d.metadata["filename"]
j["page_number"] = 0 j["page_number"] = 0
docs_json.append(j) docs_json.append(j)
if "hash" in d.metadata:
doc_hash.append(d.metadata["hash"])
if len(doc_hash)>0:
hash_str = ",".join(doc_hash)
else:
hash_str = ""
# answer = "reGenerate Answer" # answer = "reGenerate Answer"
crud.update_turn_last(str(session_id), last_turn_id ) crud.update_turn_last(str(session_id), last_turn_id )
crud.insert_turn_qa(session_id, question, answer, last_turn_id , 1) crud.insert_turn_qa(session_id, question, answer, last_turn_id, 1, hash_str)
return { return {
'code': 200, 'code': 200,
...@@ -238,5 +259,27 @@ def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)): ...@@ -238,5 +259,27 @@ def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)):
} }
} }
def get_similarity_doc(similarity_doc_hash: str):
hashs = similarity_doc_hash.split(",")
if len(hashs) == 0:
return []
docs = []
txt_doc = TxtDoc(k_db)
for h in hashs:
doc = txt_doc.search(h)
d = Document(page_content=doc[0],metadata=json.loads(doc[1]))
docs.append(d)
return docs_to_json(docs)
def docs_to_json(docs):
docs_json = []
for d in docs:
j = {}
j["page_content"] = d.page_content
j["from_file"] = d.metadata["filename"]
j["page_number"] = 0
docs_json.append(j)
return docs_json
if __name__ == "__main__": if __name__ == "__main__":
uvicorn.run(app, host='0.0.0.0', port=8088) uvicorn.run(app, host='0.0.0.0', port=8088)
...@@ -32,6 +32,7 @@ CREATE TABLE turn_qa ( ...@@ -32,6 +32,7 @@ CREATE TABLE turn_qa (
chat_id varchar(1000), chat_id varchar(1000),
question text, question text,
answer text, answer text,
similar_docs text,
create_time timestamp(6) DEFAULT current_timestamp, create_time timestamp(6) DEFAULT current_timestamp,
turn_number int, turn_number int,
is_last int2 is_last int2
...@@ -40,6 +41,7 @@ COMMENT ON COLUMN "turn_qa"."turn_id" IS '会话轮次id'; ...@@ -40,6 +41,7 @@ COMMENT ON COLUMN "turn_qa"."turn_id" IS '会话轮次id';
COMMENT ON COLUMN "turn_qa"."chat_id" IS '会话id'; COMMENT ON COLUMN "turn_qa"."chat_id" IS '会话id';
COMMENT ON COLUMN "turn_qa"."question" IS '该轮会话问题'; COMMENT ON COLUMN "turn_qa"."question" IS '该轮会话问题';
COMMENT ON COLUMN "turn_qa"."answer" IS '该轮会话答案'; COMMENT ON COLUMN "turn_qa"."answer" IS '该轮会话答案';
COMMENT ON COLUMN "turn_qa"."similar_docs" IS '该轮会话相似文档 hash 索引';
COMMENT ON COLUMN "turn_qa"."create_time" IS '该轮会话创建时间,默认为当前时间'; COMMENT ON COLUMN "turn_qa"."create_time" IS '该轮会话创建时间,默认为当前时间';
COMMENT ON COLUMN "turn_qa"."turn_number" IS '会话轮数'; COMMENT ON COLUMN "turn_qa"."turn_number" IS '会话轮数';
COMMENT ON COLUMN "turn_qa"."is_last" IS '是否为最后一轮对话:0=否,1=是'; COMMENT ON COLUMN "turn_qa"."is_last" IS '是否为最后一轮对话:0=否,1=是';
...@@ -86,26 +88,26 @@ class CRUD: ...@@ -86,26 +88,26 @@ class CRUD:
self.db.execute(TABLE_USER) self.db.execute(TABLE_USER)
def get_history(self, _chat_id): def get_history(self, _chat_id):
query = f'SELECT turn_number,question,answer,is_last FROM turn_qa WHERE chat_id=(%s) ORDER BY turn_number ASC' query = f'SELECT turn_number,question,answer,is_last,similar_docs FROM turn_qa WHERE chat_id=(%s) ORDER BY turn_number ASC'
self.db.execute_args(query, (_chat_id,)) self.db.execute_args(query, (_chat_id,))
ans = self.db.fetchall() ans = self.db.fetchall()
return ans return ans
def get_last_history(self, _chat_id): def get_last_history(self, _chat_id):
query = f'SELECT question,answer FROM turn_qa WHERE chat_id=(%s) and is_last=1 ORDER BY turn_number ASC' query = f'SELECT question,answer,similar_docs FROM turn_qa WHERE chat_id=(%s) and is_last=1 ORDER BY turn_number ASC'
self.db.execute_args(query, (_chat_id,)) self.db.execute_args(query, (_chat_id,))
ans = self.db.fetchall() ans = self.db.fetchall()
return ans return ans
def get_last_history_before_turn_id(self, _chat_id,turn_id): def get_last_history_before_turn_id(self, _chat_id,turn_id):
query = f'SELECT question,answer FROM turn_qa WHERE chat_id=(%s) and is_last=1 and turn_number<(%s) ORDER BY turn_number ASC' query = f'SELECT question,answer,similar_docs FROM turn_qa WHERE chat_id=(%s) and is_last=1 and turn_number<(%s) ORDER BY turn_number ASC'
self.db.execute_args(query, (_chat_id,turn_id)) self.db.execute_args(query, (_chat_id,turn_id))
ans = self.db.fetchall() ans = self.db.fetchall()
return ans return ans
def insert_turn_qa(self, chat_id, question, answer, turn_number, is_last): def insert_turn_qa(self, chat_id, question, answer, turn_number, is_last, similar_docs=None):
query = f'INSERT INTO turn_qa(chat_id, question, answer, turn_number, is_last) VALUES (%s,%s,%s,%s,%s)' query = f'INSERT INTO turn_qa(chat_id, question, answer, turn_number, is_last, similar_docs) VALUES (%s,%s,%s,%s,%s,%s)'
self.db.execute_args(query, (chat_id, question, answer, turn_number, is_last)) self.db.execute_args(query, (chat_id, question, answer, turn_number, is_last, similar_docs))
......
...@@ -5,6 +5,7 @@ CREATE TABLE turn_qa ( ...@@ -5,6 +5,7 @@ CREATE TABLE turn_qa (
chat_id varchar(1000), chat_id varchar(1000),
question text, question text,
answer text, answer text,
similar_docs text,
create_time timestamp(6) DEFAULT current_timestamp, create_time timestamp(6) DEFAULT current_timestamp,
turn_number int, turn_number int,
is_last int2 is_last int2
...@@ -13,6 +14,7 @@ COMMENT ON COLUMN "turn_qa"."turn_id" IS '会话轮次id'; ...@@ -13,6 +14,7 @@ COMMENT ON COLUMN "turn_qa"."turn_id" IS '会话轮次id';
COMMENT ON COLUMN "turn_qa"."chat_id" IS '会话id'; COMMENT ON COLUMN "turn_qa"."chat_id" IS '会话id';
COMMENT ON COLUMN "turn_qa"."question" IS '该轮会话问题'; COMMENT ON COLUMN "turn_qa"."question" IS '该轮会话问题';
COMMENT ON COLUMN "turn_qa"."answer" IS '该轮会话答案'; COMMENT ON COLUMN "turn_qa"."answer" IS '该轮会话答案';
COMMENT ON COLUMN "turn_qa"."similar_docs" IS '该轮会话相似文档 hash 索引';
COMMENT ON COLUMN "turn_qa"."create_time" IS '该轮会话创建时间,默认为当前时间'; COMMENT ON COLUMN "turn_qa"."create_time" IS '该轮会话创建时间,默认为当前时间';
COMMENT ON COLUMN "turn_qa"."turn_number" IS '会话轮数'; COMMENT ON COLUMN "turn_qa"."turn_number" IS '会话轮数';
COMMENT ON COLUMN "turn_qa"."is_last" IS '是否为最后一轮对话:0=否,1=是'; COMMENT ON COLUMN "turn_qa"."is_last" IS '是否为最后一轮对话:0=否,1=是';
......
...@@ -64,8 +64,10 @@ class PgSqlDocstore(Docstore, AddableMixin): ...@@ -64,8 +64,10 @@ class PgSqlDocstore(Docstore, AddableMixin):
self.__sub_init__() self.__sub_init__()
anwser = self.VEC_TXT.search(search) anwser = self.VEC_TXT.search(search)
content = self.TXT_DOC.search(anwser[0]) content = self.TXT_DOC.search(anwser[0])
meta = json.loads(content[1])
meta.update({"hash": anwser[0]}) # paragraph_id = hash 插入到metadata中,便于后续根据段落查找
if content: if content:
return Document(page_content=content[0], metadata=json.loads(content[1])) return Document(page_content=content[0], metadata=meta)
else: else:
return Document() return Document()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment