Commit 6a19aee9 by tinywell

初步删减代码,仅保留 rate agent 相关部分

parent 5c16c92c
......@@ -7,7 +7,7 @@ services:
environment:
POSTGRES_PASSWORD: 111111
ports:
- "5433:5432"
- "5434:5432"
volumes:
- ./lae_pg_data:/var/lib/postgresql/data
- ./init_db/create_db.sql:/docker-entrypoint-initdb.d/create_db.sql
......
......@@ -9,7 +9,6 @@ psycopg2==2.9.7
pydantic==1.10.12
requests==2.31.0
sentence-transformers==2.2.2
torch==2.0.1
transformers==4.31.0
uvicorn==0.23.1
unstructured==0.8.1
......
# =============================
# 资料存储数据库配置
# =============================
VEC_DB_HOST = '192.168.10.189'
# VEC_DB_HOST = '192.168.10.189'
VEC_DB_HOST = 'localhost'
VEC_DB_DBNAME = 'lae'
VEC_DB_USER = 'postgres'
VEC_DB_PASSWORD = '111111'
VEC_DB_PORT = '5433'
VEC_DB_PORT = '5434'
# =============================
# 聊天相关数据库配置
# =============================
CHAT_DB_HOST = '192.168.10.189'
# CHAT_DB_HOST = '192.168.10.189'
CHAT_DB_HOST = "localhost"
CHAT_DB_DBNAME = 'lae'
CHAT_DB_USER = 'postgres'
CHAT_DB_PASSWORD = '111111'
CHAT_DB_PORT = '5433'
CHAT_DB_PORT = '5434'
# =============================
# 向量化模型路径配置
# =============================
EMBEEDING_MODEL_PATH = 'D:\\work\\py\\LAE\\bge-large-zh-v1.5'
# EMBEEDING_MODEL_PATH = 'D:\\work\\py\\LAE\\bge-large-zh-v1.5'
EMBEEDING_MODEL_PATH="BAAI/bge-large-zh-v1.5"
# =============================
# 重排序模型路径配置
# =============================
RERANK_MODEL_PATH = 'D:\\work\\py\\LAE\\bge-reranker-large'
# RERANK_MODEL_PATH = 'BAAI/bge-reranker-large'
# RERANK_MODEL_PATH = 'D:\\work\\py\\LAE\\bge-reranker-large'
RERANK_MODEL_PATH = 'BAAI/bge-reranker-large'
# =============================
# 模型服务URL配置
......@@ -41,7 +44,8 @@ SIMILARITY_THRESHOLD = 0.8
# =============================
# FAISS向量库文件存储路径配置
# =============================
FAISS_STORE_PATH = 'D:\\work\\py\\LAE\\faiss'
# FAISS_STORE_PATH = 'D:\\work\\py\\LAE\\faiss'
FAISS_STORE_PATH = './tmp/vecstore'
INDEX_NAME = 'know'
# =============================
......
import sys
sys.path.append('../')
from fastapi import FastAPI, Header
from src.pgdb.chat.c_db import UPostgresDB
from fastapi.middleware.cors import CORSMiddleware
from src.controller.response import Response
from src.controller.return_data import ReturnData
from src.pgdb.chat.crud import CRUD
import uvicorn
from src.server.qa import QA
from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion
from langchain_core.prompts import PromptTemplate
from src.controller.request import (
RegisterRequest,
LoginRequest,
ChatQaRequest,
ChatDetailRequest,
ChatDeleteRequest,
ChatReQA
)
from src.config.consts import (
CHAT_DB_USER,
CHAT_DB_HOST,
CHAT_DB_PORT,
CHAT_DB_DBNAME,
CHAT_DB_PASSWORD,
EMBEEDING_MODEL_PATH,
FAISS_STORE_PATH,
INDEX_NAME,
VEC_DB_HOST,
VEC_DB_PASSWORD,
VEC_DB_PORT,
VEC_DB_USER,
VEC_DB_DBNAME,
SIMILARITY_SHOW_NUMBER,
prompt1
)
from src.server.agent_rate import new_rate_agent
from src.controller.request import GeoAgentRateRequest
from langchain_openai import ChatOpenAI
app = FastAPI()
# 数据库连接
c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
port=CHAT_DB_PORT, )
c_db.connect()
# 添加 CORS 中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有域访问,也可以指定特定域名
......@@ -50,162 +17,29 @@ app.add_middleware(
allow_headers=["*"], # 允许所有HTTP头
)
base_llm = ChatOpenAI(
openai_api_key='xxxxxxxxxxxxx',
openai_api_base='http://192.168.10.14:8000/v1',
model_name='Qwen2-7B',
verbose=True
)
@app.post("/lae/user/register", response_model=Response)
async def register(request: RegisterRequest):
print(request)
account = request.account
password = request.password
crud = CRUD(_db=c_db)
if crud.user_exist_account(_account=account):
return ReturnData(40000, "当前用户已注册", {}).get_data()
crud.insert_c_user(account=account, password=password)
data = {
"account": account,
}
return ReturnData(200, "用户注册成功", dict(data)).get_data()
@app.get("/lae/user/login", response_model=Response)
async def login(request: LoginRequest):
print(request)
account = request.account
password = request.password
crud = CRUD(_db=c_db)
user_id = crud.user_exist_account_password(_account=account, _password=password)
if not user_id:
return ReturnData(40000, "用户未注册或密码错误", {}).get_data()
token = user_id + '******'
data = {
"account": account,
"token": token
}
return ReturnData(200, "用户登录成功", dict(data)).get_data()
@app.post("/lae/chat/create", response_model=Response)
async def create(token: str = Header(None)):
print(token)
user_id = token.replace('*', '')
crud = CRUD(_db=c_db)
if not crud.user_exist_id(_user_id=user_id):
return ReturnData(40000, "当前用户暂未注册", {}).get_data()
chat_info = '这是该chat的info'
crud.insert_chat(user_id=user_id, info='这是该chat的info', deleted=0)
data = {
"user_id": user_id,
"chat_info": chat_info,
}
return ReturnData(200, '会话创建成功', dict(data)).get_data()
@app.post("/lae/chat/delete", response_model=Response)
async def delete(request: ChatDeleteRequest, token: str = Header(None)):
print(request)
print(token)
user_id = token.replace('*', '')
chat_id = request.chat_id
crud = CRUD(_db=c_db)
if crud.chat_exist_chatid_userid(_chat_id=chat_id, _user_id=user_id):
return ReturnData(40000, "该会话不存在(用户无法访问非本人创建的对话)", {}).get_data()
chat_info = crud.get_chatinfo_from_chatid(chat_id)
crud.delete_chat(chat_id)
data = {
"user_id": user_id,
"chat_info": chat_info,
}
return ReturnData(200, '会话删除成功', dict(data)).get_data()
@app.post("/lae/chat/qa", response_model=Response)
async def qa(request: ChatQaRequest, token: str = Header(None)):
print(request)
print(token)
user_id = token.replace('*', '')
chat_id = request.chat_id
question = request.question
crud = CRUD(_db=c_db)
if not crud.chat_exist_chatid_userid(_chat_id=chat_id, _user_id=user_id):
return ReturnData(40000, "该会话不存在(用户无法访问非本人创建的对话)", {}).get_data()
vecstore_faiss = VectorStore_FAISS(
embedding_model_name=EMBEEDING_MODEL_PATH,
store_path=FAISS_STORE_PATH,
index_name=INDEX_NAME,
info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
"password": VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER,
reset=False)
base_llm = ChatERNIESerLLM(
chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
my_chat = QA(PromptTemplate(input_variables=["context", "question"], template=prompt1), base_llm,
{"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id=chat_id,
_faiss_db=vecstore_faiss)
answer = my_chat.chat(question)
my_chat.update_history()
data = {
"answer": answer
}
return ReturnData(200, '模型问答成功', dict(data)).get_data()
@app.get("/lae/chat/detail", response_model=Response)
async def detail(request: ChatDetailRequest, token: str = Header(None)):
print(request)
print(token)
user_id = token.replace('*', '')
chat_id = request.chat_id
crud = CRUD(_db=c_db)
if not crud.chat_exist_chatid_userid(_chat_id=chat_id, _user_id=user_id):
return ReturnData(40000, "该会话不存在(用户无法访问非本人创建的对话)", {}).get_data()
history = crud.get_history(chat_id)
data = {
"chat_id": chat_id,
"history": history
}
return ReturnData(200, "会话详情获取成功", dict(data)).get_data()
@app.post('/api/agent/rate')
def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
@app.post("/lae/chat/clist", response_model=Response)
async def clist(token: str = Header(None)):
print(token)
user_id = token.replace('*', '')
crud = CRUD(_db=c_db)
chat_list = crud.get_chat_list_userid(user_id)
data = {
"chat_list": chat_list
agent = new_rate_agent(base_llm,verbose=True)
try:
res = agent.exec(prompt_args={"input": chat_request.query})
except Exception as e:
print(f"处理请求失败, 错误信息: {str(e)},请重新提问")
return {
'code': 500,
'data': str(e)
}
return ReturnData(200, "会话列表获取成功", dict(data)).get_data()
@app.post("/lae/chat/reqa", response_model=Response)
async def reqa(request: ChatReQA, token: str = Header(None)):
print(request)
print(token)
chat_id = request.chat_id
user_id = token.replace('*', '')
crud = CRUD(_db=c_db)
if not crud.chat_exist_chatid_userid(_chat_id=chat_id, _user_id=user_id):
return ReturnData(40000, "该会话不存在(用户无法访问非本人创建的对话)", {}).get_data()
question = crud.get_last_question(_chat_id=chat_id)
vecstore_faiss = VectorStore_FAISS(
embedding_model_name=EMBEEDING_MODEL_PATH,
store_path=FAISS_STORE_PATH,
index_name=INDEX_NAME,
info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
"password": VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER,
reset=False)
base_llm = ChatERNIESerLLM(
chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
my_chat = QA(PromptTemplate(input_variables=["context", "question"], template=prompt1), base_llm,
{"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id=chat_id,
_faiss_db=vecstore_faiss)
answer = my_chat.chat(question)
my_chat.update_history()
data = {
"answer": answer
return {
'code': 200,
'data': res
}
return ReturnData(200, '模型重新问答成功', dict(data)).get_data()
if __name__ == "__main__":
uvicorn.run(app, host='localhost', port=8889)
uvicorn.run(app, host='0.0.0.0', port=8088)
import sys
sys.path.append('../')
from fastapi import FastAPI, Header,Query
from fastapi.middleware.cors import CORSMiddleware
from datetime import datetime,timedelta
from src.pgdb.chat.c_db import UPostgresDB
import uvicorn
import json
from src.pgdb.chat.crud import CRUD
from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.pgdb.knowledge.k_db import PostgresDB
from src.pgdb.knowledge.txt_doc_table import TxtDoc
from langchain_openai import ChatOpenAI
from langchain_core.documents import Document
from src.server.rag_query import RagQuery
from src.server.agent_rate import new_rate_agent
from src.controller.request import (
PhoneLoginRequest,
ChatRequest,
ReGenerateRequest,
GeoAgentRateRequest
)
from src.config.consts import (
CHAT_DB_USER,
CHAT_DB_HOST,
CHAT_DB_PORT,
CHAT_DB_DBNAME,
CHAT_DB_PASSWORD,
EMBEEDING_MODEL_PATH,
FAISS_STORE_PATH,
INDEX_NAME,
VEC_DB_HOST,
VEC_DB_PASSWORD,
VEC_DB_PORT,
VEC_DB_USER,
VEC_DB_DBNAME,
SIMILARITY_SHOW_NUMBER,
)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有域访问,也可以指定特定域名
allow_credentials=True,
allow_methods=["*"], # 允许所有HTTP方法
allow_headers=["*"], # 允许所有HTTP头
)
# c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
# port=CHAT_DB_PORT, )
# c_db.connect()
# k_db = PostgresDB(host=VEC_DB_HOST, database=VEC_DB_DBNAME, user=VEC_DB_USER, password=VEC_DB_PASSWORD, port=VEC_DB_PORT)
# k_db.connect()
# vecstore_faiss = VectorStore_FAISS(
# embedding_model_name=EMBEEDING_MODEL_PATH,
# store_path=FAISS_STORE_PATH,
# index_name=INDEX_NAME,
# info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
# "password": VEC_DB_PASSWORD},
# show_number=SIMILARITY_SHOW_NUMBER,
# reset=False)
base_llm = ChatOpenAI(
openai_api_key='xxxxxxxxxxxxx',
openai_api_base='http://192.168.10.14:8000/v1',
model_name='Qwen2-7B',
verbose=True
)
# rag_query = RagQuery(base_llm=base_llm,_faiss_db=vecstore_faiss,_db=TxtDoc(k_db))
@app.post('/api/login')
def login(phone_request: PhoneLoginRequest):
phone = phone_request.phone
crud = CRUD(_db=c_db)
user = crud.user_exist_account(phone)
if not user:
crud.insert_c_user(phone,"123456")
user = crud.user_exist_account(phone)
userid = user[0]
expire = (datetime.now() + timedelta(days=1)).strftime('%Y-%m-%d %H:%M:%S')
return{
'code': 200,
'data': {
'accessToken': userid,
'refreshToken': userid,
'accessExpire': expire,
'refreshExpire': expire,
'accessUUID': 'accessUUID',
'refreshUUID': 'refreshUUID',
}
}
@app.get('/api/sessions/chat/')
def get_sessions(token: str = Header(None)):
if not token:
return {
'code': 404,
'data': '验证失败'
}
crud = CRUD(_db=c_db)
chat_list = crud.get_chat_list_userid(token)
chat_list_str = []
for chat in chat_list:
chat_list_str.append(str(chat[0]))
return {
'code': 200,
'data': chat_list_str
}
@app.get('/api/session/{session_id}')
def get_history_by_session_id(session_id:str,token: str = Header(None)):
if not token:
return {
'code': 404,
'data': '验证失败'
}
crud = CRUD(_db=c_db)
history = crud.get_history(session_id)
history_json = []
for h in history:
j ={}
j["TurnID"] = h[0]
j["Question"] = h[1]
j["Answer"] = h[2]
j["IsLast"] = h[3]
j["SimilarDocuments"] = get_similarity_doc(h[4])
history_json.append(j)
history_str = json.dumps(history_json)
return {
'code': 200,
'data': history_str
}
@app.delete('/api/session/{session_id}')
def delete_session_by_session_id(session_id:str,token: str = Header(None)):
if not token:
return {
'code': 404,
'data': '验证失败'
}
crud = CRUD(_db=c_db)
crud.delete_chat(session_id)
return {
'code': 200,
'data': 'success'
}
@app.post('/api/general/chat')
def question(chat_request: ChatRequest, token: str = Header(None)):
if not token:
return {
'code': 404,
'data': '验证失败'
}
session_id = chat_request.sessionID
question = chat_request.question
crud = CRUD(_db=c_db)
history = []
if session_id !="":
history = crud.get_last_history(str(session_id))
prompt = ""
for h in history:
prompt += "Q: {}\nA:{}\n".format(h[0], h[1])
res = rag_query.query(question=question,history=prompt)
answer = res["answer"]
docs = res["docs"]
docs_json = json.loads(docs, strict=False)
print(len(docs_json))
doc_hash = []
for d in docs_json:
if "hash" in d:
doc_hash.append(d["hash"])
if len(doc_hash)>0:
hash_str = ",".join(doc_hash)
else:
hash_str = ""
# answer = "test Answer"
if session_id =="":
session_id = crud.create_chat(token, '\t\t', '0')
crud.insert_turn_qa(session_id, question, answer, 0, 1, hash_str)
else:
last_turn_id = crud.get_last_turn_num(str(session_id))
crud.insert_turn_qa(session_id, question, answer, last_turn_id+1, 1, hash_str)
return {
'code': 200,
'data': {
'question': question,
'answer': answer,
'sessionID': session_id,
'similarity': docs_json
}
}
@app.post('/api/general/regenerate')
def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)):
if not token:
return {
'code': 404,
'data': '验证失败'
}
session_id = chat_request.sessionID
question = chat_request.question
crud = CRUD(_db=c_db)
last_turn_id = crud.get_last_turn_num(str(session_id))
history = crud.get_last_history_before_turn_id(str(session_id),last_turn_id)
prompt = ""
for h in history:
prompt += "Q: {}\nA:{}\n".format(h[0], h[1])
res = rag_query.query(question=question, history=prompt)
answer = res["answer"]
docs = res["docs"]
docs_json = json.loads(docs, strict=False)
doc_hash = []
for d in docs_json:
if "hash" in d:
doc_hash.append(d["hash"])
if len(doc_hash)>0:
hash_str = ",".join(doc_hash)
else:
hash_str = ""
# answer = "reGenerate Answer"
crud.update_turn_last(str(session_id), last_turn_id )
crud.insert_turn_qa(session_id, question, answer, last_turn_id, 1, hash_str)
return {
'code': 200,
'data': {
'question': question,
'answer': answer,
'sessionID': session_id,
'similarity': docs_json
}
}
@app.post('/api/agent/rate')
def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
agent = new_rate_agent(base_llm,verbose=True)
try:
res = agent.exec(prompt_args={"input": chat_request.query})
except Exception as e:
print(f"处理请求失败, 错误信息: {str(e)},请重新提问")
return {
'code': 500,
'data': str(e)
}
return {
'code': 200,
'data': res
}
def get_similarity_doc(similarity_doc_hash: str):
if similarity_doc_hash:
hashs = similarity_doc_hash.split(",")
if not similarity_doc_hash or len(hashs) == 0:
return []
docs = []
txt_doc = TxtDoc(k_db)
for h in hashs:
doc = txt_doc.search(h)
if doc is None:
continue
d = Document(page_content=doc[0],metadata=json.loads(doc[1]))
docs.append(d)
return docs_to_json(docs)
def docs_to_json(docs):
docs_json = []
for d in docs:
j = {}
j["page_content"] = d.page_content
j["from_file"] = d.metadata["filename"]
j["page_number"] = 0
docs_json.append(j)
return docs_json
if __name__ == "__main__":
uvicorn.run(app, host='0.0.0.0', port=8088)
# -*- coding: utf-8 -*-
import sys
sys.path.append('../')
import gradio as gr
from langchain_core.prompts import PromptTemplate
from src.llm.chatglm import ChatGLMSerLLM
from src.llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion
import os
import asyncio
from src.pgdb.chat.c_db import UPostgresDB
from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.config.consts import (
CHAT_DB_USER,
CHAT_DB_HOST,
CHAT_DB_PORT,
CHAT_DB_DBNAME,
CHAT_DB_PASSWORD,
EMBEEDING_MODEL_PATH,
FAISS_STORE_PATH,
INDEX_NAME,
VEC_DB_HOST,
VEC_DB_PASSWORD,
VEC_DB_PORT,
VEC_DB_USER,
VEC_DB_DBNAME,
SIMILARITY_SHOW_NUMBER,
GR_PORT,
GR_SERVER_NAME,
ICON_PATH
)
from src.server.qa import QA
prompt1 = """'''
{context}
'''
请你根据上述已知资料回答下面的问题,问题如下:
{question}"""
PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=prompt1)
def main():
c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
port=CHAT_DB_PORT, )
vecstore_faiss = VectorStore_FAISS(
embedding_model_name=EMBEEDING_MODEL_PATH,
store_path=FAISS_STORE_PATH,
index_name=INDEX_NAME,
info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
"password": VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER,
reset=False)
base_llm = ChatERNIESerLLM(
chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
# base_llm = ChatGLMSerLLM(url='http://192.168.22.106:8088')
my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db,
_faiss_db=vecstore_faiss)
def clear(): # 清空输入框
return ''
def show_history(): # 显示对话历史
return my_chat.get_history()
def stop_btn():
return gr.Button(interactive=False)
def restart_btn():
return gr.Button(interactive=True)
def get_users():
o_users = my_chat.get_users()
users_l = [item[0] for item in o_users]
return gr.components.Radio(choices=users_l, label="选择一个用户", value=users_l[0], interactive=True), users_l[
0]
def create_chat(user_account):
my_chat.create_chat(user_account)
def get_chats(user_account):
o_chats = my_chat.get_chats(user_account)
if len(o_chats) >= 13:
o_chats = o_chats[:13]
chats_l = [item[0] + ':' + item[1] for item in o_chats]
if my_chat.chat_id:
result = [item for item in chats_l if item.split(":")[0].strip() == my_chat.chat_id][0]
return gr.components.Radio(choices=chats_l, label="历史对话", value=result, interactive=True,
show_label=True)
else:
return gr.components.Radio(choices=chats_l, label="历史对话", value=chats_l[0], interactive=True,
show_label=True)
def set_info(question):
my_chat.set_info(question)
def set_chat_id(chat_id_info):
chat_id = chat_id_info.split(':')[0]
my_chat.set_chat_id(chat_id)
def load():
l_users, t_user = get_users()
l_chats = get_chats(t_user)
return l_users, l_chats
def clear_text(t):
if t == "请输入您的问题":
return ""
else:
return t
def reset_text():
return ""
def blur(t):
if t == "":
return "请输入您的问题"
else:
return t
def text_his(text, history):
history = history + [[text, None]]
return history
def clear_tip(his):
if not his[0][0] and his[0][1] == "你好,我是新晨科技股份的低空经济人工智能助手小晨,如果您有低空经济相关问题,欢迎随时向我咨询。":
return his[1:]
else:
return his
with gr.Blocks(css='index.css', title="低空经济知识问答") as demo:
gr.HTML("""<h1 align="center">低空经济知识问答</h1>""", visible=False)
with gr.Row():
with gr.Column(scale=2, visible=False):
users = gr.components.Radio(choices=[], label="选择一个用户", interactive=True,
visible=False, show_label=False)
chats = gr.components.Radio(choices=[], label="历史对话", interactive=True,
show_label=True, visible=False)
new_chat_btn = gr.Button("新建对话", visible=False)
with gr.Column(scale=8):
chatbot = gr.Chatbot(bubble_full_width=False,
avatar_images=(ICON_PATH + '\\user2.png', ICON_PATH + "\\bot2.png"),
value=[[None,
"你好,我是新晨科技股份的低空经济人工智能助手小晨,如果您有低空经济相关问题,欢迎随时向我咨询。"]],
height=400, show_copy_button=True,
show_label=False, line_breaks=True)
with gr.Row():
input_text = gr.Textbox(show_label=False, lines=1, label="文本输入", scale=9, container=False,
placeholder="请输入您的问题", max_lines=1)
sub_btn = gr.Button("提交", scale=1)
sub_btn.click(
stop_btn, [], sub_btn
).success(
clear_tip, [chatbot], [chatbot]
).success(
text_his, [input_text, chatbot], [chatbot]
).success(
reset_text, [], input_text
).success(
my_chat.async_chat2, [chatbot], [chatbot]
).success(
restart_btn, [], sub_btn
)
# input_text.submit(
# stop_btn, [], sub_btn
# ).then(
# my_chat.async_chat2, [input_text, chatbot], [chatbot]
# ).then(
# restart_btn, [], sub_btn
# )
demo.load(load, [], [users, chats])
# input_text.submit(my_chat.async_chat, [input_text], [chatbot]
# ).then(
# stop_btn, None, sub_btn
# ).then(
# set_info, [input_text], []
# ).then(
# get_chats, [users], [chats]
# ).then(
# my_chat.update_history, None, None
# ).then(
# show_history, None, chatbot
# ).then(
# clear, None, [input_text]
# ).then(
# restart_btn, None, sub_btn
# ).then(
# reset_text, [], input_text
# )
# new_chat_btn.click(create_chat, [users], []).then(
# get_chats, [users], [chats]
# )
# users.change(get_chats, [users], [chats]).then(
# set_chat_id, [chats], None
# ).then(
# show_history, None, chatbot
# )
# chats.change(set_chat_id, [chats], None).then(
# show_history, None, chatbot
# )
# sub_btn.click(my_chat.async_chat, [input_text], [chatbot]
# ).then(
# stop_btn, None, sub_btn
# ).then(
# set_info, [input_text], []
# ).then(
# get_chats, [users], [chats]
# ).then(
# my_chat.update_history, None, None
# ).then(
# show_history, None, chatbot
# ).then(
# clear, None, [input_text]
# ).then(
# restart_btn, None, sub_btn
# ).then(
# reset_text, [], input_text
# )
demo.queue().launch(share=False, inbrowser=True, server_name='0.0.0.0', server_port=GR_PORT)
if __name__ == "__main__":
main()
#component-5 {
height: 76vh !important;
overflow: auto !important;
}
.wrap.svelte-vm32wk.svelte-vm32wk.svelte-vm32wk {
display: inline !important;
}
.wrap .svelte-vm32wk label {
margin-bottom: 10px !important;
}
#component-9 {
height: 88vh !important;
border: #f6faff;
box-shadow: none;
background: #f6faff;
}
footer {
visibility: hidden;
}
.app.svelte-1kyws56.svelte-1kyws56 {
height: 100vh;
}
@media screen and (max-width: 768px) {
input[type="text"], textarea {
-webkit-user-modify: read-write-plaintext-only;
-webkit-text-size-adjust: none;
}
#component-3 {
display: none;
}
#component-5 {
display: none !important;
}
#component-9 {
height: 89vh !important;
}
.app.svelte-1kyws56.svelte-1kyws56 {
height: 100vh;
}
#component-10 {
position: fixed;
bottom: 15px;
right: 0px;
padding: 0 20px;
}
#component-12 {
min-width: min(74px, 100%);
}
}
#component-3 button {
background: #4999ff !important;
color: white !important;
}
#component-3 button:hover {
background: #4999ff !important;
color: white !important;
}
span.svelte-1gfkn6j {
margin: 55px 0;
}
gradio-app {
background-color: #f6faff !important;
}
.bot.svelte-1pjfiar.svelte-1pjfiar.svelte-1pjfiar {
background: white;
border: none;
box-shadow: 0px 0px 9px 0px rgba(0, 0, 0, 0.1);
border-radius: 4px;
}
.user.svelte-1pjfiar.svelte-1pjfiar.svelte-1pjfiar {
background: white;
border: none;
box-shadow: 0px 0px 9px 0px rgba(0, 0, 0, 0.1);
border-radius: 4px;
}
.message-buttons-bubble.svelte-1pjfiar.svelte-1pjfiar.svelte-1pjfiar {
/* background: white; */
border: none;
box-shadow: 0px 0px 9px 0px rgba(0, 0, 0, 0.1);
}
label.svelte-vm32wk > .svelte-vm32wk + .svelte-vm32wk {
color: #989898;
}
p {
color: #26415f;
}
.wrapper.svelte-nab2ao {
background: #f6faff;
}
#component-12 {
background: #4999ff;
color: white;
}
.svelte-11hlfrc svg {
color: gray;
}
#component-6 {
position: fixed;
width: 219px;
height: 36px;
border-radius: 18px;
top: 4%;
z-index: 4;
left: 8%;
}
div.svelte-sfqy0y {
background-color: white;
}
"""各种大模型提供的服务"""
import os
import requests
from typing import Dict, Optional, List, Any, Mapping, Iterator
from pydantic import root_validator
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
import langchain
from langchain_core.language_models import BaseLLM, LLM
from langchain_community.cache import InMemoryCache
from langchain.callbacks.manager import Callbacks
from langchain_core.callbacks import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun
import aiohttp
import asyncio
# 启动llm的缓存
# langchain.llm_cache = InMemoryCache()
class ChatGLMLocLLM(LLM):
model_name: str = "THUDM/chatglm-6b"
ptuning_checkpoint: str = None
quantization_bit: Optional[int] = None
pre_seq_len: Optional[int] = None
prefix_projection: bool = False
tokenizer: AutoTokenizer = None
model: AutoModel = None
def _llm_type(self) -> str:
return "chatglm_local"
# @root_validator()
@staticmethod
def validate_environment(values: Dict) -> Dict:
if not values["model_name"]:
raise ValueError("No model name provided.")
model_name = values["model_name"]
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True)
if values["pre_seq_len"]:
config.pre_seq_len = values["pre_seq_len"]
if values["prefix_projection"]:
config.prefix_projection = values["prefix_projection"]
if values["ptuning_checkpoint"]:
ptuning_checkpoint = values["ptuning_checkpoint"]
print(f"Loading prefix_encoder weight from {ptuning_checkpoint}")
model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join(ptuning_checkpoint, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
else:
model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True).half().cuda()
if values["pre_seq_len"]:
# P-tuning v2
model = model.half().cuda()
model.transformer.prefix_encoder.float().cuda()
if values["quantization_bit"]:
print(f"Quantized to {values['quantization_bit']} bit")
model = model.quantize(values["quantization_bit"])
model = model.eval()
values["tokenizer"] = tokenizer
values["model"] = model
return values
def _call(self, prompt: str, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str:
resp, his = self.model.chat(self.tokenizer, prompt)
# print(f"prompt:{prompt}\nresponse:{resp}\n")
return resp
class ChatGLMSerLLM(LLM):
# 模型服务url
url: str = "http://127.0.0.1:8000"
chat_history: dict = []
out_stream: bool = False
cache: bool = False
@property
def _llm_type(self) -> str:
return "chatglm3-6b"
def get_num_tokens(self, text: str) -> int:
resp = self._post(url=self.url + "/tokens", query=self._construct_query(text))
if resp.status_code == 200:
resp_json = resp.json()
predictions = resp_json['response']
# display(self.convert_data(resp_json['history']))
return predictions
else:
return len(text)
@staticmethod
def convert_data(data):
result = []
for item in data:
result.append({'q': item[0], 'a': item[1]})
return result
def _construct_query(self, prompt: str, temperature=0.95) -> Dict:
"""构造请求体
"""
# self.chat_history.append({"role": "user", "content": prompt})
query = {
"prompt": prompt,
"history": self.chat_history,
"max_length": 4096,
"top_p": 0.7,
"temperature": temperature
}
return query
@classmethod
def _post(cls, url: str,
query: Dict) -> Any:
"""POST请求
"""
_headers = {"Content_Type": "application/json"}
with requests.session() as sess:
resp = sess.post(url,
json=query,
headers=_headers,
timeout=300)
return resp
@staticmethod
async def _post_stream(url: str,
query: Dict,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, stream=False) -> Any:
"""POST请求
"""
_headers = {"Content_Type": "application/json"}
async with aiohttp.ClientSession() as sess:
async with sess.post(url, json=query, headers=_headers, timeout=300) as response:
if response.status == 200:
if stream and not run_manager:
print('not callable')
if run_manager:
for _callable in run_manager.get_sync().handlers:
await _callable.on_llm_start(None, None)
async for chunk in response.content.iter_any():
# 处理每个块的数据
if chunk and run_manager:
for _callable in run_manager.get_sync().handlers:
# print(chunk.decode("utf-8"),end="")
await _callable.on_llm_new_token(chunk.decode("utf-8"))
if run_manager:
for _callable in run_manager.get_sync().handlers:
await _callable.on_llm_end(None)
else:
raise ValueError(f'glm 请求异常,http code:{response.status}')
def _call(self, prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream=False,
**kwargs: Any) -> str:
query = self._construct_query(prompt=prompt,
temperature=kwargs["temperature"] if "temperature" in kwargs else 0.95)
# display("==============================")
# display(query)
# post
if stream or self.out_stream:
async def _post_stream():
await self._post_stream(url=self.url + "/stream",
query=query, run_manager=run_manager, stream=stream or self.out_stream)
asyncio.run(_post_stream())
return ''
else:
resp = self._post(url=self.url,
query=query)
if resp.status_code == 200:
resp_json = resp.json()
# self.chat_history.append({'q': prompt, 'a': resp_json['response']})
predictions = resp_json['response']
# display(self.convert_data(resp_json['history']))
return predictions
else:
raise ValueError(f'glm 请求异常,http code:{resp.status_code}')
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
query = self._construct_query(prompt=prompt,
temperature=kwargs["temperature"] if "temperature" in kwargs else 0.95)
await self._post_stream(url=self.url + "/stream",
query=query, run_manager=run_manager, stream=self.out_stream)
return ''
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters.
"""
_param_dict = {
"url": self.url
}
return _param_dict
import os
import requests
from typing import Dict, Optional, List, Any, Mapping, Iterator
from pydantic import root_validator
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
import langchain
from langchain_core.language_models import BaseLLM, LLM
from langchain_openai import OpenAI
from langchain_community.cache import InMemoryCache
from langchain.callbacks.manager import Callbacks
from langchain_core.callbacks import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun
class ChatGLMSerLLM(OpenAI):
def get_token_ids(self, text: str) -> List[int]:
if self.model_name.__contains__("chatglm"):
## 发起http请求,获取token_ids
url = f"{self.openai_api_base}/num_tokens"
query = {"prompt": text, "model": self.model_name}
_headers = {"Content_Type": "application/json", "Authorization": "chatglm " + self.openai_api_key}
resp = self._post(url=url, query=query, headers=_headers)
if resp.status_code == 200:
resp_json = resp.json()
print(resp_json)
predictions = resp_json['choices'][0]['text']
## predictions字符串转int
return [int(predictions)]
return [len(text)]
@classmethod
def _post(cls, url: str,
query: Dict, headers: Dict) -> Any:
"""POST请求
"""
_headers = {"Content_Type": "application/json"}
_headers.update(headers)
with requests.session() as sess:
resp = sess.post(url,
json=query,
headers=_headers,
timeout=300)
return resp
import logging
import os
from typing import Any, Dict, List, Mapping, Optional
from langchain_core.language_models import BaseLLM, LLM
from langchain_core.outputs import LLMResult
from langchain_core.utils import get_from_dict_or_env
from langchain.callbacks.manager import Callbacks
from langchain_core.callbacks import CallbackManagerForLLMRun
from enum import Enum
from pydantic import root_validator, Field
from .ernie_sdk import CompletionRequest, ErnieBot, Message, bot_message, user_message
logger = logging.getLogger(__name__)
class ModelType(Enum):
ERNIE = "ernie"
ERNIE_LITE = "ernie-lite"
SHEETS1 = "sheets1"
SHEETS2 = "sheets2"
SHEET_COMB = "sheet-comb"
LLAMA2_7B = "llama2-7b"
LLAMA2_13B = "llama2-13b"
LLAMA2_70B = "llama2-70b"
QFCN_LLAMA2_7B = "qfcn-llama2-7b"
BLOOMZ_7B = "bloomz-7b"
MODEL_SERVICE_BASE_URL = "https://aip.baidubce.com/rpc/2.0/"
MODEL_SERVICE_Suffix = {
ModelType.ERNIE: "ai_custom/v1/wenxinworkshop/chat/completions",
ModelType.ERNIE_LITE: "ai_custom/v1/wenxinworkshop/chat/eb-instant",
ModelType.SHEETS1: "ai_custom/v1/wenxinworkshop/chat/besheet",
ModelType.SHEETS2: "ai_custom/v1/wenxinworkshop/chat/besheets2",
ModelType.SHEET_COMB: "ai_custom/v1/wenxinworkshop/chat/sheet_comb1",
ModelType.LLAMA2_7B: "ai_custom/v1/wenxinworkshop/chat/llama_2_7b",
ModelType.LLAMA2_13B: "ai_custom/v1/wenxinworkshop/chat/llama_2_13b",
ModelType.LLAMA2_70B: "ai_custom/v1/wenxinworkshop/chat/llama_2_70b",
ModelType.QFCN_LLAMA2_7B: "ai_custom/v1/wenxinworkshop/chat/qianfan_chinese_llama_2_7b",
ModelType.BLOOMZ_7B: "ai_custom/v1/wenxinworkshop/chat/bloomz_7b1",
}
class ErnieLLM(LLM):
"""
ErnieLLM is a LLM that uses Ernie to generate text.
"""
model_name: Optional[ModelType] = None
access_token: Optional[str] = ""
@root_validator()
def validate_environment(self, values: Dict) -> Dict:
"""Validate the environment."""
# print(values)
model_name = ModelType(get_from_dict_or_env(values, "model_name", "model_name", str(ModelType.ERNIE)))
access_token = get_from_dict_or_env(values, "access_token", "ERNIE_ACCESS_TOKEN", "")
if not access_token:
raise ValueError("No access token provided.")
values["model_name"] = model_name
values["access_token"] = access_token
return values
def _call(self, prompt: str, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str:
request = CompletionRequest(messages=[Message("user", prompt)])
bot = ErnieBot(_get_model_service_url(self.model_name), self.access_token or "", request)
try:
# 你的代码
response = bot.get_response().result
# print("response: ",response)
return response
except Exception as e:
# 处理异常
print("exception:", e)
return e.__str__()
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "ernie"
# def _identifying_params(self) -> Mapping[str, Any]:
# return {
# "name": "ernie",
# }
def _get_model_service_url(model_name) -> str:
# print("_get_model_service_url model_name: ",model_name)
return MODEL_SERVICE_BASE_URL + MODEL_SERVICE_Suffix[model_name]
class ErnieChat(LLM):
model_name: ModelType
access_token: str
prefix_messages: List = Field(default_factory=list)
id: str = ""
def _call(self, prompt: str, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str:
msg = user_message(prompt)
request = CompletionRequest(messages=self.prefix_messages + [msg])
bot = ErnieBot(_get_model_service_url(self.model_name), self.access_token, request)
try:
# 你的代码
response = bot.get_response().result
if self.id == "":
self.id = bot.get_response().id
self.prefix_messages.append(msg)
self.prefix_messages.append(bot_message(response))
return response
except Exception as e:
# 处理异常
raise e
def _get_id(self) -> str:
return self.id
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "ernie"
from dataclasses import asdict, dataclass
from typing import List
from pydantic import BaseModel, Field
from enum import Enum
class MessageRole(str, Enum):
USER = "user"
BOT = "assistant"
@dataclass
class Message:
role: str
content: str
@dataclass
class CompletionRequest:
messages: List[Message]
stream: bool = False
user: str = ""
@dataclass
class Usage:
prompt_tokens: int
completion_tokens: int
total_tokens: int
@dataclass
class CompletionResponse:
id: str
object: str
created: int
result: str
need_clear_history: bool
ban_round: int = 0
sentence_id: int = 0
is_end: bool = False
usage: Usage = None
is_safe: bool = False
is_truncated: bool = False
class ErrorResponse(BaseModel):
error_code: int = Field(...)
error_msg: str = Field(...)
id: str = Field(...)
class ErnieBot:
url: str
access_token: str
request: CompletionRequest
def __init__(self, url: str, access_token: str, request: CompletionRequest):
self.url = url
self.access_token = access_token
self.request = request
def get_response(self) -> CompletionResponse:
import requests
import json
headers = {'Content-Type': 'application/json'}
params = {'access_token': self.access_token}
request_dict = asdict(self.request)
response = requests.post(self.url, params=params, data=json.dumps(request_dict), headers=headers)
# print(response.json())
try:
return CompletionResponse(**response.json())
except Exception as e:
print(e)
raise Exception(response.json())
def user_message(prompt: str) -> Message:
return Message(MessageRole.USER, prompt)
def bot_message(prompt: str) -> Message:
return Message(MessageRole.BOT, prompt)
import os
import requests
from typing import Dict, Optional, List, Any, Mapping, Iterator
from pydantic import root_validator
from langchain_core.language_models import LLM
from langchain_community.cache import InMemoryCache
from langchain.callbacks.manager import Callbacks
from langchain_core.callbacks import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun
import qianfan
from qianfan import ChatCompletion
# 启动llm的缓存
# langchain.llm_cache = InMemoryCache()
class ChatERNIESerLLM(LLM):
# 模型服务url
chat_completion: ChatCompletion = None
# url: str = "http://127.0.0.1:8000"
chat_history: dict = []
out_stream: bool = False
cache: bool = False
model_name: str = "ERNIE-Bot"
# def __init__(self):
# self.chat_completion = qianfan.ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu")
@property
def _llm_type(self) -> str:
return self.model_name
def get_num_tokens(self, text: str) -> int:
return len(text)
@staticmethod
def convert_data(data):
result = []
for item in data:
result.append({'q': item[0], 'a': item[1]})
return result
def _call(self, prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream=False,
**kwargs: Any) -> str:
resp = self.chat_completion.do(model=self.model_name, messages=[{
"role": "user",
"content": prompt
}])
print(resp)
assert resp.code == 200
return resp.body["result"]
async def _post_stream(self,
query: Dict,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream=False) -> Any:
"""POST请求
"""
async for r in await self.chat_completion.ado(model=self.model_name, messages=[query], stream=stream):
assert r.code == 200
if run_manager:
for _callable in run_manager.get_sync().handlers:
await _callable.on_llm_new_token(r.body["result"])
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
await self._post_stream(query={
"role": "user",
"content": prompt
}, stream=True, run_manager=run_manager)
return ''
import os
import transformers
import torch
from transformers import AutoModel, AutoTokenizer, AutoConfig, DataCollatorForSeq2Seq
from peft import PeftModel
class ModelLoader:
def __init__(self, model_name_or_path, pre_seq_len=0, prefix_projection=False):
self.config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
if pre_seq_len is not None and pre_seq_len > 0:
self.config.pre_seq_len = pre_seq_len
self.config.prefix_projection = prefix_projection
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
self.model = AutoModel.from_pretrained(model_name_or_path, config=self.config, trust_remote_code=True).half()
# self.model = self.model.cuda()
self.base_model = self.model
def quantize(self, quantization_bit):
if quantization_bit is not None:
print(f"Quantized to {quantization_bit} bit")
self.model = self.model.quantize(quantization_bit)
return self.model
def models(self):
return self.model, self.tokenizer
def collator(self):
return DataCollatorForSeq2Seq(tokenizer=self.tokenizer, model=self.model)
def load_lora(self, ckpt_path, name="default"):
# 训练时节约GPU占用
_peft_loaded = PeftModel.from_pretrained(self.base_model, ckpt_path, adapter_name=name)
self.model = _peft_loaded.merge_and_unload()
print(f"Load LoRA model successfully!")
def load_loras(self, ckpt_paths, name="default"):
global peft_loaded
if len(ckpt_paths) == 0:
return
first = True
for name, path in ckpt_paths.items():
print(f"Load {name} from {path}")
if first:
peft_loaded = PeftModel.from_pretrained(self.base_model, path, adapter_name=name)
first = False
else:
peft_loaded.load_adapter(path, adapter_name=name)
peft_loaded.set_adapter(name)
self.model = peft_loaded
def load_prefix(self, ckpt_path):
prefix_state_dict = torch.load(os.path.join(ckpt_path, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
self.model.transformer.prefix_encoder.float()
print(f"Load prefix model successfully!")
"""资料分割"""
\ No newline at end of file
from abc import ABC, abstractmethod
class BaseCallback(ABC):
@abstractmethod
def filter(self, title: str, content: str) -> bool: # return True舍弃当前段落
pass
import re
from typing import List
from src.loader.config import SENTENCE_SIZE
from langchain_text_splitters import CharacterTextSplitter
class ChineseTextSplitter(CharacterTextSplitter):
def __init__(self, pdf: bool = False, sentence_size: int = SENTENCE_SIZE, **kwargs):
super().__init__(**kwargs)
self.pdf = pdf
self.sentence_size = sentence_size
def split_text1(self, text: str) -> List[str]:
if self.pdf:
text = re.sub(r"\n{3,}", "\n", text)
text = re.sub('\s', ' ', text)
text = text.replace("\n\n", "")
sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))') # del :;
sent_list = []
for ele in sent_sep_pattern.split(text):
if sent_sep_pattern.match(ele) and sent_list:
sent_list[-1] += ele
elif ele:
sent_list.append(ele)
return sent_list
def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑
if self.pdf:
text = re.sub(r"\n{3,}", r"\n", text)
text = re.sub('\s', " ", text)
text = re.sub("\n\n", "", text)
text = re.sub(r'([;;.!?。!?\?])([^”’])', r"\1\n\2", text) # 单字符断句符
text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # 英文省略号
text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) # 中文省略号
text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r'\1\n\2', text)
# 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号
text = text.rstrip() # 段尾如果有多余的\n就去掉它
# 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。
ls = [i for i in text.split("\n") if i]
for ele in ls:
if len(ele) > self.sentence_size:
ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r'\1\n\2', ele)
ele1_ls = ele1.split("\n")
for ele_ele1 in ele1_ls:
if len(ele_ele1) > self.sentence_size:
ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1)
ele2_ls = ele_ele2.split("\n")
for ele_ele2 in ele2_ls:
if len(ele_ele2) > self.sentence_size:
ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2)
ele2_id = ele2_ls.index(ele_ele2)
ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[
ele2_id + 1:]
ele_id = ele1_ls.index(ele_ele1)
ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:]
_id = ls.index(ele)
ls = ls[:_id] + [i for i in ele1_ls if i] + ls[_id + 1:]
return ls
# 文本分句长度
SENTENCE_SIZE = 100
ZH_TITLE_ENHANCE = False
import os, copy
from langchain_community.document_loaders import UnstructuredFileLoader, TextLoader, CSVLoader, UnstructuredPDFLoader, \
UnstructuredWordDocumentLoader, PDFMinerPDFasHTMLLoader
from .config import SENTENCE_SIZE, ZH_TITLE_ENHANCE
from .chinese_text_splitter import ChineseTextSplitter
from .zh_title_enhance import zh_title_enhance
from langchain_core.documents import Document
from typing import List
from src.loader.callback import BaseCallback
import re
from bs4 import BeautifulSoup
def load(filepath, mode: str = None, sentence_size: int = 0, metadata=None, callbacks=None, **kwargs):
r"""
加载文档,参数说明
mode:文档切割方式,"single", "elements", "paged"
sentence_size:对于较大的document再次切割成多个
kwargs
"""
if filepath.lower().endswith(".md"):
loader = UnstructuredFileLoader(filepath, mode=mode or "elements", **kwargs)
elif filepath.lower().endswith(".txt"):
loader = TextLoader(filepath, autodetect_encoding=True, **kwargs)
elif filepath.lower().endswith(".csv"):
loader = CSVLoader(filepath, **kwargs)
elif filepath.lower().endswith(".pdf"):
# loader = UnstructuredPDFLoader(filepath, mode=mode or "elements",**kwargs)
# 使用自定义pdf loader
return __pdf_loader(filepath, sentence_size=sentence_size, metadata=metadata, callbacks=callbacks)
elif filepath.lower().endswith(".docx") or filepath.lower().endswith(".doc"):
loader = UnstructuredWordDocumentLoader(filepath, mode=mode or "elements", **kwargs)
else:
loader = UnstructuredFileLoader(filepath, mode=mode or "elements", **kwargs)
if sentence_size > 0:
try:
return split(loader.load(), sentence_size)
except:
print(filepath, " is wrong ")
return []
return loader.load()
def loads_path(path: str, **kwargs):
return loads(get_files_in_directory(path), **kwargs)
def loads(filepaths, **kwargs):
default_kwargs = {"mode": "paged"}
default_kwargs.update(**kwargs)
documents = [load(filepath=file, **default_kwargs) for file in filepaths]
return [item for sublist in documents for item in sublist]
def append(documents=None, sentence_size: int = SENTENCE_SIZE): # 保留文档结构信息,注意处理hash
if documents is None:
documents = []
effect_documents = []
last_doc = documents[0]
for doc in documents[1:]:
last_hash = "" if "next_hash" not in last_doc.metadata else last_doc.metadata["next_hash"]
doc_hash = "" if "next_hash" not in doc.metadata else doc.metadata["next_hash"]
if len(last_doc.page_content) + len(doc.page_content) <= sentence_size and last_hash == doc_hash:
last_doc.page_content = last_doc.page_content + doc.page_content
continue
else:
effect_documents.append(last_doc)
last_doc = doc
effect_documents.append(last_doc)
return effect_documents
def split(documents=None, sentence_size: int = SENTENCE_SIZE): # 保留文档结构信息,注意处理hash
if documents is None:
documents = []
effect_documents = []
for doc in documents:
if len(doc.page_content) > sentence_size:
words_list = re.split(r'·-·', doc.page_content.replace("。", "。·-·").replace("\n", "\n·-·")) # 插入分隔符,分割
document = Document(page_content="", metadata=copy.deepcopy(doc.metadata))
first = True
for word in words_list:
if len(document.page_content) + len(word) < sentence_size:
document.page_content += word
else:
if len(document.page_content.replace(" ", "").replace("\n", "")) > 0:
if first:
first = False
else:
effect_documents[-1].metadata["next_doc"] = document.page_content
effect_documents.append(document)
document = Document(page_content=word, metadata=copy.deepcopy(doc.metadata))
if len(document.page_content.replace(" ", "").replace("\n", "")) > 0:
if first:
pass
else:
effect_documents[-1].metadata["next_doc"] = document.page_content
effect_documents.append(document)
else:
effect_documents.append(doc)
return effect_documents
def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_TITLE_ENHANCE, mode: str = None,
**kwargs):
print("load_file", filepath)
if filepath.lower().endswith(".md"):
loader = UnstructuredFileLoader(filepath, mode=mode or "elements", **kwargs)
docs = loader.load()
elif filepath.lower().endswith(".txt"):
loader = TextLoader(filepath, autodetect_encoding=True, **kwargs)
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
docs = loader.load_and_split(textsplitter)
elif filepath.lower().endswith(".csv"):
loader = CSVLoader(filepath, **kwargs)
docs = loader.load()
elif filepath.lower().endswith(".pdf"):
loader = UnstructuredPDFLoader(filepath, mode=mode or "elements", **kwargs)
textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size)
docs = loader.load_and_split(textsplitter)
elif filepath.lower().endswith(".docx"):
loader = UnstructuredWordDocumentLoader(filepath, mode=mode or "elements", **kwargs)
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
docs = loader.load_and_split(textsplitter)
else:
loader = UnstructuredFileLoader(filepath, mode=mode or "elements", **kwargs)
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
docs = loader.load_and_split(text_splitter=textsplitter)
if using_zh_title_enhance:
docs = zh_title_enhance(docs)
write_check_file(filepath, docs)
return docs
def write_check_file(filepath, docs):
folder_path = os.path.join(os.path.dirname(filepath), "tmp_files")
if not os.path.exists(folder_path):
os.makedirs(folder_path)
fp = os.path.join(folder_path, 'load_file.txt')
with open(fp, 'a+', encoding='utf-8') as fout:
fout.write("filepath=%s,len=%s" % (filepath, len(docs)))
fout.write('\n')
for i in docs:
fout.write(str(i))
fout.write('\n')
fout.close()
def get_files_in_directory(directory):
file_paths = []
for root, dirs, files in os.walk(directory):
for file in files:
file_path = os.path.join(root, file)
file_paths.append(file_path)
return file_paths
# 自定义pdf load部分
def __checkV(strings: str):
lines = len(strings.splitlines())
if lines > 3 and len(strings.replace(" ", "")) / lines < 15:
return False
return True
def __isTitle(strings: str):
return len(strings.splitlines()) == 1 and len(strings) > 0 and strings.endswith("\n")
def __appendPara(strings: str):
return strings.replace(".\n", "^_^").replace("。\n", "^-^").replace("?\n", "?^-^").replace("?\n", "?^-^").replace(
"\n", "").replace("^_^", ".\n").replace("^-^", "。\n").replace("?^-^", "?\n").replace("?^-^", "?\n")
def __check_fs_ff(line_ff_fs_s, fs, ff): # 若当前行有上一行一样的字体、字号文字,则返回相同的。默认返回最长文本的字体和字号
re_fs = line_ff_fs_s[-1][0][-1]
re_ff = line_ff_fs_s[-1][1][-1] if line_ff_fs_s[-1][1] else None
max_len = 0
for ff_fs in line_ff_fs_s: # 寻找最长文本字体和字号
c_max = max(list(map(int, ff_fs[0])))
if max_len < ff_fs[2] or (max_len == ff_fs[2] and c_max > int(re_fs)):
max_len = ff_fs[2]
re_fs = c_max
re_ff = ff_fs[1][-1] if ff_fs[1] else None
if fs:
for ff_fs in line_ff_fs_s:
if str(fs) in ff_fs[0] and ff in ff_fs[1]:
re_fs = fs
re_ff = ff
break
return int(re_fs), re_ff
def append_document(snippets1: List[Document], title: str, content: str, callbacks, font_size, page_num, metadate,
need_append: bool = False):
if callbacks:
for cb in callbacks:
if isinstance(cb, BaseCallback):
if cb.filter(title, content):
return
if need_append and len(snippets1) > 0:
ps = snippets1.pop()
snippets1.append(Document(page_content=ps.page_content + title, metadata=ps.metadata))
else:
doc_metadata = {"font-size": font_size, "page_number": page_num}
doc_metadata.update(metadate)
snippets1.append(Document(page_content=title + content, metadata=doc_metadata))
'''
提取pdf文档,按标题和内容进行分割,文档的页码按标题所在页码为准
分割后的文本按sentence_size值再次分割,分割的文本的页码均属于父文本的页码
'''
def __pdf_loader(filepath: str, sentence_size: int = 0, metadata=None, callbacks=None):
if not filepath.lower().endswith(".pdf"):
raise ValueError("file is not pdf document")
loader = PDFMinerPDFasHTMLLoader(filepath)
documents = loader.load()
soup = BeautifulSoup(documents[0].page_content, 'html.parser')
content = soup.find_all('div')
cur_fs = None # 当前文本font-size
last_fs = None # 上一段文本font-size
cur_ff = None # 当前文本风格
cur_text = ''
fs_increasing = False # 下一行字体变大,判断为标题,从此处分割
last_text = ''
last_page_num = 1 # 上一页页码 根据page_split判断当前文本页码
page_num = 1 # 初始页码
page_change = False # 页面切换
page_split = False # 页面是否出现文本分割
last_is_title = False # 上一个文本是否是标题
snippets: List[Document] = []
filename = os.path.basename(filepath)
if metadata:
metadata.update({'source': filepath, 'filename': filename, 'filetype': 'application/pdf'})
else:
metadata = {'source': filepath, 'filename': filename, 'filetype': 'application/pdf'}
for c in content:
divs = c.get('style')
if re.match(r"^(Page|page)", c.text): # 检测当前页的页码
match = re.match(r"^(page|Page)\s+(\d+)", c.text)
if match:
if page_split: # 如果有文本分割,则换页,没有则保持当前文本起始页码
last_page_num = page_num
page_num = match.group(2)
if len(last_text) + len(cur_text) == 0: # 如果翻页且文本为空,上一页页码为当前页码
last_page_num = page_num
page_change = True
page_split = False
continue
if re.findall('writing-mode:(.*?);', divs) == ['False'] or re.match(r'^[0-9\s\n]+$', c.text) or re.match(
r"^第\s+\d+\s+页$", c.text): # 如果不显示或者纯数字
continue
if len(c.text.replace("\n", "").replace(" ", "")) <= 1: # 去掉有效字符小于1的行
continue
sps = c.find_all('span')
if not sps:
continue
line_ff_fs_s = [] # 有效字符大于1的集合
line_ff_fs_s2 = [] # 有效字符为1的集合
for sp in sps: # 如果一行中有多个不同样式的
sp_len = len(sp.text.replace("\n", "").replace(" ", ""))
if sp_len > 0:
st = sp.get('style')
if st:
ff_fs = (re.findall('font-size:(\d+)px', st), re.findall('font-family:(.*?);', st),
len(sp.text.replace("\n", "").replace(" ", "")))
if sp_len == 1: # 过滤一个有效字符的span
line_ff_fs_s2.append(ff_fs)
else:
line_ff_fs_s.append(ff_fs)
if len(line_ff_fs_s) == 0: # 如果为空,则以一个有效字符span为准
if len(line_ff_fs_s2) > 0:
line_ff_fs_s = line_ff_fs_s2
else:
if len(c.text) > 0:
page_change = False
continue
fs, ff = __check_fs_ff(line_ff_fs_s, cur_fs, cur_ff)
if not cur_ff:
cur_ff = ff
if not cur_fs:
cur_fs = fs
if abs(fs - cur_fs) <= 1 and ff == cur_ff: # 风格和字体都没改变
cur_text += c.text
cur_fs = fs
page_change = False
if len(cur_text.splitlines()) > 3: # 连续多行则fs_increasing不再生效
fs_increasing = False
else:
if page_change and cur_fs > fs + 1: # 翻页,(字体变小) 大概率是页眉,跳过c.text。-----有可能切掉一行文本
page_change = False
continue
if last_is_title: # 如果上一个为title
if __isTitle(cur_text) or fs_increasing: # 连续多个title 或者 有变大标识的
last_text = last_text + cur_text
last_is_title = True
fs_increasing = False
else:
append_document(snippets, last_text, __appendPara(cur_text), callbacks, cur_fs,
page_num if page_split else last_page_num, metadata)
page_split = True
last_text = ''
last_is_title = False
fs_increasing = int(fs) > int(cur_fs) # 字体变大
else:
if len(last_text) > 0 and __checkV(last_text): # 过滤部分文本
# 将跨页的两段或者行数较少的文本合并
append_document(snippets, __appendPara(last_text), "", callbacks, last_fs,
page_num if page_split else last_page_num, metadata,
need_append=len(last_text.splitlines()) <= 2 or page_change)
page_split = True
last_text = cur_text
last_is_title = __isTitle(last_text) or fs_increasing
fs_increasing = int(fs) > int(cur_fs)
if page_split:
last_page_num = page_num
last_fs = cur_fs
cur_fs = fs
cur_ff = ff
cur_text = c.text
page_change = False
append_document(snippets, last_text, __appendPara(cur_text), callbacks, cur_fs,
page_num if page_split else last_page_num, metadata)
if sentence_size > 0:
return split(snippets, sentence_size)
return snippets
from typing import List
from langchain_core.documents import Document
import re
def under_non_alpha_ratio(text: str, threshold: float = 0.5):
"""Checks if the proportion of non-alpha characters in the text snippet exceeds a given
threshold. This helps prevent text like "-----------BREAK---------" from being tagged
as a title or narrative text. The ratio does not count spaces.
Parameters
----------
text
The input string to test
threshold
If the proportion of non-alpha characters exceeds this threshold, the function
returns False
"""
if len(text) == 0:
return False
alpha_count = len([char for char in text if char.strip() and char.isalpha()])
total_count = len([char for char in text if char.strip()])
try:
ratio = alpha_count / total_count
return ratio < threshold
except:
return False
def is_possible_title(
text: str,
title_max_word_length: int = 20,
non_alpha_threshold: float = 0.5,
) -> bool:
"""Checks to see if the text passes all the checks for a valid title.
Parameters
----------
text
The input text to check
title_max_word_length
The maximum number of words a title can contain
non_alpha_threshold
The minimum number of alpha characters the text needs to be considered a title
"""
# 文本长度为0的话,肯定不是title
if len(text) == 0:
print("Not a title. Text is empty.")
return False
# 文本中有标点符号,就不是title
ENDS_IN_PUNCT_PATTERN = r"[^\w\s]\Z"
ENDS_IN_PUNCT_RE = re.compile(ENDS_IN_PUNCT_PATTERN)
if ENDS_IN_PUNCT_RE.search(text) is not None:
return False
# 文本长度不能超过设定值,默认20
# NOTE(robinson) - splitting on spaces here instead of word tokenizing because it
# is less expensive and actual tokenization doesn't add much value for the length check
if len(text) > title_max_word_length:
return False
# 文本中数字的占比不能太高,否则不是title
if under_non_alpha_ratio(text, threshold=non_alpha_threshold):
return False
# NOTE(robinson) - Prevent flagging salutations like "To My Dearest Friends," as titles
if text.endswith((",", ".", ",", "。")):
return False
if text.isnumeric():
print(f"Not a title. Text is all numeric:\n\n{text}") # type: ignore
return False
# 开头的字符内应该有数字,默认5个字符内
if len(text) < 5:
text_5 = text
else:
text_5 = text[:5]
alpha_in_text_5 = sum(list(map(lambda x: x.isnumeric(), list(text_5))))
if not alpha_in_text_5:
return False
return True
def zh_title_enhance(docs: List[Document]):
title = None
if len(docs) > 0:
for doc in docs:
if is_possible_title(doc.page_content):
doc.metadata['category'] = 'cn_Title'
title = doc.page_content
elif title:
doc.page_content = f"下文与({title})有关。{doc.page_content}"
return docs
else:
print("文件不存在")
# shellcheck disable=SC1128
#!/bin/bash
# Set the path to the server.py script
SERVER_PATH=server.py
# Set the default values for the arguments
MODEL_NAME_OR_PATH="../../../model/chatglm2-6b"
CHECKPOINT=None
CHECKPOINT_PATH="../../../model/ckpt/chatglm2-6b-qlora-INSv11-rank16-1e-3-30/checkpoint-2000"
PRE_SEQ_LEN=128
QUANTIZATION_BIT=8
PORT=8002
# Call the server.py script with the parsed arguments
python $SERVER_PATH \
--model_name_or_path $MODEL_NAME_OR_PATH \
--checkpoint $CHECKPOINT \
--checkpoint_path $CHECKPOINT_PATH \
--pre_seq_len $PRE_SEQ_LEN \
--quantization_bit $QUANTIZATION_BIT \
--port $PORT
\ No newline at end of file
#!/bin/bash
# Set the path to the server.py script
SERVER_PATH=server.py
# Set the default values for the arguments
MODEL_NAME_OR_PATH="../../../model/chatglm2-6b"
CHECKPOINT=lora
CHECKPOINT_PATH="../../../model/ckpt/chatglm2-6b-qlora-INSv11-rank16-1e-3-30/checkpoint-2000"
QUANTIZATION_BIT=8
PORT=8001
# Call the server.py script with the parsed arguments
python $SERVER_PATH \
--model_name_or_path $MODEL_NAME_OR_PATH \
--checkpoint $CHECKPOINT \
--checkpoint_path $CHECKPOINT_PATH \
--quantization_bit $QUANTIZATION_BIT \
--port $PORT
\ No newline at end of file
import argparse
import time
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
import json
import datetime
import torch
from typing import AsyncIterable
from pydantic import BaseModel
import uvicorn
import signal
from src.llm.loader import ModelLoader
DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(CUDA_DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def build_history(history):
result = []
for item in history if history else []:
result.append((item['q'], item['a']))
return result
def convert_data(data):
result = []
for item in data:
result.append({'q': item[0], 'a': item[1]})
return result
class StreamRequest(BaseModel):
"""Request body for streaming."""
message: str
stop_stream = False
def signal_handler(signal, frame):
global stop_stream
stop_stream = True
async def send_message(message: str, history=[], max_length=2048, top_p=0.7, temperature=0.95) -> AsyncIterable[str]:
global model, tokenizer, stop_stream
count = 0
old_len = 0
print(message)
output = ''
for response, history in model.stream_chat(tokenizer, message, history=history,
max_length=max_length,
top_p=top_p,
temperature=temperature):
# print(old_len,count)
if stop_stream:
stop_stream = False
break
else:
output = response[old_len:]
print(output, end='', flush=True)
# print(output)
old_len = len(response)
signal.signal(signal.SIGINT, signal_handler)
yield f"{output}"
print("")
# yield f"\n"
# print()
app = FastAPI()
@app.post("/stream")
async def stream(request: Request):
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
history = build_history(json_post_list.get('history'))
max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
return StreamingResponse(send_message(prompt, history=history, max_length=max_length if max_length else 2048,
top_p=top_p if top_p else 0.7,
temperature=temperature if temperature else 0.95), media_type="text/plain")
@app.post("/")
async def create_item(request: Request):
global model, tokenizer
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
history = build_history(json_post_list.get('history'))
max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
response, history = model.chat(tokenizer,
prompt,
history=history,
max_length=max_length if max_length else 2048,
top_p=top_p if top_p else 0.7,
temperature=temperature if temperature else 0.95)
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = {
"response": response,
"history": history,
"status": 200,
"time": time
}
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log)
torch_gc()
return answer
@app.post("/tokens")
async def get_num_tokens(request: Request):
global model, tokenizer
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
tokens = tokenizer.encode(prompt, add_special_tokens=False)
print("=======================================")
print("=======================================")
print(len(tokens), prompt)
print("=======================================")
print("=======================================")
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = {
"response": len(tokens),
"status": 200,
"time": time
}
return answer
def parse_args():
parser = argparse.ArgumentParser(description='ChatGLM2-6B Server')
parser.add_argument('--model_name_or_path', type=str, default='THUDM/chatglm2-6b', help='模型id或local path')
parser.add_argument('--checkpoint', type=str, default=None, help='checkpoint类型(None、ptuning、lora)')
parser.add_argument('--checkpoint_path', type=str,
default='../../../model/ckpt/chatglm2-6b-qlora-INSv11-rank16-1e-3-30/checkpoint-2000',
help='checkpoint路径')
parser.add_argument('--pre_seq_len', type=int, default=128, help='prefix 长度')
parser.add_argument('--quantization_bit', type=int, default=None, help='是否量化')
parser.add_argument('--port', type=int, default=8000, help='端口')
parser.add_argument('--host', type=str, default='0.0.0.0', help='host')
# parser.add_argument('--max_input_length', type=int, default=512, help='instruction + input的最大长度')
# parser.add_argument('--max_output_length', type=int, default=1536, help='output的最大长度')
return parser.parse_args()
if __name__ == '__main__':
cfg = parse_args()
## ----------- load model --------------
start = time.time()
if cfg.checkpoint == "lora":
# lora 微调 checkpoint 及模型加载
loader = ModelLoader(cfg.model_name_or_path)
loader.load_lora(cfg.checkpoint_path)
elif cfg.checkpoint == "ptuning":
# ptuning v2 微调 checkpoint 及模型加载
loader = ModelLoader(cfg.model_name_or_path, cfg.pre_seq_len, False)
loader.load_prefix(cfg.checkpoint_path)
else:
loader = ModelLoader(cfg.model_name_or_path)
model, tokenizer = loader.models()
if cfg.quantization_bit is not None:
model = loader.quantize(cfg.quantization_bit)
model.cuda().eval()
uvicorn.run(app, host=cfg.host, port=cfg.port, workers=1)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
from src.llm.loader import ModelLoader
import uvicorn
import json,os,datetime
from typing import List, Optional, Any
from fastapi import FastAPI, HTTPException, Request, status, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sse_starlette.sse import EventSourceResponse
tokens = ["token1"]
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
class Message(BaseModel):
role: str
content: str
class ChatBody(BaseModel):
messages: List[Message]
model: str
stream: Optional[bool] = False
max_tokens: Optional[int] = 4096
temperature: Optional[float] = 0.9
top_p: Optional[float] = 5
class CompletionBody(BaseModel):
prompt: Any
model: str
stream: Optional[bool] = False
max_tokens: Optional[int] = 4096
temperature: Optional[float] = 0.9
top_p: Optional[float] = 5
class EmbeddingsBody(BaseModel):
# Python 3.8 does not support str | List[str]
input: Any
model: Optional[str]
@app.get("/")
def read_root():
return {"Hello": "World!"}
@app.get("/v1/models")
def get_models():
global model
ret = {"data": [], "object": "list"}
if model:
ret['data'].append({
"created": 1677610602,
"id": "gpt-3.5-turbo",
"object": "model",
"owned_by": "openai",
"permission": [
{
"created": 1680818747,
"id": "modelperm-fTUZTbzFp7uLLTeMSo9ks6oT",
"object": "model_permission",
"allow_create_engine": False,
"allow_sampling": True,
"allow_logprobs": True,
"allow_search_indices": False,
"allow_view": True,
"allow_fine_tuning": False,
"organization": "*",
"group": None,
"is_blocking": False
}
],
"root": "gpt-3.5-turbo",
"parent": None,
})
return ret
def generate_response(content: str, chat: bool = True):
global model_name
if chat:
return {
"id": "chatcmpl-77PZm95TtxE0oYLRx3cxa6HtIDI7s",
"object": "chat.completion",
"created": 1682000966,
"model": model_name,
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
},
"choices": [{
"message": {"role": "assistant", "content": content},
"finish_reason": "stop", "index": 0}
]
}
else:
return {
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
"object": "text_completion",
"created": 1589478378,
"model": "text-davinci-003",
"choices": [
{
"text": content,
"index": 0,
"logprobs": None,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}
def generate_stream_response_start():
return {
"id": "chatcmpl-77QWpn5cxFi9sVMw56DZReDiGKmcB",
"object": "chat.completion.chunk", "created": 1682004627,
"model": "gpt-3.5-turbo-0301",
"choices": [{"delta": {"role": "assistant"}, "index": 0, "finish_reason": None}]
}
def generate_stream_response(content: str, chat: bool = True):
global model_name
if chat:
return {
"id": "chatcmpl-77QWpn5cxFi9sVMw56DZReDiGKmcB",
"object": "chat.completion.chunk",
"created": 1682004627,
"model": model_name,
"choices": [{"delta": {"content": content}, "index": 0, "finish_reason": None}
]}
else:
return {
"id":"cmpl-7GfnvmcsDmmTVbPHmTBcNqlMtaEVj",
"object":"text_completion",
"created":1684208299,
"choices":[
{
"text": content,
"index": 0,
"logprobs": None,
"finish_reason": None,
}
],
"model": "text-davinci-003"
}
def generate_stream_response_stop(chat: bool = True):
if chat:
return {"id": "chatcmpl-77QWpn5cxFi9sVMw56DZReDiGKmcB",
"object": "chat.completion.chunk", "created": 1682004627,
"model": "gpt-3.5-turbo-0301",
"choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}]
}
else:
return {
"id":"cmpl-7GfnvmcsDmmTVbPHmTBcNqlMtaEVj",
"object":"text_completion",
"created":1684208299,
"choices":[
{"text":"","index":0,"logprobs":None,"finish_reason":"stop"}],
"model":"text-davinci-003",
}
# @app.post("/v1/embeddings")
# async def embeddings(body: EmbeddingsBody, request: Request, background_tasks: BackgroundTasks):
# return do_embeddings(body, request, background_tasks)
# def do_embeddings(body: EmbeddingsBody, request: Request, background_tasks: BackgroundTasks):
# background_tasks.add_task(torch_gc)
# if request.headers.get("Authorization").split(" ")[1] not in context.tokens:
# raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token is wrong!")
# if not context.embeddings_model:
# raise HTTPException(status.HTTP_404_NOT_FOUND, "Embeddings model not found!")
# embeddings = context.embeddings_model.encode(body.input)
# data = []
# if isinstance(body.input, str):
# data.append({
# "object": "embedding",
# "index": 0,
# "embedding": embeddings.tolist(),
# })
# else:
# for i, embed in enumerate(embeddings):
# data.append({
# "object": "embedding",
# "index": i,
# "embedding": embed.tolist(),
# })
# content = {
# "object": "list",
# "data": data,
# "model": "text-embedding-ada-002-v2",
# "usage": {
# "prompt_tokens": 0,
# "total_tokens": 0
# }
# }
# return JSONResponse(status_code=200, content=content)
# @app.post("/v1/engines/{engine}/embeddings")
# async def engines_embeddings(engine: str, body: EmbeddingsBody, request: Request, background_tasks: BackgroundTasks):
# return do_embeddings(body, request, background_tasks)
def init_model_args(model_args = None):
if model_args is None:
model_args = {}
model_args['temperature'] = model_args['temperature'] if model_args.get('temperature') != None else 0.95
if model_args['temperature'] <= 0:
model_args['temperature'] = 0.1
if model_args['temperature'] > 1:
model_args['temperature'] = 1
model_args['top_p'] = model_args['top_p'] if model_args.get('top_p') else 0.7
model_args['max_tokens'] = model_args['max_tokens'] if model_args.get('max_tokens') != None else 512
return model_args
@app.post("/v1/num_tokens")
async def get_num_tokens(body: CompletionBody, request: Request):
global model, tokenizer,model_name
if request.headers.get("Authorization").split(" ")[1] not in tokens:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token is wrong!")
if not model:
raise HTTPException(status.HTTP_404_NOT_FOUND, "LLM model not found!")
prompt = body.prompt
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
# now = datetime.datetime.now()
# time = now.strftime("%Y-%m-%d %H:%M:%S")
print(prompt,len(prompt_tokens) )
return JSONResponse(content=generate_response(str(len(prompt_tokens)), chat=False))
@app.post("/v1/chat/completions")
async def chat_completions(body: ChatBody, request: Request):
global model, tokenizer,model_name
if request.headers.get("Authorization").split(" ")[1] not in tokens:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token is wrong!")
if not model:
raise HTTPException(status.HTTP_404_NOT_FOUND, "LLM model not found!")
question = body.messages[-1]
if question.role == 'user':
question = question.content
else:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "No Question Found")
history = []
user_question = ''
if model_name == "chatglm3-6b":
for message in body.messages[:-1]:
history.append({"role":message.role, "content":message.content})
# history.extend(body.messages[:-1])
else:
for message in body.messages:
if message.role == 'system':
history.append((message.content, "OK"))
if message.role == 'user':
user_question = message.content
elif message.role == 'assistant':
assistant_answer = message.content
history.append((user_question, assistant_answer))
print(f"question = {question}, history = {history}")
if body.stream:
async def eval_llm():
first = True
model_args = init_model_args({
"temperature": body.temperature,
"top_p": body.top_p,
"max_tokens": body.max_tokens,
})
sends = 0
for response, _ in model.stream_chat(
tokenizer, question, history,
temperature=model_args['temperature'],
top_p=model_args['top_p'],
max_length=max(2048, model_args['max_tokens'])):
ret = response[sends:]
# https://github.com/THUDM/ChatGLM-6B/issues/478
# 修复表情符号的输出问题
if "\uFFFD" == ret[-1:]:
continue
sends = len(response)
if first:
first = False
yield json.dumps(generate_stream_response_start(),
ensure_ascii=False)
yield json.dumps(generate_stream_response(ret), ensure_ascii=False)
yield json.dumps(generate_stream_response_stop(), ensure_ascii=False)
yield "[DONE]"
return EventSourceResponse(eval_llm(), ping=10000)
else:
model_args = init_model_args({
"temperature": body.temperature,
"top_p": body.top_p,
"max_tokens": body.max_tokens,
})
response, _ = model.chat(
tokenizer, question, history,
temperature=model_args['temperature'],
top_p=model_args['top_p'],
max_length=max(2048, model_args['max_tokens']))
return JSONResponse(content=generate_response(response))
@app.post("/v1/completions")
async def completions(body: CompletionBody, request: Request):
print(body)
if request.headers.get("Authorization").split(" ")[1] not in tokens:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token is wrong!")
if not model:
raise HTTPException(status.HTTP_404_NOT_FOUND, "LLM model not found!")
if type(body.prompt) == list:
question = body.prompt[0]
else:
question = body.prompt
print(f"question = {question}")
if body.stream:
async def eval_llm():
model_args = init_model_args({
"temperature": body.temperature,
"top_p": body.top_p,
"max_tokens": body.max_tokens,
})
sends = 0
for response, _ in model.stream_chat(
tokenizer, question, [],
temperature=model_args['temperature'],
top_p=model_args['top_p'],
max_length=max(2048, model_args['max_tokens'])):
ret = response[sends:]
# https://github.com/THUDM/ChatGLM-6B/issues/478
# 修复表情符号的输出问题
if "\uFFFD" == ret[-1:]:
continue
sends = len(response)
yield json.dumps(generate_stream_response(ret, chat=False), ensure_ascii=False)
yield json.dumps(generate_stream_response_stop(chat=False), ensure_ascii=False)
yield "[DONE]"
return EventSourceResponse(eval_llm(), ping=10000)
else:
model_args = init_model_args({
"temperature": body.temperature,
"top_p": body.top_p,
"max_tokens": body.max_tokens,
})
response, _ = model.chat(
tokenizer, question, [],
temperature=model_args['temperature'],
top_p=model_args['top_p'],
max_length=max(2048, model_args['max_tokens']))
print(response)
return JSONResponse(content=generate_response(response, chat=False))
def main():
global model, tokenizer,model_name
parser = argparse.ArgumentParser(
description='Start LLM and Embeddings models as a service.')
parser.add_argument('--model_name_or_path', type=str, help='Choosed LLM model',
default='/model/chatglm3-6b')
parser.add_argument('--device', type=str,
help='Device to run the service, gpu/cpu/mps',
default='gpu')
parser.add_argument('--port', type=int, help='Port number to run the service',
default=8000)
parser.add_argument('--host', type=str, help='host to run the service',
default="0.0.0.0")
parser.add_argument('--checkpoint', type=str, help='model checkpoint to load',
default=None)
parser.add_argument('--checkpoint_path', type=str, help='model checkpoint to load',
default=None)
parser.add_argument('--pre_seq_len', type=int, help='ptuning train pre_seq_len',
default=None)
parser.add_argument('--quantization_bit', type=int, help='quantization_bit 4 or 8, default not set',
default=None)
args = parser.parse_args()
print("> Load config and arguments...")
print(f"Language Model: {args.model_name_or_path}")
print(f"Device: {args.device}")
print(f"Port: {args.port}")
print(f"Host: {args.host}")
print(f"Quantization_bit: {args.quantization_bit}")
print(f"Checkpoint: {args.checkpoint}")
print(f"Checkpoint_path: {args.checkpoint_path}")
model_name = os.path.basename(args.model_name_or_path)
print(model_name)
if args.checkpoint == "lora":
# lora 微调 checkpoint 及模型加载
loader = ModelLoader(args.model_name_or_path)
loader.load_lora(args.checkpoint_path)
elif args.checkpoint == "ptuning":
# ptuning v2 微调 checkpoint 及模型加载
loader = ModelLoader(args.model_name_or_path, args.pre_seq_len, False)
loader.load_prefix(args.checkpoint_path)
else:
loader = ModelLoader(args.model_name_or_path)
model,tokenizer = loader.models()
if args.quantization_bit is not None:
model = loader.quantize(args.quantization_bit)
model.cuda().eval()
uvicorn.run(app, host=args.host, port=args.port, workers=1)
if __name__ == '__main__':
main()
\ No newline at end of file
"""会话信息相关表"""
\ No newline at end of file
import psycopg2
from psycopg2 import OperationalError, InterfaceError
class UPostgresDB:
"""
psycopg2.connect(
dsn #指定连接参数。可以使用参数形式或 DSN 形式指定。
host #指定连接数据库的主机名。
dbname #指定数据库名。
user #指定连接数据库使用的用户名。
password #指定连接数据库使用的密码。
port #指定连接数据库的端口号。
connection_factory #指定创建连接对象的工厂类。
cursor_factory #指定创建游标对象的工厂类。
async_ #指定是否异步连接(默认False)。
sslmode #指定 SSL 模式。
sslrootcert #指定证书文件名。
sslkey #指定私钥文件名。
sslcert #指定公钥文件名。
)
"""
def __init__(self, host, database, user, password, port=5432):
self.host = host
self.database = database
self.user = user
self.password = password
self.port = port
self.conn = None
self.cur = None
def connect(self):
try:
self.conn = psycopg2.connect(
host=self.host,
database=self.database,
user=self.user,
password=self.password,
port=self.port
)
self.cur = self.conn.cursor()
except Exception as e:
print(f"连接数据库出现错误: {e}")
def execute(self, query):
try:
if self.conn is None or self.conn.closed:
self.connect()
self.cur.execute(query)
self.conn.commit()
except InterfaceError as e:
print(f"数据库连接已经关闭: {e}")
except OperationalError as e:
print(f"数据库连接出现问题: {e}")
self.connect()
self.retry_execute(query)
except Exception as e:
print(f"执行sql语句出现错误: {e}")
self.conn.rollback()
def retry_execute(self, query):
try:
self.cur.execute(query)
self.conn.commit()
except Exception as e:
print(f"重新执行sql语句再次出现错误: {type(e).__name__}: {e}")
self.conn.rollback()
def execute_args(self, query, args):
try:
if self.conn is None or self.conn.closed:
self.connect()
self.cur.execute(query, args)
self.conn.commit()
except InterfaceError as e:
print(f"数据库连接已经关闭: {e}")
except OperationalError as e:
print(f"数据库操作出现问题: {e}")
self.connect()
self.retry_execute_args(query, args)
except Exception as e:
print(f"执行sql语句出现错误: {e}")
self.conn.rollback()
def retry_execute_args(self, query, args):
try:
self.cur.execute(query, args)
self.conn.commit()
except Exception as e:
print(f"重新执行sql语句再次出现错误: {type(e).__name__}: {e}")
self.conn.rollback()
def search(self, query, params=None):
if self.conn is None or self.conn.closed:
self.connect()
self.cur.execute(query, params)
def fetchall(self):
return self.cur.fetchall()
def fetchone(self):
return self.cur.fetchone()
def close(self):
self.cur.close()
self.conn.close()
def format(self, query):
try:
if self.conn is None or self.conn.closed:
self.connect()
self.cur.execute(query)
self.conn.commit()
except Exception as e:
print(f"An error occurred: {e}")
self.conn.rollback()
from .c_db import UPostgresDB
TABLE_USER = """
DROP TABLE IF EXISTS "c_user";
CREATE TABLE c_user (
user_id varchar(1000) PRIMARY KEY,
account varchar(20) NOT NULL,
password varchar(50) NOT NULL
);
COMMENT ON COLUMN "c_user"."user_id" IS '用户id';
COMMENT ON COLUMN "c_user"."account" IS '用户帐户';
COMMENT ON COLUMN "c_user"."password" IS '用户密码';
COMMENT ON TABLE "c_user" IS '用户表';
"""
class CUser:
def __init__(self, _user_id, _account, _password) -> None:
self.user_id = _user_id
self.account = _account
self.password = _password
from .c_db import UPostgresDB
TABLE_CHAT = """
DROP TABLE IF EXISTS "chat";
CREATE TABLE chat (
chat_id varchar(1000) PRIMARY KEY,
user_id int,
info text,
create_time timestamp(6) DEFAULT current_timestamp,
deleted int2
);
COMMENT ON COLUMN "chat"."chat_id" IS '会话id';
COMMENT ON COLUMN "chat"."user_id" IS '会话创建用户id';
COMMENT ON COLUMN "chat"."info" IS '会话简介';
COMMENT ON COLUMN "chat"."create_time" IS '会话创建时间,默认为当前时间';
COMMENT ON COLUMN "chat"."deleted" IS '是否删除:0=否,1=是';
COMMENT ON TABLE "chat" IS '会话信息表';
"""
class Chat:
def __init__(self, _chat_id, _user_id, _info, _create_time, _deleted) -> None:
self.chat_id = _chat_id
self.user_id = _user_id
self.info = _info
self.create_time = _create_time
self.deleted = _deleted
from .c_db import UPostgresDB
TABLE_CHAT = """
DROP TABLE IF EXISTS "chat";
CREATE TABLE chat (
chat_id varchar(1000) PRIMARY KEY,
user_id varchar(1000),
info text,
create_time timestamp(6) DEFAULT current_timestamp,
deleted int2
);
COMMENT ON COLUMN "chat"."chat_id" IS '会话id';
COMMENT ON COLUMN "chat"."user_id" IS '会话创建用户id';
COMMENT ON COLUMN "chat"."info" IS '会话简介';
COMMENT ON COLUMN "chat"."create_time" IS '会话创建时间,默认为当前时间';
COMMENT ON COLUMN "chat"."deleted" IS '是否删除:0=否,1=是';
COMMENT ON TABLE "chat" IS '会话信息表';
DROP SEQUENCE IF EXISTS "chat_seq";
CREATE SEQUENCE "chat_seq"
INCREMENT 1
MINVALUE 1
MAXVALUE 2147483647
START 1
CACHE 1;
ALTER TABLE chat ALTER COLUMN chat_id SET DEFAULT nextval('chat_seq'::regclass);
"""
TABLE_TURN_QA = """
DROP TABLE IF EXISTS "turn_qa";
CREATE TABLE turn_qa (
turn_id varchar(1000) PRIMARY KEY,
chat_id varchar(1000),
question text,
answer text,
similar_docs text,
create_time timestamp(6) DEFAULT current_timestamp,
turn_number int,
is_last int2
);
COMMENT ON COLUMN "turn_qa"."turn_id" IS '会话轮次id';
COMMENT ON COLUMN "turn_qa"."chat_id" IS '会话id';
COMMENT ON COLUMN "turn_qa"."question" IS '该轮会话问题';
COMMENT ON COLUMN "turn_qa"."answer" IS '该轮会话答案';
COMMENT ON COLUMN "turn_qa"."similar_docs" IS '该轮会话相似文档 hash 索引';
COMMENT ON COLUMN "turn_qa"."create_time" IS '该轮会话创建时间,默认为当前时间';
COMMENT ON COLUMN "turn_qa"."turn_number" IS '会话轮数';
COMMENT ON COLUMN "turn_qa"."is_last" IS '是否为最后一轮对话:0=否,1=是';
COMMENT ON TABLE "turn_qa" IS '会话轮次信息表';
DROP SEQUENCE IF EXISTS "turn_qa_seq";
CREATE SEQUENCE "turn_qa_seq"
INCREMENT 1
MINVALUE 1
MAXVALUE 2147483647
START 1
CACHE 1;
ALTER TABLE turn_qa ALTER COLUMN turn_id SET DEFAULT nextval('turn_qa_seq'::regclass);
"""
TABLE_USER = """
DROP TABLE IF EXISTS "c_user";
CREATE TABLE c_user (
user_id varchar(1000) PRIMARY KEY,
account varchar(20) NOT NULL,
password varchar(50) NOT NULL
);
COMMENT ON COLUMN "c_user"."user_id" IS '用户id';
COMMENT ON COLUMN "c_user"."account" IS '用户帐户';
COMMENT ON COLUMN "c_user"."password" IS '用户密码';
COMMENT ON TABLE "c_user" IS '用户表';
DROP SEQUENCE IF EXISTS "c_user_seq";
CREATE SEQUENCE "c_user_seq"
INCREMENT 1
MINVALUE 1
MAXVALUE 2147483647
START 1
CACHE 1;
ALTER TABLE c_user ALTER COLUMN user_id SET DEFAULT nextval('c_user_seq'::regclass);
"""
class CRUD:
def __init__(self, _db: UPostgresDB):
self.db = _db
def create_table(self):
self.db.execute(TABLE_CHAT)
self.db.execute(TABLE_TURN_QA)
self.db.execute(TABLE_USER)
def get_history(self, _chat_id):
query = f'SELECT turn_number,question,answer,is_last,similar_docs FROM turn_qa WHERE chat_id=(%s) ORDER BY turn_number ASC'
self.db.execute_args(query, (_chat_id,))
ans = self.db.fetchall()
return ans
def get_last_history(self, _chat_id):
query = f'SELECT question,answer,similar_docs FROM turn_qa WHERE chat_id=(%s) and is_last=1 ORDER BY turn_number ASC'
self.db.execute_args(query, (_chat_id,))
ans = self.db.fetchall()
return ans
def get_last_history_before_turn_id(self, _chat_id,turn_id):
query = f'SELECT question,answer,similar_docs FROM turn_qa WHERE chat_id=(%s) and is_last=1 and turn_number<(%s) ORDER BY turn_number ASC'
self.db.execute_args(query, (_chat_id,turn_id))
ans = self.db.fetchall()
return ans
def insert_turn_qa(self, chat_id, question, answer, turn_number, is_last, similar_docs=None):
query = f'INSERT INTO turn_qa(chat_id, question, answer, turn_number, is_last, similar_docs) VALUES (%s,%s,%s,%s,%s,%s)'
self.db.execute_args(query, (chat_id, question, answer, turn_number, is_last, similar_docs))
def insert_c_user(self, account, password):
query = f'INSERT INTO c_user(account, password) VALUES (%s,%s)'
self.db.execute_args(query, (account, password))
def insert_chat(self, user_id, info, deleted):
query = f'INSERT INTO chat(user_id, info, deleted) VALUES (%s,%s,%s)'
self.db.execute_args(query, (user_id, info, deleted))
def update_last(self, chat_id):
query = f'UPDATE turn_qa SET is_last = 0 WHERE chat_id = (%s) AND is_last = 1'
self.db.execute_args(query, (chat_id,))
def update_turn_last(self, chat_id, turn_number):
query = f'UPDATE turn_qa SET is_last = 0 WHERE chat_id = (%s) AND turn_number = (%s)'
self.db.execute_args(query, (chat_id, turn_number))
def user_exist_id(self, _user_id):
query = f'SELECT * FROM c_user WHERE user_id = (%s)'
self.db.execute_args(query, (_user_id,))
return self.db.fetchone()
def user_exist_account(self, _account):
query = f'SELECT * FROM c_user WHERE account = (%s)'
self.db.execute_args(query, (_account,))
return self.db.fetchone()
def user_exist_account_password(self, _account, _password):
query = f'SELECT user_id FROM c_user WHERE account = (%s) AND password = (%s)'
self.db.execute_args(query, (_account, _password))
return self.db.fetchone()
def chat_exist_chatid_userid(self, _chat_id, _user_id):
query = f'SELECT * FROM chat WHERE chat_id = (%s) AND user_id = (%s)'
self.db.execute_args(query, (_chat_id, _user_id))
return self.db.fetchone()
def get_chat_list_userid(self, _user_id):
query = f'SELECT chat_id FROM chat WHERE user_id = (%s) AND deleted = 0 order by create_time desc'
self.db.execute_args(query, (_user_id,))
return self.db.fetchall()
def get_chatinfo_from_chatid(self, _chat_id):
query = f'SELECT info FROM chat WHERE chat_id = (%s)'
self.db.execute_args(query, (_chat_id,))
return self.db.fetchone()
def delete_chat(self, _chat_id):
query = f'UPDATE chat SET deleted = 1 WHERE chat_id = (%s)'
self.db.execute_args(query, (_chat_id,))
def get_last_question(self, _chat_id):
query = f'SELECT question FROM turn_qa WHERE chat_id = (%s) AND turn_number = 1'
self.db.execute_args(query, (_chat_id,))
return self.db.fetchone()[0]
def get_users(self):
query = f'SELECT account FROM c_user'
self.db.execute(query)
return self.db.fetchall()
def get_chats(self, account):
query = f'SELECT chat.chat_id,chat.info FROM chat JOIN c_user ON chat.user_id = c_user.user_id WHERE c_user.account = (%s) ORDER BY chat.create_time DESC;'
self.db.execute_args(query, (account,))
return self.db.fetchall()
def create_chat(self, user_id, info, deleted):
query = f'INSERT INTO chat(user_id, info, deleted) VALUES (%s,%s,%s) RETURNING chat_id'
self.db.execute_args(query, (user_id, info, deleted))
ans = self.db.fetchall()[0][0]
return ans
def get_uersid_from_account(self, account):
query = f'SELECT user_id FROM c_user WHERE account = (%s)'
self.db.execute_args(query, (account, ))
ans = self.db.fetchall()[0][0]
print(ans)
return ans
def get_chat_info(self, chat_id):
query = f'SELECT info FROM chat WHERE chat_id = (%s)'
self.db.execute_args(query, (chat_id,))
ans = self.db.fetchall()[0][0]
print(ans)
return ans
def set_info(self, chat_id, info):
query = f'UPDATE chat SET info = (%s) WHERE chat_id = (%s)'
self.db.execute_args(query, (info, chat_id))
def get_last_turn_num(self,chat_id):
query = f'SELECT max(turn_number) FROM turn_qa WHERE chat_id = (%s)'
self.db.execute_args(query, (chat_id,))
ans = self.db.fetchone()[0]
return ans
TABLE_TURN_QA = """
DROP TABLE IF EXISTS "turn_qa";
CREATE TABLE turn_qa (
turn_id varchar(1000) PRIMARY KEY,
chat_id varchar(1000),
question text,
answer text,
similar_docs text,
create_time timestamp(6) DEFAULT current_timestamp,
turn_number int,
is_last int2
);
COMMENT ON COLUMN "turn_qa"."turn_id" IS '会话轮次id';
COMMENT ON COLUMN "turn_qa"."chat_id" IS '会话id';
COMMENT ON COLUMN "turn_qa"."question" IS '该轮会话问题';
COMMENT ON COLUMN "turn_qa"."answer" IS '该轮会话答案';
COMMENT ON COLUMN "turn_qa"."similar_docs" IS '该轮会话相似文档 hash 索引';
COMMENT ON COLUMN "turn_qa"."create_time" IS '该轮会话创建时间,默认为当前时间';
COMMENT ON COLUMN "turn_qa"."turn_number" IS '会话轮数';
COMMENT ON COLUMN "turn_qa"."is_last" IS '是否为最后一轮对话:0=否,1=是';
COMMENT ON TABLE "turn_qa" IS '会话轮次信息表';
"""
class TurnQa:
def __init__(self, _turn_id, _chat_id, _question, _answer, _create_time, _turn_number, _is_last) -> None:
self.turn_id = _turn_id
self.chat_id = _chat_id
self.question = _question
self.answer = _answer
self.create_time = _create_time
self.turn_number = _turn_number
self.is_last = _is_last
\ No newline at end of file
"""资料存储相关"""
\ No newline at end of file
import sys
from abc import ABC, abstractmethod
import json
from typing import List, Tuple
from langchain_core.documents import Document
from src.pgdb.knowledge.pgsqldocstore import PgSqlDocstore, str2hash_base64
sys.path.append("../")
class DocumentCallback(ABC):
@abstractmethod # 向量库储存前文档处理--
def before_store(self, docstore: PgSqlDocstore, documents):
pass
@abstractmethod # 向量库查询后文档处理--用于结构建立
def after_search(self, docstore: PgSqlDocstore, documents: List[Tuple[Document, float]], number: int = 1000) -> \
List[Tuple[Document, float]]: # 向量库查询后文档处理
pass
class DefaultDocumentCallback(DocumentCallback):
def before_store(self, docstore: PgSqlDocstore, documents):
output_doc = []
for doc in documents:
if "next_doc" in doc.metadata:
doc.metadata["next_hash"] = str2hash_base64(doc.metadata["next_doc"])
doc.metadata.pop("next_doc")
output_doc.append(doc)
return output_doc
def after_search(self, docstore: PgSqlDocstore, documents: List[Tuple[Document, float]], number: int = 1000) -> \
List[Tuple[Document, float]]: # 向量库查询后文档处理
output_doc: List[Tuple[Document, float]] = []
exist_hash = []
for doc, score in documents:
print(exist_hash)
dochash = str2hash_base64(doc.page_content)
if dochash in exist_hash:
continue
else:
exist_hash.append(dochash)
output_doc.append((doc, score))
if len(output_doc) > number:
return output_doc
fordoc = doc
while "next_hash" in fordoc.metadata:
if len(fordoc.metadata["next_hash"]) > 0:
if fordoc.metadata["next_hash"] in exist_hash:
break
else:
exist_hash.append(fordoc.metadata["next_hash"])
content = docstore.TXT_DOC.search(fordoc.metadata["next_hash"])
if content:
fordoc = Document(page_content=content[0], metadata=json.loads(content[1]))
output_doc.append((fordoc, score))
if len(output_doc) > number:
return output_doc
else:
break
else:
break
return output_doc
import psycopg2
class PostgresDB:
"""
psycopg2.connect(
dsn #指定连接参数。可以使用参数形式或 DSN 形式指定。
host #指定连接数据库的主机名。
dbname #指定数据库名。
user #指定连接数据库使用的用户名。
password #指定连接数据库使用的密码。
port #指定连接数据库的端口号。
connection_factory #指定创建连接对象的工厂类。
cursor_factory #指定创建游标对象的工厂类。
async_ #指定是否异步连接(默认False)。
sslmode #指定 SSL 模式。
sslrootcert #指定证书文件名。
sslkey #指定私钥文件名。
sslcert #指定公钥文件名。
)
"""
def __init__(self, host, database, user, password, port=5432):
self.host = host
self.database = database
self.user = user
self.password = password
self.port = port
self.conn = None
self.cur = None
def connect(self):
self.conn = psycopg2.connect(
host=self.host,
database=self.database,
user=self.user,
password=self.password,
port=self.port
)
self.cur = self.conn.cursor()
def execute(self, query):
try:
self.cur.execute(query)
self.conn.commit()
except Exception as e:
print(f"An error occurred: {e}")
self.conn.rollback()
def execute_args(self, query, args):
try:
self.cur.execute(query, args)
self.conn.commit()
except Exception as e:
print(f"An error occurred: {e}")
self.conn.rollback()
def search(self, query, params=None):
self.cur.execute(query, params)
def fetchall(self):
return self.cur.fetchall()
def close(self):
self.cur.close()
self.conn.close()
def format(self, query):
try:
self.cur.execute(query)
self.conn.commit()
except Exception as e:
print(f"An error occurred: {e}")
self.conn.rollback()
import sys
from os import path
# 这里相当于把当前目录添加到pythonpath中
sys.path.append(path.dirname(path.abspath(__file__)))
from typing import List, Union, Dict, Optional
from langchain_community.docstore.base import AddableMixin, Docstore
from k_db import PostgresDB
from .txt_doc_table import TxtDoc
from .vec_txt_table import TxtVector
import json, hashlib, base64
from langchain_core.documents import Document
def str2hash_base64(inp: str) -> str:
# return f"%s" % hash(input)
return base64.b64encode(hashlib.sha1(inp.encode()).digest()).decode()
class PgSqlDocstore(Docstore, AddableMixin):
host: str
dbname: str
username: str
password: str
port: str
'''
说明,重写__getstate__,__setstate__,适用于langchain的序列化存储,基于pickle进行存储。返回数组包含pgsql连接信息。
'''
def __getstate__(self):
return {"host": self.host, "dbname": self.dbname, "username": self.username, "password": self.password,
"port": self.port}
def __setstate__(self, info):
self.__init__(info)
def __init__(self, info: dict, reset: bool = False):
self.host = info["host"]
self.dbname = info["dbname"]
self.username = info["username"]
self.password = info["password"]
self.port = info["port"] if "port" in info else "5432"
self.pgdb = PostgresDB(self.host, self.dbname, self.username, self.password, port=self.port)
self.TXT_DOC = TxtDoc(self.pgdb)
self.VEC_TXT = TxtVector(self.pgdb)
if reset:
self.__sub_init__()
self.TXT_DOC.drop_table()
self.VEC_TXT.drop_table()
self.TXT_DOC.create_table()
self.VEC_TXT.create_table()
def __sub_init__(self):
if not self.pgdb.conn:
self.pgdb.connect()
'''
从本地库中查找向量对应的文本段落,封装成Document返回
'''
def search(self, search: str) -> Union[str, Document]:
if not self.pgdb.conn:
self.__sub_init__()
anwser = self.VEC_TXT.search(search)
content = self.TXT_DOC.search(anwser[0])
meta = json.loads(content[1])
meta.update({"hash": anwser[0]}) # paragraph_id = hash 插入到metadata中,便于后续根据段落查找
if content:
return Document(page_content=content[0], metadata=meta)
else:
return Document()
'''
从本地库中删除向量对应的文本,批量删除
'''
def delete(self, ids: List) -> None:
if not self.pgdb.conn:
self.__sub_init__()
pids = []
for item in ids:
anwser = self.VEC_TXT.search(item)
pids.append(anwser[0])
self.VEC_TXT.delete(ids)
self.TXT_DOC.delete(pids)
'''
向本地库添加向量和文本信息
[vector_id,Document(page_content=问题, metadata=dict(paragraph=段落文本))]
'''
def add(self, texts: Dict[str, Document]) -> None:
# for vec,doc in texts.items():
# paragraph_id = self.TXT_DOC.insert(doc.metadata["paragraph"])
# self.VEC_TXT.insert(vector_id=vec,paragraph_id=paragraph_id,text=doc.page_content)
if not self.pgdb.conn:
self.__sub_init__()
paragraph_hashs = [] # hash,text
paragraph_txts = []
vec_inserts = []
for vec, doc in texts.items():
txt_hash = str2hash_base64(doc.metadata["paragraph"])
print(txt_hash)
vec_inserts.append((vec, doc.page_content, txt_hash))
if txt_hash not in paragraph_hashs:
paragraph_hashs.append(txt_hash)
paragraph = doc.metadata["paragraph"]
doc.metadata.pop("paragraph")
paragraph_txts.append((txt_hash, paragraph, json.dumps(doc.metadata, ensure_ascii=False)))
# print(paragraph_txts)
self.TXT_DOC.insert(paragraph_txts)
self.VEC_TXT.insert(vec_inserts)
class InMemorySecondaryDocstore(Docstore, AddableMixin):
"""Simple in memory docstore in the form of a dict."""
def __init__(self, _dict: Optional[Dict[str, Document]] = None, _sec_dict: Optional[Dict[str, Document]] = None):
"""Initialize with dict."""
self._dict = _dict if _dict is not None else {}
self._sec_dict = _sec_dict if _sec_dict is not None else {}
def add(self, texts: Dict[str, Document]) -> None:
"""Add texts to in memory dictionary.
Args:
texts: dictionary of id -> document.
Returns:
None
"""
overlapping = set(texts).intersection(self._dict)
if overlapping:
raise ValueError(f"Tried to add ids that already exist: {overlapping}")
self._dict = {**self._dict, **texts}
dict1 = {}
dict_sec = {}
for vec, doc in texts.items():
txt_hash = str2hash_base64(doc.metadata["paragraph"])
metadata = doc.metadata
paragraph = metadata.pop('paragraph')
# metadata.update({"paragraph_id":txt_hash})
metadata['paragraph_id'] = txt_hash
dict_sec[txt_hash] = Document(page_content=paragraph, metadata=metadata)
dict1[vec] = Document(page_content=doc.page_content, metadata={'paragraph_id': txt_hash})
self._dict = {**self._dict, **dict1}
self._sec_dict = {**self._sec_dict, **dict_sec}
def delete(self, ids: List) -> None:
"""Deleting IDs from in memory dictionary."""
overlapping = set(ids).intersection(self._dict)
if not overlapping:
raise ValueError(f"Tried to delete ids that does not exist: {ids}")
for _id in ids:
self._sec_dict.pop(self._dict[_id].metadata['paragraph_id'])
self._dict.pop(_id)
def search(self, search: str) -> Union[str, Document]:
"""Search via direct lookup.
Args:
search: id of a document to search for.
Returns:
Document if found, else error message.
"""
if search not in self._dict:
return f"ID {search} not found."
else:
print(self._dict[search].page_content)
return self._sec_dict[self._dict[search].metadata['paragraph_id']]
import os
import sys
import re
from os import path
import copy
from typing import List, OrderedDict, Any, Optional, Tuple, Dict
from src.pgdb.knowledge.pgsqldocstore import InMemorySecondaryDocstore
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from src.pgdb.knowledge.pgsqldocstore import PgSqlDocstore
from langchain.embeddings.huggingface import (
HuggingFaceEmbeddings,
)
import math
import faiss
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from src.loader import load
from langchain_core.embeddings import Embeddings
from src.pgdb.knowledge.callback import DocumentCallback, DefaultDocumentCallback
import operator
import numpy as np
sys.path.append("../")
def singleton(cls):
instances = {}
def get_instance(*args, **kwargs):
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]
return get_instance
@singleton
class EmbeddingFactory:
def __init__(self, _path: str):
self.path = _path
self.embedding = HuggingFaceEmbeddings(model_name=_path)
def get_embedding(self):
return self.embedding
def get_embding(_path: str) -> Embeddings:
# return HuggingFaceEmbeddings(model_name=path)
return EmbeddingFactory(_path).get_embedding()
class RE_FAISS(FAISS):
# 去重,并保留metadate
@staticmethod
def _tuple_deduplication(tuple_input: List[Tuple[Document, float]]) -> List[Tuple[Document, float]]:
deduplicated_dict = OrderedDict()
print("--------------oedereddict type--------------", type(deduplicated_dict))
for doc, scores in tuple_input:
page_content = doc.page_content
metadata = doc.metadata
if page_content not in deduplicated_dict:
deduplicated_dict[page_content] = (metadata, scores)
print("--------------------------du--------------------------\n", deduplicated_dict)
deduplicated_documents = [(Document(page_content=key, metadata=value[0]), value[1]) for key, value in
deduplicated_dict.items()]
return deduplicated_documents
# def similarity_search_with_score_by_vector(
# self,
# embedding: List[float],
# k: int = 4,
# filter: Optional[Dict[str, Any]] = None,
# fetch_k: int = 20,
# **kwargs: Any,
# ) -> List[Tuple[Document, float]]:
# faiss = dependable_faiss_import()
# vector = np.array([embedding], dtype=np.float32)
# if self._normalize_L2:
# faiss.normalize_L2(vector)
# scores, indices = self.index.search(vector, k if filter is None else fetch_k)
# docs = []
# for j, i in enumerate(indices[0]):
# if i == -1:
# # This happens when not enough docs are returned.
# continue
# _id = self.index_to_docstore_id[i]
# doc = self.docstore.search(_id)
# if not isinstance(doc, Document):
# raise ValueError(f"Could not find document for id {_id}, got {doc}")
# if filter is not None:
# filter = {
# key: [value] if not isinstance(value, list) else value
# for key, value in filter.items()
# }
# if all(doc.metadata.get(key) in value for key, value in filter.items()):
# docs.append((doc, scores[0][j]))
# else:
# docs.append((doc, scores[0][j]))
# docs = self._tuple_deduplication(docs)
# score_threshold = kwargs.get("score_threshold")
# if score_threshold is not None:
# cmp = (
# operator.ge
# if self.distance_strategy
# in (DistanceStrategy.MAX_INNER_PRODUCT, DistanceStrategy.JACCARD)
# else operator.le
# )
# docs = [
# (doc, similarity)
# for doc, similarity in docs
# if cmp(similarity, score_threshold)
# ]
#
# if "doc_callback" in kwargs:
# if hasattr(kwargs["doc_callback"], 'after_search'):
# docs = kwargs["doc_callback"].after_search(self.docstore, docs, number=k)
# return docs[:k]
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch before filtering to
pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter
)
docs_and_scores = self._tuple_deduplication(docs_and_scores)
if "doc_callback" in kwargs:
if hasattr(kwargs["doc_callback"], 'after_search'):
docs_and_scores = kwargs["doc_callback"].after_search(self.docstore, docs_and_scores, number=k)
return [doc for doc, _ in docs_and_scores]
def getFAISS(embedding_model_name: str, store_path: str, info: dict = None, index_name: str = "index",
is_pgsql: bool = True, reset: bool = False) -> RE_FAISS:
embeddings = get_embding(_path=embedding_model_name)
docstore1: PgSqlDocstore = None
if is_pgsql:
if info and "host" in info and "dbname" in info and "username" in info and "password" in info:
docstore1 = PgSqlDocstore(info, reset=reset)
else:
docstore1 = InMemorySecondaryDocstore()
if not path.exists(store_path):
os.makedirs(store_path, exist_ok=True)
if store_path is None or len(store_path) <= 0 or not path.exists(
path.join(store_path, index_name + ".faiss")) or reset:
print("create new faiss")
index = faiss.IndexFlatL2(len(embeddings.embed_documents(["a"])[0])) # 根据embeddings向量维度设置
return RE_FAISS(embedding_function=embeddings.client.encode, index=index, docstore=docstore1,
index_to_docstore_id={})
else:
print("load_local faiss")
_faiss = RE_FAISS.load_local(folder_path=store_path, index_name=index_name, embeddings=embeddings, allow_dangerous_deserialization=True)
if docstore1 and is_pgsql: # 如果外部参数调整,更新docstore
_faiss.docstore = docstore1
return _faiss
class VectorStore_FAISS(FAISS):
def __init__(self, embedding_model_name: str, store_path: str, index_name: str = "index", info: dict = None,
is_pgsql: bool = True, show_number=5, threshold=0.8, reset: bool = False,
doc_callback: DocumentCallback = DefaultDocumentCallback()):
self.info = info
self.embedding_model_name = embedding_model_name
self.store_path = path.join(store_path, index_name)
if not path.exists(self.store_path):
os.makedirs(self.store_path, exist_ok=True)
self.index_name = index_name
self.show_number = show_number
self.search_number = self.show_number * 3
self.threshold = threshold
self._faiss = getFAISS(self.embedding_model_name, self.store_path, info=info, index_name=self.index_name,
is_pgsql=is_pgsql, reset=reset)
self.doc_callback = doc_callback
def get_text_similarity_with_score(self, text: str, **kwargs):
score_threshold = (1 - self.threshold) * math.sqrt(2)
docs = self._faiss.similarity_search_with_score(query=text, k=self.search_number,
score_threshold=score_threshold, doc_callback=self.doc_callback,
**kwargs)
return [doc for doc, similarity in docs][:self.show_number]
def get_text_similarity(self, text: str, **kwargs):
docs = self._faiss.similarity_search(query=text, k=self.search_number, doc_callback=self.doc_callback, **kwargs)
return docs[:self.show_number]
# #去重,并保留metadate
# def _tuple_deduplication(self, tuple_input:List[Document]) -> List[Document]:
# deduplicated_dict = OrderedDict()
# for doc in tuple_input:
# page_content = doc.page_content
# metadata = doc.metadata
# if page_content not in deduplicated_dict:
# deduplicated_dict[page_content] = metadata
# deduplicated_documents = [Document(page_content=key,metadata=value) for key, value in deduplicated_dict.items()]
# return deduplicated_documents
@staticmethod
def join_document(docs: List[Document]) -> str:
return "".join([doc.page_content for doc in docs])
@staticmethod
def get_local_doc(docs: List[Document]):
ans = []
for doc in docs:
ans.append({"page_content": doc.page_content, "page_number": doc.metadata["page_number"],
"filename": doc.metadata["filename"]})
return ans
# def _join_document_location(self, docs:List[Document]) -> str:
# 持久化到本地
def _save_local(self):
self._faiss.save_local(folder_path=self.store_path, index_name=self.index_name)
# 添加文档
# Document {
# page_content 段落
# metadata {
# page 页码
# }
# }
def _add_documents(self, new_docs: List[Document], need_split: bool = True, pattern: str = r'[?。;\n]'):
list_of_documents: List[Document] = []
if self.doc_callback:
new_docs = self.doc_callback.before_store(self._faiss.docstore, new_docs)
if need_split:
for doc in new_docs:
words_list = re.split(pattern, doc.page_content)
# 去掉重复项
words_list = set(words_list)
words_list = [str(words) for words in words_list]
for words in words_list:
if not words.strip() == '':
metadata = copy.deepcopy(doc.metadata)
metadata["paragraph"] = doc.page_content
list_of_documents.append(Document(page_content=words, metadata=metadata))
else:
list_of_documents = new_docs
self._faiss.add_documents(list_of_documents)
def _add_documents_from_dir(self, filepaths=None, load_kwargs=None):
if load_kwargs is None:
load_kwargs = {"mode": "paged"}
if filepaths is None:
filepaths = []
self._add_documents(load.loads(filepaths, **load_kwargs))
def as_retriever(self, **kwargs: Any) -> VectorStoreRetriever:
"""
Return VectorStoreRetriever initialized from this VectorStore.
Args:
search_type (Optional[str]): Defines the type of search that
the Retriever should perform.
Can be "similarity" (default), "mmr", or
"similarity_score_threshold".
search_kwargs (Optional[Dict]): Keyword arguments to pass to the
search function. Can include things like:
k: Amount of documents to return (Default: 4)
score_threshold: Minimum relevance threshold
for similarity_score_threshold
fetch_k: Amount of documents to pass to MMR algorithm (Default: 20)
lambda_mult: Diversity of results returned by MMR;
1 for minimum diversity and 0 for maximum. (Default: 0.5)
filter: Filter by document metadata
Returns:
VectorStoreRetriever: Retriever class for VectorStore.
Examples:
. code-block:: python
# Retrieve more documents with higher diversity
# Useful if your dataset has many similar documents
docsearch.as_retriever(
search_type="mmr",
search_kwargs={'k': 6, 'lambda_mult': 0.25}
)
# Fetch more documents for the MMR algorithm to consider
# But only return the top 5
docsearch.as_retriever(
search_type="mmr",
search_kwargs={'k': 5, 'fetch_k': 50}
)
# Only retrieve documents that have a relevance score
# Above a certain threshold
docsearch.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={'score_threshold': 0.8}
)
# Only get the single most similar document from the dataset
docsearch.as_retriever(search_kwargs={'k': 1})
# Use a filter to only retrieve documents from a specific paper
docsearch.as_retriever(
search_kwargs={'filter': {'paper_title':'GPT-4 Technical Report'}}
)
"""
if not kwargs or "similarity_score_threshold" != kwargs["search_type"]:
default_kwargs = {'k': self.show_number}
if "search_kwargs" in kwargs:
default_kwargs.update(kwargs["search_kwargs"])
kwargs["search_kwargs"] = default_kwargs
elif "similarity_score_threshold" == kwargs["search_type"]:
default_kwargs = {'score_threshold': self.threshold, 'k': self.show_number}
if "search_kwargs" in kwargs:
default_kwargs.update(kwargs["search_kwargs"])
kwargs["search_kwargs"] = default_kwargs
kwargs["search_kwargs"]["doc_callback"] = self.doc_callback
tags = kwargs.pop("tags", None) or []
tags.extend(self._faiss._get_retriever_tags())
print(kwargs)
return VectorStoreRetriever_FAISS(vectorstore=self._faiss, **kwargs, tags=tags)
class VectorStoreRetriever_FAISS(VectorStoreRetriever):
search_k = 5
def __init__(self, **kwargs):
super().__init__(**kwargs)
if "k" in self.search_kwargs:
self.search_k = self.search_kwargs["k"]
self.search_kwargs["k"] = self.search_k * 2
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
docs = super()._get_relevant_documents(query=query, run_manager=run_manager)
return docs[:self.search_k]
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
docs = super()._aget_relevant_documents(query=query, run_manager=run_manager)
return docs[:self.search_k]
from .k_db import PostgresDB
# paragraph_id BIGSERIAL primary key,
TABLE_TXT_DOC = """
create table txt_doc (
hash varchar(40) primary key,
text text not null,
matadate text
);
"""
TABLE_TXT_DOC_HASH_INDEX = """
CREATE UNIQUE INDEX hash_index ON txt_doc (hash);
"""
# CREATE UNIQUE INDEX idx_name ON your_table (column_name);
class TxtDoc:
def __init__(self, db: PostgresDB) -> None:
self.db = db
def insert(self, texts):
query = f"INSERT INTO txt_doc(hash,text,matadate) VALUES "
args = []
for value in texts:
value = list(value)
query += "(%s,%s,%s),"
args.extend(value)
query = query[:len(query) - 1]
query += f"ON conflict(hash) DO UPDATE SET text = EXCLUDED.text;"
self.db.execute_args(query, args)
def delete(self, ids):
for item in ids:
query = f"delete FROM txt_doc WHERE hash = %s" % item
self.db.execute(query)
def search(self, item):
query = "SELECT text,matadate FROM txt_doc WHERE hash = %s"
self.db.execute_args(query, [item])
answer = self.db.fetchall()
if len(answer) > 0:
return answer[0]
else:
return None
# return Document(page_content=self.db.fetchall()[0][0], metadata=dict(page=self.db.fetchall()[0][1]))
# answer = self.db.fetchall()[0][0]
# return answer
def create_table(self):
query = f"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'txt_doc')"
self.db.execute(query)
exists = self.db.fetchall()[0][0]
if not exists:
query = TABLE_TXT_DOC
self.db.execute(query)
# self.db.execute(TABLE_TXT_DOC_HASH_INDEX)
def drop_table(self):
query = f"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'txt_doc')"
self.db.execute(query)
exists = self.db.fetchall()[0][0]
if exists:
query = "DROP TABLE txt_doc"
self.db.format(query)
print("drop table txt_doc ok")
def find_like_doc(self,item:str):
query = "select text,matadate FROM txt_doc WHERE matadate like '%"+item+"%' or text like '%"+item+"%' "
self.db.execute(query)
answer = self.db.fetchall()
if len(answer) > 0:
return answer
else:
return None
\ No newline at end of file
from .k_db import PostgresDB
TABLE_VEC_TXT = """
CREATE TABLE vec_txt (
vector_id varchar(36) PRIMARY KEY,
text text,
paragraph_id varchar(40) not null
)
"""
# 025a9bee-2eb2-47f5-9722-525e05a0442b
class TxtVector:
def __init__(self, db: PostgresDB) -> None:
self.db = db
def insert(self, vectors):
query = f"INSERT INTO vec_txt(vector_id,text,paragraph_id) VALUES"
args = []
for value in vectors:
value = list(value)
query += "(%s,%s,%s),"
args.extend(value)
query = query[:len(query) - 1]
query += f"ON conflict(vector_id) DO UPDATE SET text = EXCLUDED.text,paragraph_id = EXCLUDED.paragraph_id;"
# query += ";"
self.db.execute_args(query, args)
def delete(self, ids):
for item in ids:
query = f"delete FROM vec_txt WHERE vector_id = '%s'" % (item,)
self.db.execute(query)
def search(self, search: str):
query = f"SELECT paragraph_id,text FROM vec_txt WHERE vector_id = %s"
self.db.execute_args(query, [search])
answer = self.db.fetchall()
# print(answer)
return answer[0]
def create_table(self):
query = f"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'vec_txt')"
self.db.execute(query)
exists = self.db.fetchall()[0][0]
if not exists:
query = TABLE_VEC_TXT
self.db.execute(query)
def drop_table(self):
query = f"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'vec_txt')"
self.db.execute(query)
exists = self.db.fetchall()[0][0]
if exists:
query = "DROP TABLE vec_txt"
self.db.format(query)
print("drop table vec_txt ok")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment