Commit 8efec361 by 陈正乐

调用模型对话实现

parent 493cdd59
from .c_db import UPostgresDB
import json
TABLE_USER = """
DROP TABLE IF EXISTS "c_user";
......
from .c_db import UPostgresDB
import json
TABLE_CHAT = """
DROP TABLE IF EXISTS "chat";
......
from .c_db import UPostgresDB
import json
TABLE_CHAT = """
DROP TABLE IF EXISTS "turn_qa";
......
import os, sys
from os import path
import sys
sys.path.append("../")
from abc import ABC, abstractmethod
import json
from typing import List, Any, Tuple, Dict
from typing import List, Tuple
from langchain.schema import Document
from src.pgdb.knowledge.pgsqldocstore import PgSqlDocstore, str2hash_base64
sys.path.append("../")
class DocumentCallback(ABC):
@abstractmethod # 向量库储存前文档处理--
......@@ -16,7 +16,7 @@ class DocumentCallback(ABC):
@abstractmethod # 向量库查询后文档处理--用于结构建立
def after_search(self, docstore: PgSqlDocstore, documents: List[Tuple[Document, float]], number: int = 1000) -> \
List[Tuple[Document, float]]: # 向量库查询后文档处理
List[Tuple[Document, float]]: # 向量库查询后文档处理
pass
......@@ -31,7 +31,7 @@ class DefaultDocumentCallback(DocumentCallback):
return output_doc
def after_search(self, docstore: PgSqlDocstore, documents: List[Tuple[Document, float]], number: int = 1000) -> \
List[Tuple[Document, float]]: # 向量库查询后文档处理
List[Tuple[Document, float]]: # 向量库查询后文档处理
output_doc: List[Tuple[Document, float]] = []
exist_hash = []
for doc, score in documents:
......
import os, sys
import re, time
import os
import sys
import re
from os import path
sys.path.append("../")
import copy
from typing import List, OrderedDict, Any, Optional, Tuple, Dict
from src.pgdb.knowledge.pgsqldocstore import InMemorySecondaryDocstore
......@@ -23,6 +23,10 @@ from langchain.callbacks.manager import (
from src.loader import load
from langchain.embeddings.base import Embeddings
from src.pgdb.knowledge.callback import DocumentCallback, DefaultDocumentCallback
import operator
from langchain.vectorstores.utils import DistanceStrategy
import numpy as np
sys.path.append("../")
def singleton(cls):
......@@ -38,22 +42,20 @@ def singleton(cls):
@singleton
class EmbeddingFactory:
def __init__(self, path: str):
self.path = path
self.embedding = HuggingFaceEmbeddings(model_name=path)
def __init__(self, _path: str):
self.path = _path
self.embedding = HuggingFaceEmbeddings(model_name=_path)
def get_embedding(self):
return self.embedding
def GetEmbding(_path: str) -> Embeddings:
def get_embding(_path: str) -> Embeddings:
# return HuggingFaceEmbeddings(model_name=path)
return EmbeddingFactory(_path).get_embedding()
import operator
from langchain.vectorstores.utils import DistanceStrategy
import numpy as np
class RE_FAISS(FAISS):
......@@ -159,7 +161,7 @@ class RE_FAISS(FAISS):
def getFAISS(embedding_model_name: str, store_path: str, info: dict = None, index_name: str = "index",
is_pgsql: bool = True, reset: bool = False) -> RE_FAISS:
embeddings = GetEmbding(_path=embedding_model_name)
embeddings = get_embding(_path=embedding_model_name)
docstore1: PgSqlDocstore = None
if is_pgsql:
if info and "host" in info and "dbname" in info and "username" in info and "password" in info:
......@@ -298,7 +300,7 @@ class VectorStore_FAISS(FAISS):
Examples:
.. code-block:: python
. code-block:: python
# Retrieve more documents with higher diversity
# Useful if your dataset has many similar documents
......
# -*- coding: utf-8 -*-
import sys
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from typing import Awaitable
import asyncio
from langchain.callbacks import AsyncIteratorCallbackHandler
from src.llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion
sys.path.append("../..")
prompt1 = """'''
{context}
'''
请你根据上述已知资料回答下面的问题,问题如下:
{question}"""
PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=prompt1)
def chat(context, question):
base_llm = ChatERNIESerLLM(
chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
erniellm = LLMChain(llm=base_llm, prompt=PROMPT1, llm_kwargs={"temperature": 0.9})
if not context and not question:
return ""
result = erniellm.run({"context": context, "question": question})
return result
async def async_chat_stc(context, question):
base_llm = ChatERNIESerLLM(
chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
erniellm = LLMChain(llm=base_llm, prompt=PROMPT1, llm_kwargs={"temperature": 0.9})
callback = AsyncIteratorCallbackHandler()
async def wrap_done(fn: Awaitable, event: asyncio.Event):
try:
await fn
except Exception as e:
import traceback
traceback.print_exc()
print(f"Caught exception: {e}")
finally:
event.set()
task = asyncio.create_task(
wrap_done(erniellm.arun({"context": context, "question": question}, callbacks=[callback]), callback.done))
print("*" * 20)
text = ""
async for token in callback.aiter():
text = text + token
yield f"{text}"
await task
if __name__ == "__main__":
print("main函数begin")
print(chat("当别人想你说你好的时候,你也应该说你好", "你好"))
import sys
sys.path.append("../")
from src.pgdb.chat.c_db import UPostgresDB
from src.pgdb.chat.chat_table import Chat
from src.pgdb.chat.c_user_table import CUser
from src.pgdb.chat.turn_qa_table import TurnQa
sys.path.append("../")
"""测试会话相关数据可的连接"""
......@@ -29,5 +29,5 @@ def test():
turn_qa.insert(["222", "1111", "nihao", "nihao", 1, 0])
if __name__ == "main":
if __name__ == "__main__":
test()
import sys
sys.path.append("../")
from src.loader.load import loads_path
from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.config.consts import (
......@@ -17,9 +16,11 @@ from src.config.consts import (
)
from src.loader.callback import BaseCallback
sys.path.append("../")
# 当返回值中带有“思考题”字样的时候,默认将其忽略。
class localCallback(BaseCallback):
class LocalCallBack(BaseCallback):
def filter(self, title: str, content: str) -> bool:
if len(title + content) == 0:
return True
......@@ -38,7 +39,7 @@ def test_faiss_from_dir():
"password": VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER,
reset=True)
docs = loads_path(KNOWLEDGE_PATH, mode="paged", sentence_size=512, callbacks=[localCallback()])
docs = loads_path(KNOWLEDGE_PATH, mode="paged", sentence_size=512, callbacks=[LocalCallBack()])
print(len(docs))
last_doc = None
docs1 = []
......
from src.server.qa import chat
print(chat("当别人想你说你好的时候,你也应该说你好", "你好"))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment