From 14a6b87b82fd815e69e77fc02c8efef5cb7919f6 Mon Sep 17 00:00:00 2001 From: 陈正乐 <chenzhengle@brilliance.com.cn> Date: Tue, 30 Apr 2024 10:06:00 +0800 Subject: [PATCH] 修改在询问模型时,所需参数个数 --- src/config/consts.py | 2 +- src/pgdb/knowledge/similarity.py | 2 +- src/server/get_similarity.py | 12 ++++++++++++ src/server/qa.py | 55 +++++++++++++++++++++++++++++++++++++++++-------------- test/chat_table_test.py | 5 +++-- test/gradio_text.py | 61 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ test/k_store_test.py | 4 ++-- test/lk_test.py | 29 ++++++++++++++++++++++++----- 8 files changed, 145 insertions(+), 25 deletions(-) create mode 100644 src/server/get_similarity.py create mode 100644 test/gradio_text.py diff --git a/src/config/consts.py b/src/config/consts.py index d738547..3569297 100644 --- a/src/config/consts.py +++ b/src/config/consts.py @@ -41,4 +41,4 @@ INDEX_NAME = 'know' # ============================= # 知识相关资料配置 # ============================= -KNOWLEDGE_PATH = 'C:\\Users\\15663\\Desktop\\work\\llm_gjjs\\兴火燎原知识库\\兴火燎原知识库\\law\\pdf' +KNOWLEDGE_PATH = 'C:\\Users\\15663\\Desktop\\低空经济数据库' diff --git a/src/pgdb/knowledge/similarity.py b/src/pgdb/knowledge/similarity.py index 422a798..3459596 100644 --- a/src/pgdb/knowledge/similarity.py +++ b/src/pgdb/knowledge/similarity.py @@ -227,7 +227,7 @@ class VectorStore_FAISS(FAISS): # return deduplicated_documents @staticmethod - def _join_document(docs: List[Document]) -> str: + def join_document(docs: List[Document]) -> str: print(docs) return "".join([doc.page_content for doc in docs]) diff --git a/src/server/get_similarity.py b/src/server/get_similarity.py new file mode 100644 index 0000000..d386792 --- /dev/null +++ b/src/server/get_similarity.py @@ -0,0 +1,12 @@ +from src.pgdb.knowledge.similarity import VectorStore_FAISS + + +class GetSimilarity: + def __init__(self, _question, _faiss_db: VectorStore_FAISS): + self.question = _question + self.faiss_db = _faiss_db + self.similarity_doc = self.faiss_db.join_document(self.faiss_db.get_text_similarity("什么是低空飞行")) + + def get_similarity_doc(self): + return self.similarity_doc + diff --git a/src/server/qa.py b/src/server/qa.py index 8cf3577..244593e 100644 --- a/src/server/qa.py +++ b/src/server/qa.py @@ -11,12 +11,23 @@ from src.llm.ernie_with_sdk import ChatERNIESerLLM from qianfan import ChatCompletion from src.pgdb.chat.c_db import UPostgresDB from src.pgdb.chat.crud import CRUD +from src.server.get_similarity import GetSimilarity +from src.pgdb.knowledge.similarity import VectorStore_FAISS from src.config.consts import ( CHAT_DB_USER, CHAT_DB_HOST, CHAT_DB_PORT, CHAT_DB_DBNAME, - CHAT_DB_PASSWORD + CHAT_DB_PASSWORD, + EMBEEDING_MODEL_PATH, + FAISS_STORE_PATH, + INDEX_NAME, + VEC_DB_HOST, + VEC_DB_PASSWORD, + VEC_DB_PORT, + VEC_DB_USER, + VEC_DB_DBNAME, + SIMILARITY_SHOW_NUMBER ) sys.path.append("../..") @@ -29,32 +40,40 @@ PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=promp class QA: - def __init__(self, _prompt, _base_llm, _llm_kwargs, _prompt_kwargs, _db, _chat_id): + def __init__(self, _prompt, _base_llm, _llm_kwargs, _prompt_kwargs, _db, _chat_id, _faiss_db): self.prompt = _prompt self.base_llm = _base_llm self.llm_kwargs = _llm_kwargs self.prompt_kwargs = _prompt_kwargs self.db = _db self.chat_id = _chat_id + self.faiss_db = _faiss_db self.crud = CRUD(self.db) self.history = self.crud.get_history(self.chat_id) self.llm = LLMChain(llm=self.base_llm, prompt=self.prompt, llm_kwargs=self.llm_kwargs) self.cur_answer = "" self.cur_question = "" + # 为所给问题返回similarity文本 + def get_similarity(self, _aquestion): + return GetSimilarity(_question=_aquestion, _faiss_db=self.faiss_db).get_similarity_doc() + # 一次性直接给出所有的答案 - async def chat(self, *args): - self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, args)}) + async def chat(self, _question): + similarity = self.get_similarity(_aquestion=_question) + self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, (similarity, _question))}) self.cur_answer = "" - if not args: + if not _question: return "" - self.cur_answer = self.llm.run({k: v for k, v in zip(self.prompt_kwargs, args)}) + self.cur_answer = self.llm.run({k: v for k, v in zip(self.prompt_kwargs, (similarity, _question))}) return self.cur_answer # 异步输出,逐渐输出答案 - async def async_chat(self, *args): - self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, args)}) + async def async_chat(self, _question): + similarity = self.get_similarity(_aquestion=_question) + self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, (similarity, _question))}) callback = AsyncIteratorCallbackHandler() + async def wrap_done(fn: Awaitable, event: asyncio.Event): try: await fn @@ -64,8 +83,9 @@ class QA: print(f"Caught exception: {e}") finally: event.set() + task = asyncio.create_task( - wrap_done(self.llm.arun({k: v for k, v in zip(self.prompt_kwargs, args)}, callbacks=[callback]), + wrap_done(self.llm.arun({k: v for k, v in zip(self.prompt_kwargs, (similarity, _question))}, callbacks=[callback]), callback.done)) self.cur_answer = "" async for token in callback.aiter(): @@ -73,11 +93,10 @@ class QA: yield f"{self.cur_answer}" print(datetime.now()) await task - print('----------------',self.cur_question) - print('================',self.cur_answer) + print('----------------', self.cur_question) + print('================', self.cur_answer) print(datetime.now()) - def get_history(self): return self.history @@ -97,8 +116,16 @@ if __name__ == "__main__": port=CHAT_DB_PORT, ) base_llm = ChatERNIESerLLM( chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu")) - my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id='2') - print(my_chat.async_chat("当别人想你说你好的时候,你也应该说你好", "你好")) + vecstore_faiss = VectorStore_FAISS( + embedding_model_name=EMBEEDING_MODEL_PATH, + store_path=FAISS_STORE_PATH, + index_name=INDEX_NAME, + info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER, + "password": VEC_DB_PASSWORD}, + show_number=SIMILARITY_SHOW_NUMBER, + reset=False) + my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id='2', _faiss_db=vecstore_faiss) + print(my_chat.chat("什么是低空经济")) my_chat.update_history() time.sleep(20) print(my_chat.cur_answer) diff --git a/test/chat_table_test.py b/test/chat_table_test.py index 63a41da..468a0a4 100644 --- a/test/chat_table_test.py +++ b/test/chat_table_test.py @@ -1,4 +1,5 @@ import sys +sys.path.append("../") from src.pgdb.chat.c_db import UPostgresDB from src.pgdb.chat.chat_table import Chat from src.pgdb.chat.c_user_table import CUser @@ -12,14 +13,14 @@ from src.config.consts import ( CHAT_DB_PASSWORD ) -sys.path.append("../") + """测试会话相关数据可的连接""" def test(): c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD, port=CHAT_DB_PORT, ) - + print(c_db) crud = CRUD(c_db) crud.create_table() crud.insert_turn_qa("2", "wen4", "da1", 1, 0) diff --git a/test/gradio_text.py b/test/gradio_text.py new file mode 100644 index 0000000..c561acf --- /dev/null +++ b/test/gradio_text.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +import gradio as gr +from langchain.prompts import PromptTemplate +from src.llm.ernie_with_sdk import ChatERNIESerLLM +from qianfan import ChatCompletion +from src.pgdb.chat.c_db import UPostgresDB +from src.server.get_similarity import GetSimilarity +from src.pgdb.knowledge.similarity import VectorStore_FAISS +from src.config.consts import ( + CHAT_DB_USER, + CHAT_DB_HOST, + CHAT_DB_PORT, + CHAT_DB_DBNAME, + CHAT_DB_PASSWORD, + EMBEEDING_MODEL_PATH, + FAISS_STORE_PATH, + INDEX_NAME, + VEC_DB_HOST, + VEC_DB_PASSWORD, + VEC_DB_PORT, + VEC_DB_USER, + VEC_DB_DBNAME, + SIMILARITY_SHOW_NUMBER +) +from src.server.qa import QA + +prompt1 = """''' +{context} +''' +请你根据上述已知资料回答下面的问题,问题如下: +{question}""" + +PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=prompt1) +def main(): + + c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD, + port=CHAT_DB_PORT, ) + vecstore_faiss = VectorStore_FAISS( + embedding_model_name=EMBEEDING_MODEL_PATH, + store_path=FAISS_STORE_PATH, + index_name=INDEX_NAME, + info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER, + "password": VEC_DB_PASSWORD}, + show_number=SIMILARITY_SHOW_NUMBER, + reset=False) + base_llm = ChatERNIESerLLM( + chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu")) + my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id='2', _faiss_db=vecstore_faiss) + with gr.Blocks() as demo: + with gr.Row(): + in1 = gr.Textbox(show_label=True, lines=10, visible=False) + in2 = gr.Textbox(show_label=True, lines=10) + with gr.Row(): + qabtn = gr.Button("SUBMIT") + out = gr.Textbox(show_label=True, lines=10) + qabtn.click(my_chat.async_chat, [in2], [out]) + demo.queue().launch(share=False, inbrowser=True, server_name="192.168.100.76", server_port=8888) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/k_store_test.py b/test/k_store_test.py index 1804b99..29b2f6b 100644 --- a/test/k_store_test.py +++ b/test/k_store_test.py @@ -78,9 +78,9 @@ def test_faiss_load(): "password": VEC_DB_PASSWORD}, show_number=SIMILARITY_SHOW_NUMBER, reset=False) - print(vecstore_faiss._join_document(vecstore_faiss.get_text_similarity("征信业务有什么情况"))) + print(vecstore_faiss.join_document(vecstore_faiss.get_text_similarity("什么是低空飞行"))) if __name__ == "__main__": - # test_faiss_from_dir() + test_faiss_from_dir() test_faiss_load() diff --git a/test/lk_test.py b/test/lk_test.py index 37d10d4..da637ce 100644 --- a/test/lk_test.py +++ b/test/lk_test.py @@ -1,5 +1,24 @@ -from functools import reduce -def add_three(x,y): - return x + y -li = [1,2,3,5] -reduce(add_three, li)#=> 11 +class Solution: + @staticmethod + def numDecodings(s: str) -> int: + length = len(s) + ans = [0 for i in range(length+1)] + ans[0] = 0 + ans[1] = 1 + for i in range(1, length+1): + print(i) + if s[i - 1] + s[i] == '10' or s[i - 1] + s[i] == '11' or s[i - 1] + s[i] == '12' or s[i - 1] + s[ + i] == '13' or s[i - 1] + s[i] == '14' or s[i - 1] + s[i] == '15' or s[i - 1] + s[i] == '16' or s[ + i - 1] + s[i] == '17' or s[i - 1] + s[i] == '18' or s[i - 1] + s[i] == '19' or s[i - 1] + s[ + i] == '20' or s[i - 1] + s[i] == '21' or s[i - 1] + s[i] == '22' or s[i - 1] + s[i] == '23' or s[ + i - 1] + s[i] == '24' or s[i - 1] + s[i] == '25' or s[i - 1] + s[i] == '26': + if s[i] == '0': + ans[i] = ans[i - 1] + 1 + else: + ans[i] = ans[i - 1] + 2 + else: + ans[i] = ans[i - 1] + 1 + print(ans) + return ans[length-1] + +Solution.numDecodings("226") \ No newline at end of file -- libgit2 0.26.0