Commit 8efec361 by 陈正乐

调用模型对话实现

parent 493cdd59
from .c_db import UPostgresDB from .c_db import UPostgresDB
import json
TABLE_USER = """ TABLE_USER = """
DROP TABLE IF EXISTS "c_user"; DROP TABLE IF EXISTS "c_user";
......
from .c_db import UPostgresDB from .c_db import UPostgresDB
import json
TABLE_CHAT = """ TABLE_CHAT = """
DROP TABLE IF EXISTS "chat"; DROP TABLE IF EXISTS "chat";
......
from .c_db import UPostgresDB from .c_db import UPostgresDB
import json
TABLE_CHAT = """ TABLE_CHAT = """
DROP TABLE IF EXISTS "turn_qa"; DROP TABLE IF EXISTS "turn_qa";
......
import os, sys import sys
from os import path
sys.path.append("../")
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import json import json
from typing import List, Any, Tuple, Dict from typing import List, Tuple
from langchain.schema import Document from langchain.schema import Document
from src.pgdb.knowledge.pgsqldocstore import PgSqlDocstore, str2hash_base64 from src.pgdb.knowledge.pgsqldocstore import PgSqlDocstore, str2hash_base64
sys.path.append("../")
class DocumentCallback(ABC): class DocumentCallback(ABC):
@abstractmethod # 向量库储存前文档处理-- @abstractmethod # 向量库储存前文档处理--
......
import os, sys import os
import re, time import sys
import re
from os import path from os import path
sys.path.append("../")
import copy import copy
from typing import List, OrderedDict, Any, Optional, Tuple, Dict from typing import List, OrderedDict, Any, Optional, Tuple, Dict
from src.pgdb.knowledge.pgsqldocstore import InMemorySecondaryDocstore from src.pgdb.knowledge.pgsqldocstore import InMemorySecondaryDocstore
...@@ -23,6 +23,10 @@ from langchain.callbacks.manager import ( ...@@ -23,6 +23,10 @@ from langchain.callbacks.manager import (
from src.loader import load from src.loader import load
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from src.pgdb.knowledge.callback import DocumentCallback, DefaultDocumentCallback from src.pgdb.knowledge.callback import DocumentCallback, DefaultDocumentCallback
import operator
from langchain.vectorstores.utils import DistanceStrategy
import numpy as np
sys.path.append("../")
def singleton(cls): def singleton(cls):
...@@ -38,22 +42,20 @@ def singleton(cls): ...@@ -38,22 +42,20 @@ def singleton(cls):
@singleton @singleton
class EmbeddingFactory: class EmbeddingFactory:
def __init__(self, path: str): def __init__(self, _path: str):
self.path = path self.path = _path
self.embedding = HuggingFaceEmbeddings(model_name=path) self.embedding = HuggingFaceEmbeddings(model_name=_path)
def get_embedding(self): def get_embedding(self):
return self.embedding return self.embedding
def GetEmbding(_path: str) -> Embeddings: def get_embding(_path: str) -> Embeddings:
# return HuggingFaceEmbeddings(model_name=path) # return HuggingFaceEmbeddings(model_name=path)
return EmbeddingFactory(_path).get_embedding() return EmbeddingFactory(_path).get_embedding()
import operator
from langchain.vectorstores.utils import DistanceStrategy
import numpy as np
class RE_FAISS(FAISS): class RE_FAISS(FAISS):
...@@ -159,7 +161,7 @@ class RE_FAISS(FAISS): ...@@ -159,7 +161,7 @@ class RE_FAISS(FAISS):
def getFAISS(embedding_model_name: str, store_path: str, info: dict = None, index_name: str = "index", def getFAISS(embedding_model_name: str, store_path: str, info: dict = None, index_name: str = "index",
is_pgsql: bool = True, reset: bool = False) -> RE_FAISS: is_pgsql: bool = True, reset: bool = False) -> RE_FAISS:
embeddings = GetEmbding(_path=embedding_model_name) embeddings = get_embding(_path=embedding_model_name)
docstore1: PgSqlDocstore = None docstore1: PgSqlDocstore = None
if is_pgsql: if is_pgsql:
if info and "host" in info and "dbname" in info and "username" in info and "password" in info: if info and "host" in info and "dbname" in info and "username" in info and "password" in info:
...@@ -298,7 +300,7 @@ class VectorStore_FAISS(FAISS): ...@@ -298,7 +300,7 @@ class VectorStore_FAISS(FAISS):
Examples: Examples:
.. code-block:: python . code-block:: python
# Retrieve more documents with higher diversity # Retrieve more documents with higher diversity
# Useful if your dataset has many similar documents # Useful if your dataset has many similar documents
......
# -*- coding: utf-8 -*-
import sys
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from typing import Awaitable
import asyncio
from langchain.callbacks import AsyncIteratorCallbackHandler
from src.llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion
sys.path.append("../..")
prompt1 = """'''
{context}
'''
请你根据上述已知资料回答下面的问题,问题如下:
{question}"""
PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=prompt1)
def chat(context, question):
base_llm = ChatERNIESerLLM(
chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
erniellm = LLMChain(llm=base_llm, prompt=PROMPT1, llm_kwargs={"temperature": 0.9})
if not context and not question:
return ""
result = erniellm.run({"context": context, "question": question})
return result
async def async_chat_stc(context, question):
base_llm = ChatERNIESerLLM(
chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
erniellm = LLMChain(llm=base_llm, prompt=PROMPT1, llm_kwargs={"temperature": 0.9})
callback = AsyncIteratorCallbackHandler()
async def wrap_done(fn: Awaitable, event: asyncio.Event):
try:
await fn
except Exception as e:
import traceback
traceback.print_exc()
print(f"Caught exception: {e}")
finally:
event.set()
task = asyncio.create_task(
wrap_done(erniellm.arun({"context": context, "question": question}, callbacks=[callback]), callback.done))
print("*" * 20)
text = ""
async for token in callback.aiter():
text = text + token
yield f"{text}"
await task
if __name__ == "__main__":
print("main函数begin")
print(chat("当别人想你说你好的时候,你也应该说你好", "你好"))
import sys import sys
sys.path.append("../")
from src.pgdb.chat.c_db import UPostgresDB from src.pgdb.chat.c_db import UPostgresDB
from src.pgdb.chat.chat_table import Chat from src.pgdb.chat.chat_table import Chat
from src.pgdb.chat.c_user_table import CUser from src.pgdb.chat.c_user_table import CUser
from src.pgdb.chat.turn_qa_table import TurnQa from src.pgdb.chat.turn_qa_table import TurnQa
sys.path.append("../")
"""测试会话相关数据可的连接""" """测试会话相关数据可的连接"""
...@@ -29,5 +29,5 @@ def test(): ...@@ -29,5 +29,5 @@ def test():
turn_qa.insert(["222", "1111", "nihao", "nihao", 1, 0]) turn_qa.insert(["222", "1111", "nihao", "nihao", 1, 0])
if __name__ == "main": if __name__ == "__main__":
test() test()
import sys import sys
sys.path.append("../")
from src.loader.load import loads_path from src.loader.load import loads_path
from src.pgdb.knowledge.similarity import VectorStore_FAISS from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.config.consts import ( from src.config.consts import (
...@@ -17,9 +16,11 @@ from src.config.consts import ( ...@@ -17,9 +16,11 @@ from src.config.consts import (
) )
from src.loader.callback import BaseCallback from src.loader.callback import BaseCallback
sys.path.append("../")
# 当返回值中带有“思考题”字样的时候,默认将其忽略。 # 当返回值中带有“思考题”字样的时候,默认将其忽略。
class localCallback(BaseCallback): class LocalCallBack(BaseCallback):
def filter(self, title: str, content: str) -> bool: def filter(self, title: str, content: str) -> bool:
if len(title + content) == 0: if len(title + content) == 0:
return True return True
...@@ -38,7 +39,7 @@ def test_faiss_from_dir(): ...@@ -38,7 +39,7 @@ def test_faiss_from_dir():
"password": VEC_DB_PASSWORD}, "password": VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER, show_number=SIMILARITY_SHOW_NUMBER,
reset=True) reset=True)
docs = loads_path(KNOWLEDGE_PATH, mode="paged", sentence_size=512, callbacks=[localCallback()]) docs = loads_path(KNOWLEDGE_PATH, mode="paged", sentence_size=512, callbacks=[LocalCallBack()])
print(len(docs)) print(len(docs))
last_doc = None last_doc = None
docs1 = [] docs1 = []
......
from src.server.qa import chat
print(chat("当别人想你说你好的时候,你也应该说你好", "你好"))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment