Commit 4d063424 by 文靖昊

Merge remote-tracking branch 'origin/geo' into geo

parents 824785fb 2c36a162
......@@ -58,4 +58,6 @@ exam/
aaa/
bbb/
ccc/
.env
\ No newline at end of file
.env
lae_pg_data
tmp
\ No newline at end of file
......@@ -30,6 +30,7 @@ CREATE TABLE turn_qa (
chat_id varchar(1000),
question text,
answer text,
similar_docs text,
create_time timestamp(6) DEFAULT current_timestamp,
turn_number int,
is_last int2
......@@ -38,6 +39,7 @@ COMMENT ON COLUMN "turn_qa"."turn_id" IS '会话轮次id';
COMMENT ON COLUMN "turn_qa"."chat_id" IS '会话id';
COMMENT ON COLUMN "turn_qa"."question" IS '该轮会话问题';
COMMENT ON COLUMN "turn_qa"."answer" IS '该轮会话答案';
COMMENT ON COLUMN "turn_qa"."similar_docs" IS '该轮会话相似文档 hash 索引';
COMMENT ON COLUMN "turn_qa"."create_time" IS '该轮会话创建时间,默认为当前时间';
COMMENT ON COLUMN "turn_qa"."turn_number" IS '会话轮数';
COMMENT ON COLUMN "turn_qa"."is_last" IS '是否为最后一轮对话:0=否,1=是';
......
......@@ -5,6 +5,8 @@ from fastapi import FastAPI, Header,Query
from fastapi.middleware.cors import CORSMiddleware
from datetime import datetime,timedelta
from src.pgdb.chat.c_db import UPostgresDB
from src.pgdb.knowledge.k_db import PostgresDB
from src.pgdb.knowledge.txt_doc_table import TxtDoc
import uvicorn
import json
from src.pgdb.chat.crud import CRUD
......@@ -13,6 +15,7 @@ from src.server.get_similarity import QAExt
from src.server.qa import QA
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.documents import Document
import re
from src.controller.request import (
PhoneLoginRequest,
......@@ -49,6 +52,9 @@ c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER
port=CHAT_DB_PORT, )
c_db.connect()
k_db = PostgresDB(host=VEC_DB_HOST, database=VEC_DB_DBNAME, user=VEC_DB_USER, password=VEC_DB_PASSWORD, port=VEC_DB_PORT)
k_db.connect()
vecstore_faiss = VectorStore_FAISS(
embedding_model_name=EMBEEDING_MODEL_PATH,
store_path=FAISS_STORE_PATH,
......@@ -127,7 +133,7 @@ def get_history_by_session_id(session_id:str,token: str = Header(None)):
j["Question"] = h[1]
j["Answer"] = h[2]
j["IsLast"] = h[3]
j["SimilarDocuments"] =[]
j["SimilarDocuments"] = get_similarity_doc(h[4])
history_json.append(j)
history_str = json.dumps(history_json)
return {
......@@ -172,19 +178,26 @@ def question(chat_request: ChatRequest, token: str = Header(None)):
prompt += "问:{}\n答:{}\n\n".format(h[0], h[1])
answer, docs = my_chat.chat_with_history_with_ext(question,ext=matches,history=prompt, with_similarity=True)
docs_json = []
doc_hash = []
for d in docs:
j ={}
j["page_content"] = d.page_content
j["from_file"] = d.metadata["filename"]
j["page_number"] = 0
if "hash" in d.metadata:
doc_hash.append(d.metadata["hash"])
docs_json.append(j)
if len(doc_hash)>0:
hash_str = ",".join(doc_hash)
else:
hash_str = ""
# answer = "test Answer"
if session_id =="":
session_id = crud.create_chat(token, '\t\t', '0')
crud.insert_turn_qa(session_id, question, answer, 0, 1)
crud.insert_turn_qa(session_id, question, answer, 0, 1, hash_str)
else:
last_turn_id = crud.get_last_turn_num(str(session_id))
crud.insert_turn_qa(session_id, question, answer, last_turn_id+1, 1)
crud.insert_turn_qa(session_id, question, answer, last_turn_id+1, 1, hash_str)
return {
'code': 200,
......@@ -217,16 +230,24 @@ def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)):
prompt += "问:{}\n答:{}\n\n".format(h[0], h[1])
answer, docs = my_chat.chat_with_history_with_ext(question,ext=matches, history=prompt, with_similarity=True)
docs_json = []
doc_hash = []
for d in docs:
j = {}
j["page_content"] = d.page_content
j["from_file"] = d.metadata["filename"]
j["page_number"] = 0
docs_json.append(j)
if "hash" in d.metadata:
doc_hash.append(d.metadata["hash"])
if len(doc_hash)>0:
hash_str = ",".join(doc_hash)
else:
hash_str = ""
# answer = "reGenerate Answer"
crud.update_turn_last(str(session_id), last_turn_id )
crud.insert_turn_qa(session_id, question, answer, last_turn_id , 1)
crud.insert_turn_qa(session_id, question, answer, last_turn_id, 1, hash_str)
return {
'code': 200,
......@@ -238,5 +259,28 @@ def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)):
}
}
def get_similarity_doc(similarity_doc_hash: str):
if similarity_doc_hash:
hashs = similarity_doc_hash.split(",")
if not similarity_doc_hash or len(hashs) == 0:
return []
docs = []
txt_doc = TxtDoc(k_db)
for h in hashs:
doc = txt_doc.search(h)
d = Document(page_content=doc[0],metadata=json.loads(doc[1]))
docs.append(d)
return docs_to_json(docs)
def docs_to_json(docs):
docs_json = []
for d in docs:
j = {}
j["page_content"] = d.page_content
j["from_file"] = d.metadata["filename"]
j["page_number"] = 0
docs_json.append(j)
return docs_json
if __name__ == "__main__":
uvicorn.run(app, host='0.0.0.0', port=8088)
......@@ -32,6 +32,7 @@ CREATE TABLE turn_qa (
chat_id varchar(1000),
question text,
answer text,
similar_docs text,
create_time timestamp(6) DEFAULT current_timestamp,
turn_number int,
is_last int2
......@@ -40,6 +41,7 @@ COMMENT ON COLUMN "turn_qa"."turn_id" IS '会话轮次id';
COMMENT ON COLUMN "turn_qa"."chat_id" IS '会话id';
COMMENT ON COLUMN "turn_qa"."question" IS '该轮会话问题';
COMMENT ON COLUMN "turn_qa"."answer" IS '该轮会话答案';
COMMENT ON COLUMN "turn_qa"."similar_docs" IS '该轮会话相似文档 hash 索引';
COMMENT ON COLUMN "turn_qa"."create_time" IS '该轮会话创建时间,默认为当前时间';
COMMENT ON COLUMN "turn_qa"."turn_number" IS '会话轮数';
COMMENT ON COLUMN "turn_qa"."is_last" IS '是否为最后一轮对话:0=否,1=是';
......@@ -86,26 +88,26 @@ class CRUD:
self.db.execute(TABLE_USER)
def get_history(self, _chat_id):
query = f'SELECT turn_number,question,answer,is_last FROM turn_qa WHERE chat_id=(%s) ORDER BY turn_number ASC'
query = f'SELECT turn_number,question,answer,is_last,similar_docs FROM turn_qa WHERE chat_id=(%s) ORDER BY turn_number ASC'
self.db.execute_args(query, (_chat_id,))
ans = self.db.fetchall()
return ans
def get_last_history(self, _chat_id):
query = f'SELECT question,answer FROM turn_qa WHERE chat_id=(%s) and is_last=1 ORDER BY turn_number ASC'
query = f'SELECT question,answer,similar_docs FROM turn_qa WHERE chat_id=(%s) and is_last=1 ORDER BY turn_number ASC'
self.db.execute_args(query, (_chat_id,))
ans = self.db.fetchall()
return ans
def get_last_history_before_turn_id(self, _chat_id,turn_id):
query = f'SELECT question,answer FROM turn_qa WHERE chat_id=(%s) and is_last=1 and turn_number<(%s) ORDER BY turn_number ASC'
query = f'SELECT question,answer,similar_docs FROM turn_qa WHERE chat_id=(%s) and is_last=1 and turn_number<(%s) ORDER BY turn_number ASC'
self.db.execute_args(query, (_chat_id,turn_id))
ans = self.db.fetchall()
return ans
def insert_turn_qa(self, chat_id, question, answer, turn_number, is_last):
query = f'INSERT INTO turn_qa(chat_id, question, answer, turn_number, is_last) VALUES (%s,%s,%s,%s,%s)'
self.db.execute_args(query, (chat_id, question, answer, turn_number, is_last))
def insert_turn_qa(self, chat_id, question, answer, turn_number, is_last, similar_docs=None):
query = f'INSERT INTO turn_qa(chat_id, question, answer, turn_number, is_last, similar_docs) VALUES (%s,%s,%s,%s,%s,%s)'
self.db.execute_args(query, (chat_id, question, answer, turn_number, is_last, similar_docs))
......
......@@ -5,6 +5,7 @@ CREATE TABLE turn_qa (
chat_id varchar(1000),
question text,
answer text,
similar_docs text,
create_time timestamp(6) DEFAULT current_timestamp,
turn_number int,
is_last int2
......@@ -13,6 +14,7 @@ COMMENT ON COLUMN "turn_qa"."turn_id" IS '会话轮次id';
COMMENT ON COLUMN "turn_qa"."chat_id" IS '会话id';
COMMENT ON COLUMN "turn_qa"."question" IS '该轮会话问题';
COMMENT ON COLUMN "turn_qa"."answer" IS '该轮会话答案';
COMMENT ON COLUMN "turn_qa"."similar_docs" IS '该轮会话相似文档 hash 索引';
COMMENT ON COLUMN "turn_qa"."create_time" IS '该轮会话创建时间,默认为当前时间';
COMMENT ON COLUMN "turn_qa"."turn_number" IS '会话轮数';
COMMENT ON COLUMN "turn_qa"."is_last" IS '是否为最后一轮对话:0=否,1=是';
......
......@@ -64,8 +64,10 @@ class PgSqlDocstore(Docstore, AddableMixin):
self.__sub_init__()
anwser = self.VEC_TXT.search(search)
content = self.TXT_DOC.search(anwser[0])
meta = json.loads(content[1])
meta.update({"hash": anwser[0]}) # paragraph_id = hash 插入到metadata中,便于后续根据段落查找
if content:
return Document(page_content=content[0], metadata=json.loads(content[1]))
return Document(page_content=content[0], metadata=meta)
else:
return Document()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment