From 14a6b87b82fd815e69e77fc02c8efef5cb7919f6 Mon Sep 17 00:00:00 2001
From: 陈正乐 <chenzhengle@brilliance.com.cn>
Date: Tue, 30 Apr 2024 10:06:00 +0800
Subject: [PATCH] 修改在询问模型时,所需参数个数

---
 src/config/consts.py             |  2 +-
 src/pgdb/knowledge/similarity.py |  2 +-
 src/server/get_similarity.py     | 12 ++++++++++++
 src/server/qa.py                 | 55 +++++++++++++++++++++++++++++++++++++++++--------------
 test/chat_table_test.py          |  5 +++--
 test/gradio_text.py              | 61 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 test/k_store_test.py             |  4 ++--
 test/lk_test.py                  | 29 ++++++++++++++++++++++++-----
 8 files changed, 145 insertions(+), 25 deletions(-)
 create mode 100644 src/server/get_similarity.py
 create mode 100644 test/gradio_text.py

diff --git a/src/config/consts.py b/src/config/consts.py
index d738547..3569297 100644
--- a/src/config/consts.py
+++ b/src/config/consts.py
@@ -41,4 +41,4 @@ INDEX_NAME = 'know'
 # =============================
 # 知识相关资料配置
 # =============================
-KNOWLEDGE_PATH = 'C:\\Users\\15663\\Desktop\\work\\llm_gjjs\\兴火燎原知识库\\兴火燎原知识库\\law\\pdf'
+KNOWLEDGE_PATH = 'C:\\Users\\15663\\Desktop\\低空经济数据库'
diff --git a/src/pgdb/knowledge/similarity.py b/src/pgdb/knowledge/similarity.py
index 422a798..3459596 100644
--- a/src/pgdb/knowledge/similarity.py
+++ b/src/pgdb/knowledge/similarity.py
@@ -227,7 +227,7 @@ class VectorStore_FAISS(FAISS):
     #     return deduplicated_documents
 
     @staticmethod
-    def _join_document(docs: List[Document]) -> str:
+    def join_document(docs: List[Document]) -> str:
         print(docs)
         return "".join([doc.page_content for doc in docs])
 
diff --git a/src/server/get_similarity.py b/src/server/get_similarity.py
new file mode 100644
index 0000000..d386792
--- /dev/null
+++ b/src/server/get_similarity.py
@@ -0,0 +1,12 @@
+from src.pgdb.knowledge.similarity import VectorStore_FAISS
+
+
+class GetSimilarity:
+    def __init__(self, _question, _faiss_db: VectorStore_FAISS):
+        self.question = _question
+        self.faiss_db = _faiss_db
+        self.similarity_doc = self.faiss_db.join_document(self.faiss_db.get_text_similarity("什么是低空飞行"))
+
+    def get_similarity_doc(self):
+        return self.similarity_doc
+
diff --git a/src/server/qa.py b/src/server/qa.py
index 8cf3577..244593e 100644
--- a/src/server/qa.py
+++ b/src/server/qa.py
@@ -11,12 +11,23 @@ from src.llm.ernie_with_sdk import ChatERNIESerLLM
 from qianfan import ChatCompletion
 from src.pgdb.chat.c_db import UPostgresDB
 from src.pgdb.chat.crud import CRUD
+from src.server.get_similarity import GetSimilarity
+from src.pgdb.knowledge.similarity import VectorStore_FAISS
 from src.config.consts import (
     CHAT_DB_USER,
     CHAT_DB_HOST,
     CHAT_DB_PORT,
     CHAT_DB_DBNAME,
-    CHAT_DB_PASSWORD
+    CHAT_DB_PASSWORD,
+    EMBEEDING_MODEL_PATH,
+    FAISS_STORE_PATH,
+    INDEX_NAME,
+    VEC_DB_HOST,
+    VEC_DB_PASSWORD,
+    VEC_DB_PORT,
+    VEC_DB_USER,
+    VEC_DB_DBNAME,
+    SIMILARITY_SHOW_NUMBER
 )
 
 sys.path.append("../..")
@@ -29,32 +40,40 @@ PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=promp
 
 
 class QA:
-    def __init__(self, _prompt, _base_llm, _llm_kwargs, _prompt_kwargs, _db, _chat_id):
+    def __init__(self, _prompt, _base_llm, _llm_kwargs, _prompt_kwargs, _db, _chat_id, _faiss_db):
         self.prompt = _prompt
         self.base_llm = _base_llm
         self.llm_kwargs = _llm_kwargs
         self.prompt_kwargs = _prompt_kwargs
         self.db = _db
         self.chat_id = _chat_id
+        self.faiss_db = _faiss_db
         self.crud = CRUD(self.db)
         self.history = self.crud.get_history(self.chat_id)
         self.llm = LLMChain(llm=self.base_llm, prompt=self.prompt, llm_kwargs=self.llm_kwargs)
         self.cur_answer = ""
         self.cur_question = ""
 
+    # 为所给问题返回similarity文本
+    def get_similarity(self, _aquestion):
+        return GetSimilarity(_question=_aquestion, _faiss_db=self.faiss_db).get_similarity_doc()
+
     # 一次性直接给出所有的答案
-    async def chat(self, *args):
-        self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, args)})
+    async def chat(self, _question):
+        similarity = self.get_similarity(_aquestion=_question)
+        self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, (similarity, _question))})
         self.cur_answer = ""
-        if not args:
+        if not _question:
             return ""
-        self.cur_answer = self.llm.run({k: v for k, v in zip(self.prompt_kwargs, args)})
+        self.cur_answer = self.llm.run({k: v for k, v in zip(self.prompt_kwargs, (similarity, _question))})
         return self.cur_answer
 
     # 异步输出,逐渐输出答案
-    async def async_chat(self, *args):
-        self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, args)})
+    async def async_chat(self, _question):
+        similarity = self.get_similarity(_aquestion=_question)
+        self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, (similarity, _question))})
         callback = AsyncIteratorCallbackHandler()
+
         async def wrap_done(fn: Awaitable, event: asyncio.Event):
             try:
                 await fn
@@ -64,8 +83,9 @@ class QA:
                 print(f"Caught exception: {e}")
             finally:
                 event.set()
+
         task = asyncio.create_task(
-            wrap_done(self.llm.arun({k: v for k, v in zip(self.prompt_kwargs, args)}, callbacks=[callback]),
+            wrap_done(self.llm.arun({k: v for k, v in zip(self.prompt_kwargs, (similarity, _question))}, callbacks=[callback]),
                       callback.done))
         self.cur_answer = ""
         async for token in callback.aiter():
@@ -73,11 +93,10 @@ class QA:
             yield f"{self.cur_answer}"
         print(datetime.now())
         await task
-        print('----------------',self.cur_question)
-        print('================',self.cur_answer)
+        print('----------------', self.cur_question)
+        print('================', self.cur_answer)
         print(datetime.now())
 
-
     def get_history(self):
         return self.history
 
@@ -97,8 +116,16 @@ if __name__ == "__main__":
                        port=CHAT_DB_PORT, )
     base_llm = ChatERNIESerLLM(
         chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
-    my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id='2')
-    print(my_chat.async_chat("当别人想你说你好的时候,你也应该说你好", "你好"))
+    vecstore_faiss = VectorStore_FAISS(
+        embedding_model_name=EMBEEDING_MODEL_PATH,
+        store_path=FAISS_STORE_PATH,
+        index_name=INDEX_NAME,
+        info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
+              "password": VEC_DB_PASSWORD},
+        show_number=SIMILARITY_SHOW_NUMBER,
+        reset=False)
+    my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id='2', _faiss_db=vecstore_faiss)
+    print(my_chat.chat("什么是低空经济"))
     my_chat.update_history()
     time.sleep(20)
     print(my_chat.cur_answer)
diff --git a/test/chat_table_test.py b/test/chat_table_test.py
index 63a41da..468a0a4 100644
--- a/test/chat_table_test.py
+++ b/test/chat_table_test.py
@@ -1,4 +1,5 @@
 import sys
+sys.path.append("../")
 from src.pgdb.chat.c_db import UPostgresDB
 from src.pgdb.chat.chat_table import Chat
 from src.pgdb.chat.c_user_table import CUser
@@ -12,14 +13,14 @@ from src.config.consts import (
     CHAT_DB_PASSWORD
 )
 
-sys.path.append("../")
+
 """测试会话相关数据可的连接"""
 
 
 def test():
     c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
                        port=CHAT_DB_PORT, )
-
+    print(c_db)
     crud = CRUD(c_db)
     crud.create_table()
     crud.insert_turn_qa("2", "wen4", "da1", 1, 0)
diff --git a/test/gradio_text.py b/test/gradio_text.py
new file mode 100644
index 0000000..c561acf
--- /dev/null
+++ b/test/gradio_text.py
@@ -0,0 +1,61 @@
+# -*- coding: utf-8 -*-
+import gradio as gr
+from langchain.prompts import PromptTemplate
+from src.llm.ernie_with_sdk import ChatERNIESerLLM
+from qianfan import ChatCompletion
+from src.pgdb.chat.c_db import UPostgresDB
+from src.server.get_similarity import GetSimilarity
+from src.pgdb.knowledge.similarity import VectorStore_FAISS
+from src.config.consts import (
+    CHAT_DB_USER,
+    CHAT_DB_HOST,
+    CHAT_DB_PORT,
+    CHAT_DB_DBNAME,
+    CHAT_DB_PASSWORD,
+    EMBEEDING_MODEL_PATH,
+    FAISS_STORE_PATH,
+    INDEX_NAME,
+    VEC_DB_HOST,
+    VEC_DB_PASSWORD,
+    VEC_DB_PORT,
+    VEC_DB_USER,
+    VEC_DB_DBNAME,
+    SIMILARITY_SHOW_NUMBER
+)
+from src.server.qa import QA
+
+prompt1 = """'''
+{context}
+'''
+请你根据上述已知资料回答下面的问题,问题如下:
+{question}"""
+
+PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=prompt1)
+def main():
+
+    c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
+                       port=CHAT_DB_PORT, )
+    vecstore_faiss = VectorStore_FAISS(
+        embedding_model_name=EMBEEDING_MODEL_PATH,
+        store_path=FAISS_STORE_PATH,
+        index_name=INDEX_NAME,
+        info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
+              "password": VEC_DB_PASSWORD},
+        show_number=SIMILARITY_SHOW_NUMBER,
+        reset=False)
+    base_llm = ChatERNIESerLLM(
+        chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
+    my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id='2', _faiss_db=vecstore_faiss)
+    with gr.Blocks() as demo:
+        with gr.Row():
+            in1 = gr.Textbox(show_label=True, lines=10, visible=False)
+            in2 = gr.Textbox(show_label=True, lines=10)
+        with gr.Row():
+            qabtn = gr.Button("SUBMIT")
+        out = gr.Textbox(show_label=True, lines=10)
+        qabtn.click(my_chat.async_chat, [in2], [out])
+    demo.queue().launch(share=False, inbrowser=True, server_name="192.168.100.76", server_port=8888)
+
+
+if __name__ == "__main__":
+    main()
\ No newline at end of file
diff --git a/test/k_store_test.py b/test/k_store_test.py
index 1804b99..29b2f6b 100644
--- a/test/k_store_test.py
+++ b/test/k_store_test.py
@@ -78,9 +78,9 @@ def test_faiss_load():
               "password": VEC_DB_PASSWORD},
         show_number=SIMILARITY_SHOW_NUMBER,
         reset=False)
-    print(vecstore_faiss._join_document(vecstore_faiss.get_text_similarity("征信业务有什么情况")))
+    print(vecstore_faiss.join_document(vecstore_faiss.get_text_similarity("什么是低空飞行")))
 
 
 if __name__ == "__main__":
-    # test_faiss_from_dir()
+    test_faiss_from_dir()
     test_faiss_load()
diff --git a/test/lk_test.py b/test/lk_test.py
index 37d10d4..da637ce 100644
--- a/test/lk_test.py
+++ b/test/lk_test.py
@@ -1,5 +1,24 @@
-from functools import reduce
-def add_three(x,y):
-    return x + y
-li = [1,2,3,5]
-reduce(add_three, li)#=> 11
+class Solution:
+    @staticmethod
+    def numDecodings(s: str) -> int:
+        length = len(s)
+        ans = [0 for i in range(length+1)]
+        ans[0] = 0
+        ans[1] = 1
+        for i in range(1, length+1):
+            print(i)
+            if s[i - 1] + s[i] == '10' or s[i - 1] + s[i] == '11' or s[i - 1] + s[i] == '12' or s[i - 1] + s[
+                i] == '13' or s[i - 1] + s[i] == '14' or s[i - 1] + s[i] == '15' or s[i - 1] + s[i] == '16' or s[
+                i - 1] + s[i] == '17' or s[i - 1] + s[i] == '18' or s[i - 1] + s[i] == '19' or s[i - 1] + s[
+                i] == '20' or s[i - 1] + s[i] == '21' or s[i - 1] + s[i] == '22' or s[i - 1] + s[i] == '23' or s[
+                i - 1] + s[i] == '24' or s[i - 1] + s[i] == '25' or s[i - 1] + s[i] == '26':
+                if s[i] == '0':
+                    ans[i] = ans[i - 1] + 1
+                else:
+                    ans[i] = ans[i - 1] + 2
+            else:
+                ans[i] = ans[i - 1] + 1
+        print(ans)
+        return ans[length-1]
+
+Solution.numDecodings("226")
\ No newline at end of file
--
libgit2 0.26.0