qa.py 9.88 KB
Newer Older
陈正乐 committed
1 2
# -*- coding: utf-8 -*-
import sys
陈正乐 committed
3
import time
陈正乐 committed
4
from langchain.chains import LLMChain
5
from langchain_core.prompts import PromptTemplate
陈正乐 committed
6 7 8 9 10
from typing import Awaitable
import asyncio
from langchain.callbacks import AsyncIteratorCallbackHandler
from src.llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion
11 12
from src.pgdb.chat.c_db import UPostgresDB
from src.pgdb.chat.crud import CRUD
13 14
from src.server.get_similarity import GetSimilarity
from src.pgdb.knowledge.similarity import VectorStore_FAISS
15 16 17 18 19
from src.config.consts import (
    CHAT_DB_USER,
    CHAT_DB_HOST,
    CHAT_DB_PORT,
    CHAT_DB_DBNAME,
20 21
    CHAT_DB_PASSWORD,
    EMBEEDING_MODEL_PATH,
tinywell committed
22
    RERANK_MODEL_PATH,
23 24 25 26 27 28 29 30
    FAISS_STORE_PATH,
    INDEX_NAME,
    VEC_DB_HOST,
    VEC_DB_PASSWORD,
    VEC_DB_PORT,
    VEC_DB_USER,
    VEC_DB_DBNAME,
    SIMILARITY_SHOW_NUMBER
31
)
tinywell committed
32
from .rerank import BgeRerank
陈正乐 committed
33 34 35 36 37 38 39 40 41

sys.path.append("../..")
prompt1 = """'''
{context}
'''
请你根据上述已知资料回答下面的问题,问题如下:
{question}"""
PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=prompt1)

42
# 预设的安全响应
43
SAFE_RESPONSE = "您好,我不具备人类属性,因此没有名字。我可以协助您完成范围广泛的任务并提供有关各种主题的信息,比如回答问题,提是供定义和解释及建议。如果您有任何问题,请随时向我提问。"
44
BLOCKED_KEYWORDS = ["文心一言", "百度", "模型"]
陈正乐 committed
45

陈正乐 committed
46

47
class QA:
tinywell committed
48
    def __init__(self, _prompt, _base_llm, _llm_kwargs, _prompt_kwargs, _db, _faiss_db,rerank:bool=False):
49 50 51 52 53
        self.prompt = _prompt
        self.base_llm = _base_llm
        self.llm_kwargs = _llm_kwargs
        self.prompt_kwargs = _prompt_kwargs
        self.db = _db
54
        self.chat_id = None
55
        self.faiss_db = _faiss_db
56
        self.crud = CRUD(self.db)
57
        self.history = None
58 59 60
        self.llm = LLMChain(llm=self.base_llm, prompt=self.prompt, llm_kwargs=self.llm_kwargs)
        self.cur_answer = ""
        self.cur_question = ""
陈正乐 committed
61 62
        self.cur_similarity = ""
        self.cur_oquestion = ""
tinywell committed
63 64 65 66
        self.rerank = rerank

        if rerank:
            self.rerank_model = BgeRerank(RERANK_MODEL_PATH)
陈正乐 committed
67

68 69 70 71
    # 检查是否包含敏感信息
    def contains_blocked_keywords(self, text):
        return any(keyword in text for keyword in BLOCKED_KEYWORDS)

72 73 74
    # 为所给问题返回similarity文本
    def get_similarity(self, _aquestion):
        return GetSimilarity(_question=_aquestion, _faiss_db=self.faiss_db).get_similarity_doc()
75 76 77
    
    def get_similarity_origin(self, _aquestion):
        return GetSimilarity(_question=_aquestion, _faiss_db=self.faiss_db)
78

79
    # 一次性直接给出所有的答案
80
    def chat(self, _question,with_similarity=False):
陈正乐 committed
81
        self.cur_oquestion = _question
82 83 84
        if self.contains_blocked_keywords(_question):
            self.cur_answer = SAFE_RESPONSE
        else:
85 86
            # self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
            similarity = self.get_similarity_origin(_aquestion=self.cur_oquestion)
tinywell committed
87 88 89 90
            if self.rerank:
                self.cur_similarity = similarity.get_rerank(self.rerank_model)
            else:
                self.cur_similarity = similarity.get_similarity_doc()
91
            similarity_docs = similarity.get_similarity_docs()
tinywell committed
92 93 94 95 96
            rerank_docs = similarity.get_rerank_docs()
            print("============== similarity ==============")
            print(similarity_docs)
            print("============== rerank ==============")
            print(rerank_docs)
97 98 99 100 101 102 103 104
            self.cur_question = self.prompt.format(context=self.cur_similarity, question=self.cur_oquestion)
            if not _question:
                return ""
            self.cur_answer = self.llm.run(context=self.cur_similarity, question=self.cur_oquestion)
            if self.contains_blocked_keywords(self.cur_answer):
                self.cur_answer = SAFE_RESPONSE

        self.update_history()
105 106
        if with_similarity:
            return self.cur_answer, similarity_docs
107
        return self.cur_answer
陈正乐 committed
108

109
    # 异步输出,逐渐输出答案
110
    async def async_chat(self, _question):
陈正乐 committed
111
        self.cur_oquestion = _question
112 113
        history = self.get_history()

114 115 116
        # if self.contains_blocked_keywords(_question):
        #     self.cur_answer = SAFE_RESPONSE
        #     yield [(self.cur_oquestion, self.cur_answer)]
117
            # return
118

陈正乐 committed
119
        self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
120
        self.cur_question = self.prompt.format(context=self.cur_similarity, question=self.cur_oquestion)
121
        callback = AsyncIteratorCallbackHandler()
122

123 124 125 126 127 128 129 130 131
        async def wrap_done(fn: Awaitable, event: asyncio.Event):
            try:
                await fn
            except Exception as e:
                import traceback
                traceback.print_exc()
                print(f"Caught exception: {e}")
            finally:
                event.set()
132

133
        task = asyncio.create_task(
134
            wrap_done(self.llm.arun(context=self.cur_similarity, question=self.cur_oquestion, callbacks=[callback]),
135
                      callback.done))
136

137
        self.cur_answer = ""
陈正乐 committed
138
        history.append((self.cur_oquestion, self.cur_answer))
139
        async for token in callback.aiter():
140
            self.cur_answer += token
141
            if self.contains_blocked_keywords(self.cur_oquestion):
142 143 144 145
                self.cur_answer = SAFE_RESPONSE
                history[-1] = (self.cur_oquestion, self.cur_answer)
                yield history
                return
陈正乐 committed
146 147
            history[-1] = (self.cur_oquestion, self.cur_answer)
            yield history
148 149
        await task

陈正乐 committed
150 151 152 153
    async def async_chat2(self, history):
        _question = history[-1][0]
        history = history[:-1]

154 155 156 157
        # if self.contains_blocked_keywords(_question):
        #     self.cur_answer = SAFE_RESPONSE
        #     yield [(_question, self.cur_answer)]
        #     return
陈正乐 committed
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181

        self.cur_similarity = self.get_similarity(_aquestion=_question)
        self.cur_question = self.prompt.format(context=self.cur_similarity, question=_question)
        callback = AsyncIteratorCallbackHandler()

        async def wrap_done(fn: Awaitable, event: asyncio.Event):
            try:
                await fn
            except Exception as e:
                import traceback
                traceback.print_exc()
                print(f"Caught exception: {e}")
            finally:
                event.set()

        task = asyncio.create_task(
            wrap_done(self.llm.arun(context=self.cur_similarity, question=_question, callbacks=[callback]),
                      callback.done))

        self.cur_answer = ""
        print(_question, self.cur_answer)
        history.append((_question, self.cur_answer))
        async for token in callback.aiter():
            self.cur_answer += token
182
            if self.contains_blocked_keywords(_question):
陈正乐 committed
183 184 185 186 187 188 189 190
                self.cur_answer = SAFE_RESPONSE
                history[-1] = (_question, self.cur_answer)
                yield history
                return
            history[-1] = (_question, self.cur_answer)
            yield history
        await task

191
    def get_history(self):
192
        self.history = self.crud.get_history(self.chat_id)
193 194
        return self.history

陈正乐 committed
195
    def update_history(self):
196 197 198 199 200
        if self.cur_oquestion and self.cur_answer:
            if not self.history:
                self.history = []
            # 避免重复添加条目
            self.history = [(q, a) for q, a in self.history if q != self.cur_oquestion or a != ""]
陈正乐 committed
201
            self.history.append((self.cur_oquestion, self.cur_answer))
202
            self.crud.update_last(chat_id=self.chat_id)
陈正乐 committed
203
            self.crud.insert_turn_qa(chat_id=self.chat_id, question=self.cur_oquestion, answer=self.cur_answer,
204
                                     turn_number=len(self.history), is_last=1)
陈正乐 committed
205

206 207 208 209 210 211 212 213 214 215
    def get_users(self):
        return self.crud.get_users()

    def get_chats(self, user_account):
        return self.crud.get_chats(user_account)

    def set_chat_id(self, chat_id):
        self.chat_id = chat_id
        self.history = self.crud.get_history(self.chat_id)

陈正乐 committed
216 217 218 219 220 221 222
    def create_chat(self, user_account):
        user_id = self.crud.get_uersid_from_account(user_account)
        self.chat_id = self.crud.create_chat(user_id, '\t\t', '0')

    def set_info(self, question):
        info = self.crud.get_chat_info(self.chat_id)
        if info == '\t\t':
陈正乐 committed
223 224 225 226 227 228 229 230 231 232 233 234 235
            if len(question) <= 10:
                n_info = question
                self.crud.set_info(self.chat_id, n_info)
            else:
                info_prompt = """'''
                        {question}
                        '''
                        请你用十个字之内总结上述问题,你的输出不得大于10个字。
                        """
                info_prompt_t = PromptTemplate(input_variables=["question"], template=info_prompt)
                info_llm = LLMChain(llm=self.base_llm, prompt=info_prompt_t, llm_kwargs=self.llm_kwargs)
                n_info = info_llm.run(question=question)
                self.crud.set_info(self.chat_id, n_info)
陈正乐 committed
236

陈正乐 committed
237 238

if __name__ == "__main__":
239 240 241 242 243
    # 数据库连接
    c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
                       port=CHAT_DB_PORT, )
    base_llm = ChatERNIESerLLM(
        chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
244 245
    vecstore_faiss = VectorStore_FAISS(
        embedding_model_name=EMBEEDING_MODEL_PATH,
246
        store_path=FAISS_STORE_PATH,
247 248 249 250 251
        index_name=INDEX_NAME,
        info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
              "password": VEC_DB_PASSWORD},
        show_number=SIMILARITY_SHOW_NUMBER,
        reset=False)
252 253 254
    my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _faiss_db=vecstore_faiss)
    my_chat.set_chat_id('1')
    print(my_chat.get_history())