Commit ad504103 by 陈正乐

加入知识入库功能,以及知识库配置

parent 558e898c
......@@ -68,3 +68,6 @@ deps/punkt.zip
deps/nltk.py
deps/averaged_perceptron_tagger.tar.gz
deps/punkt.tar.gz
# faiss数据库存储文件
faiss/
\ No newline at end of file
......@@ -86,7 +86,7 @@ def is_possible_title(
return True
def zh_title_enhance(docs: List[Document]) -> List[Document] | None:
def zh_title_enhance(docs: List[Document]):
title = None
if len(docs) > 0:
for doc in docs:
......
VEC_DB_HOST = 'localhost'
VEC_DB_DBNAME='lae'
VEC_DB_USER='postgres'
VEC_DB_PASSWORD='chenzl'
VEC_DB_PORT='5432'
EMBEEDING_MODEL_PATH = 'C:\\Users\\15663\\AI\\models\\bge-large-zh-v1.5'
LLM_SERVER_URL = '192.168.10.102:8002'
SIMILARITY_SHOW_NUMBER = 5
SIMILARITY_THRESHOLD = 0.8
FAISS_STORE_PATH = '../faiss'
KNOWLEDGE_PATH = 'C:\\Users\\15663\\Desktop\\work\\llm_gjjs\\兴火燎原知识库\\兴火燎原知识库\\law\\pdf'
INDEX_NAME = 'know'
\ No newline at end of file
# -*- coding: utf-8 -*-
import sys
sys.path.append("../..")
import pandas as pd
import gradio as gr
import argparse
from llm.chatglm import ChatGLMSerLLM
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from typing import Awaitable
import asyncio
from langchain.callbacks import AsyncIteratorCallbackHandler
import re
import json
from llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion
prompt1 = """'''
{context}
'''
请你根据上述已知资料回答下面的问题,问题如下:
{question}"""
PROMPT1 = PromptTemplate(input_variables=["context", "question"],template=prompt1)
def chat(input_text):
global llmchain
if not input_text:
return ""
result = llmchain.run({"context":input_text})
return result
async def async_chat_stc(input_text):
global base_llm1, base_llm2, step
qianfanchain_stc = LLMChain(llm=base_llm2, prompt=PROMPT1,llm_kwargs={"temperature":0.9})
callback = AsyncIteratorCallbackHandler()
async def wrap_done(fn: Awaitable, event: asyncio.Event):
try:
await fn
except Exception as e:
import traceback
traceback.print_exc()
print(f"Caught exception: {e}")
finally:
event.set()
task = asyncio.create_task(wrap_done(qianfanchain_stc.arun({"context":input_text},callbacks=[callback]),callback.done))
print("*"*20)
text=""
async for token in callback.aiter():
text=text+token
yield f"{text}"
await task
\ No newline at end of file
import psycopg2
from psycopg2 import OperationalError, InterfaceError
class UPostgresDB:
'''
psycopg2.connect(
dsn #指定连接参数。可以使用参数形式或 DSN 形式指定。
host #指定连接数据库的主机名。
dbname #指定数据库名。
user #指定连接数据库使用的用户名。
password #指定连接数据库使用的密码。
port #指定连接数据库的端口号。
connection_factory #指定创建连接对象的工厂类。
cursor_factory #指定创建游标对象的工厂类。
async_ #指定是否异步连接(默认False)。
sslmode #指定 SSL 模式。
sslrootcert #指定证书文件名。
sslkey #指定私钥文件名。
sslcert #指定公钥文件名。
)
'''
def __init__(self, host, database, user, password,port = 5432):
self.host = host
self.database = database
self.user = user
self.password = password
self.port = port
self.conn = None
self.cur = None
def connect(self):
try:
self.conn = psycopg2.connect(
host=self.host,
database=self.database,
user=self.user,
password=self.password,
port = self.port
)
self.cur = self.conn.cursor()
except Exception as e:
print(f"连接数据库出现错误: {e}")
def execute(self, query):
try:
if self.conn is None or self.conn.closed:
self.connect()
self.cur.execute(query)
self.conn.commit()
except InterfaceError as e:
print(f"数据库连接已经关闭: {e}")
except OperationalError as e:
print(f"数据库连接出现问题: {e}")
self.connect()
self.retry_execute(query)
except Exception as e:
print(f"执行sql语句出现错误: {e}")
self.conn.rollback()
def retry_execute(self, query):
try:
self.cur.execute(query)
self.conn.commit()
except Exception as e:
print(f"重新执行sql语句再次出现错误: {type(e).__name__}: {e}")
self.conn.rollback()
def execute_args(self, query, args):
try:
if self.conn is None or self.conn.closed:
self.connect()
self.cur.execute(query, args)
self.conn.commit()
except InterfaceError as e:
print(f"数据库连接已经关闭: {e}")
except OperationalError as e:
print(f"数据库操作出现问题: {e}")
self.connect()
self.retry_execute_args(query, args)
except Exception as e:
print(f"执行sql语句出现错误: {e}")
self.conn.rollback()
def retry_execute_args(self, query, args):
try:
self.cur.execute(query, args)
self.conn.commit()
except Exception as e:
print(f"重新执行sql语句再次出现错误: {type(e).__name__}: {e}")
self.conn.rollback()
def search(self, query, params=None):
if self.conn is None or self.conn.closed:
self.connect()
self.cur.execute(query, params)
def fetchall(self):
return self.cur.fetchall()
def fetchone(self):
return self.cur.fetchone()
def close(self):
self.cur.close()
self.conn.close()
def format(self, query):
try:
if self.conn is None or self.conn.closed:
self.connect()
self.cur.execute(query)
self.conn.commit()
except Exception as e:
print(f"An error occurred: {e}")
self.conn.rollback()
from .c_db import UPostgresDB
import json
TABLE_CHAT = """
create table chat (
id varchar(1000) primary key,
user_id int,
info text,
chat_type_id int,
create_time date,
history json,
is_delete int,
status int
);
"""
class Chat:
def __init__(self, db: UPostgresDB) -> None:
self.db = db
def insert(self, value):
value[5] = json.dumps(value[5])
query = f"INSERT INTO chat(id,user_id,info,chat_type_id,create_time,history,is_delete,status) VALUES (%s,%s,%s,%s,%s,%s,%s,%s)"
self.db.execute_args(query, ((value[0],value[1],value[2],value[3],value[4],value[5],value[6],value[7])))
def delete_update(self,id):
query = f"UPDATE chat SET is_delete = 1 WHERE id = %s"
self.db.execute_args(query, (id,))
def search(self, id):
query = f"SELECT chat.id,user_id,info,t.name as type,chat.create_time,history,is_delete,status FROM chat left join chat_type t on t.id=chat.chat_type_id WHERE chat.id = %s and is_delete=0 ORDER BY chat.create_time DESC;"
self.db.execute_args(query, (id,))
answer = self.db.fetchall()
if len(answer) > 0:
return answer[0]
else:
return None
def manual_search(self, id):
query = f"SELECT user_id,status,history FROM chat WHERE id = %s and is_delete=0"
self.db.execute_args(query, (id,))
answer = self.db.fetchall()
if len(answer) > 0:
return answer[0]
else:
return None
def qa_search(self, id):
query = f"SELECT user_id,info,t.name as type,chat.create_time,history,status FROM chat left join chat_type t on t.id=chat.chat_type_id WHERE chat.id = %s and is_delete=0 ORDER BY chat.create_time DESC;"
self.db.execute_args(query, (id,))
answer = self.db.fetchall()
if len(answer) > 0:
return answer[0]
else:
return None
def detail_search(self, id):
query = f"SELECT chat.id,user_id,info,t.name as type,chat.create_time,history,status FROM chat left join chat_type t on t.id=chat.chat_type_id WHERE chat.id = %s and is_delete=0 ORDER BY chat.create_time DESC;"
self.db.execute_args(query, (id,))
answer = self.db.fetchall()
if len(answer) > 0:
return answer[0]
else:
return None
def delete_search(self, id):
query = f"SELECT user_id FROM chat WHERE id = %s and is_delete=0"
self.db.execute_args(query, (id,))
answer = self.db.fetchall()
if len(answer) > 0:
return answer[0]
else:
return None
def list_chat_search(self, chat_type, user_id):
query = f"SELECT chat.id,info,t.name as type,chat.create_time,history,chat.status FROM chat left join chat_type t on t.id=chat.chat_type_id WHERE t.name = %s AND is_delete=0 AND user_id = %s ORDER BY chat.create_time DESC;"
self.db.execute_args(query, (chat_type, user_id))
answer = self.db.fetchall()
if len(answer) > 0:
return answer
else:
return None
def search_history(self,id):
query = f"SELECT history FROM chat WHERE id = %s ORDER BY chat.create_time DESC;"
self.db.execute_args(query, (id,))
answer = self.db.fetchall()
if len(answer) > 0:
answer[0] = answer[0][0]
return answer[0]
else:
return None
def get_last_q(self):
query = f"SELECT id,create_time,info,chat_type_id FROM chat ORDER BY create_time DESC LIMIT 1 "
self.db.execute(query)
answer = self.db.fetchall()
if len(answer) > 0:
return answer[0]
else:
return None
def create_table(self):
query = f"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'chat')"
self.db.execute(query)
exists = self.db.fetchall()[0][0]
if not exists:
query = TABLE_CHAT
self.db.execute(query)
def drop_table(self):
query = f"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'chat')"
self.db.execute(query)
exists = self.db.fetchall()[0][0]
if exists:
query = "DROP TABLE chat"
self.db.format(query)
print("drop table chat ok")
def update(self, chat_id, history):
history = json.dumps(history)
query = f"UPDATE chat SET history = %s WHERE id = %s"
self.db.execute_args(query, (history,chat_id))
def history_update(self, chat_id, history):
history = json.dumps(history)
query = f"UPDATE chat SET history = %s WHERE id = %s"
self.db.execute_args(query, (history,chat_id))
def search_type_id(self, chat_id):
query = f"select chat_type_id from chat where id = %s"
self.db.execute_args(query, (chat_id,))
answer = self.db.fetchall()
if len(answer) > 0:
return answer[0][0]
else:
return None
def get_chat_status(self, chat_id):
query = f"Select status from chat where id = %s"
self.db.execute_args(query, (chat_id,))
answer = self.db.fetchall()
if len(answer) > 0:
return answer[0][0]
else:
return None
def set_chat_status(self, chat_id, status):
query = f"update chat set status = %s where id = %s"
self.db.execute_args(query, (status,chat_id))
\ No newline at end of file
import sys
sys.path.append("../../..")
import time
from llm.chatglm import ChatGLMSerLLM
from loader.load import loads_path,loads
from vector.pgsql.db import PostgresDB
from vector.similarity import VectorStore_FAISS
from scenarios.lae.common.consts import (
VEC_DB_DBNAME,
VEC_DB_HOST,
VEC_DB_PASSWORD,
VEC_DB_PORT,
VEC_DB_USER,
EMBEEDING_MODEL_PATH,
FAISS_STORE_PATH,
SIMILARITY_SHOW_NUMBER,
KNOWLEDGE_PATH,
INDEX_NAME
)
from loader.callback import BaseCallback
# 当返回值中带有“思考题”字样的时候,默认将其忽略。
class localCallback(BaseCallback):
def filter(self,title:str,content:str) -> bool:
if len(title+content) == 0:
return True
return (len(title+content) / (len(title.splitlines())+len(content.splitlines())) < 20) or "思考题" in title
def test_faiss_from_dir():
vecstore_faiss = VectorStore_FAISS(
embedding_model_name=EMBEEDING_MODEL_PATH,
store_path=FAISS_STORE_PATH,
index_name=INDEX_NAME,
info={"port":VEC_DB_PORT,"host":VEC_DB_HOST,"dbname":VEC_DB_DBNAME,"username":VEC_DB_USER,"password":VEC_DB_PASSWORD},
show_number=3,
reset=True)
docs = loads_path(KNOWLEDGE_PATH,mode="paged",sentence_size=512,callbacks=[localCallback()])
print(len(docs))
last_doc = None
docs1 = []
for doc in docs:
if not last_doc:
last_doc = doc
continue
if "font-size" not in doc.metadata or "page_number" not in doc.metadata:
continue
if doc.metadata["font-size"] == last_doc.metadata["font-size"] and doc.metadata["page_number"] == last_doc.metadata["page_number"] and len(doc.page_content)+len(last_doc.page_content) < 512/4*3:
last_doc.page_content += doc.page_content
else:
docs1.append(last_doc)
last_doc = doc
if last_doc:
docs1.append(last_doc)
docs = docs1
print(len(docs))
print(vecstore_faiss._faiss.index.ntotal)
for i in range(0, len(docs), 300):
vecstore_faiss._add_documents(docs[i:i+300 if i+300<len(docs) else len(docs)],need_split=True)
print(vecstore_faiss._faiss.index.ntotal)
vecstore_faiss._save_local()
def test_faiss_load():
vecstore_faiss = VectorStore_FAISS(
embedding_model_name=EMBEEDING_MODEL_PATH,
store_path=FAISS_STORE_PATH,
index_name=INDEX_NAME,
info={"port":VEC_DB_PORT,"host":VEC_DB_HOST,"dbname":VEC_DB_DBNAME,"username":VEC_DB_USER,"password":VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER,
reset=False)
print(vecstore_faiss._join_document(vecstore_faiss.get_text_similarity("请介绍一下你理解的国际结算业务")))
if __name__ == "__main__":
test_faiss_from_dir()
test_faiss_load()
\ No newline at end of file
import sys
sys.path.append("../../..")
from vector.similarity import VectorStore_FAISS
from argparse import Namespace
from lae.common.consts import (
VEC_DB_DBNAME,
VEC_DB_HOST,
VEC_DB_PASSWORD,
VEC_DB_PORT,
VEC_DB_USER,
SIMILARITY_SHOW_NUMBER,
SIMILARITY_THRESHOLD,
FAISS_STORE_PATH,
EMBEEDING_MODEL_PATH,
INDEX_NAME
)
cfg = Namespace()
cfg.embeddings_model = EMBEEDING_MODEL_PATH
cfg.store_path = FAISS_STORE_PATH
cfg.index_name = INDEX_NAME
cfg.port = VEC_DB_PORT
cfg.host = VEC_DB_HOST
cfg.dbname = VEC_DB_DBNAME
cfg.username = VEC_DB_USER
cfg.password = VEC_DB_PASSWORD
cfg.show_number = SIMILARITY_SHOW_NUMBER
cfg.threshold = SIMILARITY_THRESHOLD
vecstore_faiss_pk = VectorStore_FAISS(embedding_model_name=cfg.embeddings_model,
store_path=cfg.store_path,
index_name=cfg.index_name,
info={"port":cfg.port,"host":cfg.host,"dbname":cfg.dbname,"username":cfg.username,"password":cfg.password},
show_number=cfg.show_number,
threshold=cfg.threshold)
\ No newline at end of file
DROP TABLE IF EXISTS chat;
CREATE TABLE chat (
id varchar(100) primary key,
info text,
create_time date,
history json,
is_delete int
);
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment