Commit 14a6b87b by 陈正乐

修改在询问模型时,所需参数个数

parent 1e9efa79
......@@ -41,4 +41,4 @@ INDEX_NAME = 'know'
# =============================
# 知识相关资料配置
# =============================
KNOWLEDGE_PATH = 'C:\\Users\\15663\\Desktop\\work\\llm_gjjs\\兴火燎原知识库\\兴火燎原知识库\\law\\pdf'
KNOWLEDGE_PATH = 'C:\\Users\\15663\\Desktop\\低空经济数据库'
......@@ -227,7 +227,7 @@ class VectorStore_FAISS(FAISS):
# return deduplicated_documents
@staticmethod
def _join_document(docs: List[Document]) -> str:
def join_document(docs: List[Document]) -> str:
print(docs)
return "".join([doc.page_content for doc in docs])
......
from src.pgdb.knowledge.similarity import VectorStore_FAISS
class GetSimilarity:
def __init__(self, _question, _faiss_db: VectorStore_FAISS):
self.question = _question
self.faiss_db = _faiss_db
self.similarity_doc = self.faiss_db.join_document(self.faiss_db.get_text_similarity("什么是低空飞行"))
def get_similarity_doc(self):
return self.similarity_doc
......@@ -11,12 +11,23 @@ from src.llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion
from src.pgdb.chat.c_db import UPostgresDB
from src.pgdb.chat.crud import CRUD
from src.server.get_similarity import GetSimilarity
from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.config.consts import (
CHAT_DB_USER,
CHAT_DB_HOST,
CHAT_DB_PORT,
CHAT_DB_DBNAME,
CHAT_DB_PASSWORD
CHAT_DB_PASSWORD,
EMBEEDING_MODEL_PATH,
FAISS_STORE_PATH,
INDEX_NAME,
VEC_DB_HOST,
VEC_DB_PASSWORD,
VEC_DB_PORT,
VEC_DB_USER,
VEC_DB_DBNAME,
SIMILARITY_SHOW_NUMBER
)
sys.path.append("../..")
......@@ -29,32 +40,40 @@ PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=promp
class QA:
def __init__(self, _prompt, _base_llm, _llm_kwargs, _prompt_kwargs, _db, _chat_id):
def __init__(self, _prompt, _base_llm, _llm_kwargs, _prompt_kwargs, _db, _chat_id, _faiss_db):
self.prompt = _prompt
self.base_llm = _base_llm
self.llm_kwargs = _llm_kwargs
self.prompt_kwargs = _prompt_kwargs
self.db = _db
self.chat_id = _chat_id
self.faiss_db = _faiss_db
self.crud = CRUD(self.db)
self.history = self.crud.get_history(self.chat_id)
self.llm = LLMChain(llm=self.base_llm, prompt=self.prompt, llm_kwargs=self.llm_kwargs)
self.cur_answer = ""
self.cur_question = ""
# 为所给问题返回similarity文本
def get_similarity(self, _aquestion):
return GetSimilarity(_question=_aquestion, _faiss_db=self.faiss_db).get_similarity_doc()
# 一次性直接给出所有的答案
async def chat(self, *args):
self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, args)})
async def chat(self, _question):
similarity = self.get_similarity(_aquestion=_question)
self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, (similarity, _question))})
self.cur_answer = ""
if not args:
if not _question:
return ""
self.cur_answer = self.llm.run({k: v for k, v in zip(self.prompt_kwargs, args)})
self.cur_answer = self.llm.run({k: v for k, v in zip(self.prompt_kwargs, (similarity, _question))})
return self.cur_answer
# 异步输出,逐渐输出答案
async def async_chat(self, *args):
self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, args)})
async def async_chat(self, _question):
similarity = self.get_similarity(_aquestion=_question)
self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, (similarity, _question))})
callback = AsyncIteratorCallbackHandler()
async def wrap_done(fn: Awaitable, event: asyncio.Event):
try:
await fn
......@@ -64,8 +83,9 @@ class QA:
print(f"Caught exception: {e}")
finally:
event.set()
task = asyncio.create_task(
wrap_done(self.llm.arun({k: v for k, v in zip(self.prompt_kwargs, args)}, callbacks=[callback]),
wrap_done(self.llm.arun({k: v for k, v in zip(self.prompt_kwargs, (similarity, _question))}, callbacks=[callback]),
callback.done))
self.cur_answer = ""
async for token in callback.aiter():
......@@ -73,11 +93,10 @@ class QA:
yield f"{self.cur_answer}"
print(datetime.now())
await task
print('----------------',self.cur_question)
print('================',self.cur_answer)
print('----------------', self.cur_question)
print('================', self.cur_answer)
print(datetime.now())
def get_history(self):
return self.history
......@@ -97,8 +116,16 @@ if __name__ == "__main__":
port=CHAT_DB_PORT, )
base_llm = ChatERNIESerLLM(
chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id='2')
print(my_chat.async_chat("当别人想你说你好的时候,你也应该说你好", "你好"))
vecstore_faiss = VectorStore_FAISS(
embedding_model_name=EMBEEDING_MODEL_PATH,
store_path=FAISS_STORE_PATH,
index_name=INDEX_NAME,
info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
"password": VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER,
reset=False)
my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id='2', _faiss_db=vecstore_faiss)
print(my_chat.chat("什么是低空经济"))
my_chat.update_history()
time.sleep(20)
print(my_chat.cur_answer)
import sys
sys.path.append("../")
from src.pgdb.chat.c_db import UPostgresDB
from src.pgdb.chat.chat_table import Chat
from src.pgdb.chat.c_user_table import CUser
......@@ -12,14 +13,14 @@ from src.config.consts import (
CHAT_DB_PASSWORD
)
sys.path.append("../")
"""测试会话相关数据可的连接"""
def test():
c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
port=CHAT_DB_PORT, )
print(c_db)
crud = CRUD(c_db)
crud.create_table()
crud.insert_turn_qa("2", "wen4", "da1", 1, 0)
......
# -*- coding: utf-8 -*-
import gradio as gr
from langchain.prompts import PromptTemplate
from src.llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion
from src.pgdb.chat.c_db import UPostgresDB
from src.server.get_similarity import GetSimilarity
from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.config.consts import (
CHAT_DB_USER,
CHAT_DB_HOST,
CHAT_DB_PORT,
CHAT_DB_DBNAME,
CHAT_DB_PASSWORD,
EMBEEDING_MODEL_PATH,
FAISS_STORE_PATH,
INDEX_NAME,
VEC_DB_HOST,
VEC_DB_PASSWORD,
VEC_DB_PORT,
VEC_DB_USER,
VEC_DB_DBNAME,
SIMILARITY_SHOW_NUMBER
)
from src.server.qa import QA
prompt1 = """'''
{context}
'''
请你根据上述已知资料回答下面的问题,问题如下:
{question}"""
PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=prompt1)
def main():
c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
port=CHAT_DB_PORT, )
vecstore_faiss = VectorStore_FAISS(
embedding_model_name=EMBEEDING_MODEL_PATH,
store_path=FAISS_STORE_PATH,
index_name=INDEX_NAME,
info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
"password": VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER,
reset=False)
base_llm = ChatERNIESerLLM(
chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id='2', _faiss_db=vecstore_faiss)
with gr.Blocks() as demo:
with gr.Row():
in1 = gr.Textbox(show_label=True, lines=10, visible=False)
in2 = gr.Textbox(show_label=True, lines=10)
with gr.Row():
qabtn = gr.Button("SUBMIT")
out = gr.Textbox(show_label=True, lines=10)
qabtn.click(my_chat.async_chat, [in2], [out])
demo.queue().launch(share=False, inbrowser=True, server_name="192.168.100.76", server_port=8888)
if __name__ == "__main__":
main()
\ No newline at end of file
......@@ -78,9 +78,9 @@ def test_faiss_load():
"password": VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER,
reset=False)
print(vecstore_faiss._join_document(vecstore_faiss.get_text_similarity("征信业务有什么情况")))
print(vecstore_faiss.join_document(vecstore_faiss.get_text_similarity("什么是低空飞行")))
if __name__ == "__main__":
# test_faiss_from_dir()
test_faiss_from_dir()
test_faiss_load()
from functools import reduce
def add_three(x,y):
return x + y
li = [1,2,3,5]
reduce(add_three, li)#=> 11
class Solution:
@staticmethod
def numDecodings(s: str) -> int:
length = len(s)
ans = [0 for i in range(length+1)]
ans[0] = 0
ans[1] = 1
for i in range(1, length+1):
print(i)
if s[i - 1] + s[i] == '10' or s[i - 1] + s[i] == '11' or s[i - 1] + s[i] == '12' or s[i - 1] + s[
i] == '13' or s[i - 1] + s[i] == '14' or s[i - 1] + s[i] == '15' or s[i - 1] + s[i] == '16' or s[
i - 1] + s[i] == '17' or s[i - 1] + s[i] == '18' or s[i - 1] + s[i] == '19' or s[i - 1] + s[
i] == '20' or s[i - 1] + s[i] == '21' or s[i - 1] + s[i] == '22' or s[i - 1] + s[i] == '23' or s[
i - 1] + s[i] == '24' or s[i - 1] + s[i] == '25' or s[i - 1] + s[i] == '26':
if s[i] == '0':
ans[i] = ans[i - 1] + 1
else:
ans[i] = ans[i - 1] + 2
else:
ans[i] = ans[i - 1] + 1
print(ans)
return ans[length-1]
Solution.numDecodings("226")
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment