qa.py 11.1 KB
Newer Older
陈正乐 committed
1 2
# -*- coding: utf-8 -*-
import sys
陈正乐 committed
3
import time
陈正乐 committed
4
from langchain.chains import LLMChain
5
from langchain_core.prompts import PromptTemplate
陈正乐 committed
6 7 8 9 10
from typing import Awaitable
import asyncio
from langchain.callbacks import AsyncIteratorCallbackHandler
from src.llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion
11 12
from src.pgdb.chat.c_db import UPostgresDB
from src.pgdb.chat.crud import CRUD
13 14
from src.server.get_similarity import GetSimilarity
from src.pgdb.knowledge.similarity import VectorStore_FAISS
15 16 17 18 19
from src.config.consts import (
    CHAT_DB_USER,
    CHAT_DB_HOST,
    CHAT_DB_PORT,
    CHAT_DB_DBNAME,
20 21
    CHAT_DB_PASSWORD,
    EMBEEDING_MODEL_PATH,
tinywell committed
22
    RERANK_MODEL_PATH,
23 24 25 26 27 28 29 30
    FAISS_STORE_PATH,
    INDEX_NAME,
    VEC_DB_HOST,
    VEC_DB_PASSWORD,
    VEC_DB_PORT,
    VEC_DB_USER,
    VEC_DB_DBNAME,
    SIMILARITY_SHOW_NUMBER
31
)
tinywell committed
32
from .rerank import BgeRerank
陈正乐 committed
33 34 35 36 37 38 39

sys.path.append("../..")
prompt1 = """'''
{context}
'''
请你根据上述已知资料回答下面的问题,问题如下:
{question}"""
40 41 42 43




陈正乐 committed
44 45
PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=prompt1)

46 47


48
# 预设的安全响应
49
SAFE_RESPONSE = "您好,我不具备人类属性,因此没有名字。我可以协助您完成范围广泛的任务并提供有关各种主题的信息,比如回答问题,提是供定义和解释及建议。如果您有任何问题,请随时向我提问。"
50
BLOCKED_KEYWORDS = ["文心一言", "百度", "模型"]
陈正乐 committed
51

陈正乐 committed
52

53
class QA:
tinywell committed
54
    def __init__(self, _prompt, _base_llm, _llm_kwargs, _prompt_kwargs, _db, _faiss_db,rerank:bool=False):
55 56 57 58 59
        self.prompt = _prompt
        self.base_llm = _base_llm
        self.llm_kwargs = _llm_kwargs
        self.prompt_kwargs = _prompt_kwargs
        self.db = _db
60
        self.chat_id = None
61
        self.faiss_db = _faiss_db
62
        self.crud = CRUD(self.db)
63
        self.history = None
64 65 66
        self.llm = LLMChain(llm=self.base_llm, prompt=self.prompt, llm_kwargs=self.llm_kwargs)
        self.cur_answer = ""
        self.cur_question = ""
陈正乐 committed
67 68
        self.cur_similarity = ""
        self.cur_oquestion = ""
tinywell committed
69 70 71 72
        self.rerank = rerank

        if rerank:
            self.rerank_model = BgeRerank(RERANK_MODEL_PATH)
陈正乐 committed
73

74 75 76 77
    # 检查是否包含敏感信息
    def contains_blocked_keywords(self, text):
        return any(keyword in text for keyword in BLOCKED_KEYWORDS)

78 79 80
    # 为所给问题返回similarity文本
    def get_similarity(self, _aquestion):
        return GetSimilarity(_question=_aquestion, _faiss_db=self.faiss_db).get_similarity_doc()
81 82 83
    
    def get_similarity_origin(self, _aquestion):
        return GetSimilarity(_question=_aquestion, _faiss_db=self.faiss_db)
84

85
    # 一次性直接给出所有的答案
86
    def chat(self, _question,with_similarity=False):
陈正乐 committed
87
        self.cur_oquestion = _question
88 89 90
        if self.contains_blocked_keywords(_question):
            self.cur_answer = SAFE_RESPONSE
        else:
91 92
            # self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
            similarity = self.get_similarity_origin(_aquestion=self.cur_oquestion)
tinywell committed
93 94 95 96
            if self.rerank:
                self.cur_similarity = similarity.get_rerank(self.rerank_model)
            else:
                self.cur_similarity = similarity.get_similarity_doc()
97
            similarity_docs = similarity.get_similarity_docs()
tinywell committed
98
            rerank_docs = similarity.get_rerank_docs()
99 100 101 102 103 104 105 106
            self.cur_question = self.prompt.format(context=self.cur_similarity, question=self.cur_oquestion)
            if not _question:
                return ""
            self.cur_answer = self.llm.run(context=self.cur_similarity, question=self.cur_oquestion)
            if self.contains_blocked_keywords(self.cur_answer):
                self.cur_answer = SAFE_RESPONSE

        self.update_history()
107
        if with_similarity:
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
            if self.rerank:
                return self.cur_answer, rerank_docs
            else:
                return self.cur_answer, similarity_docs
        return self.cur_answer

    def chat_with_history(self, _question,history,with_similarity=False):
        self.cur_oquestion = _question
        if self.contains_blocked_keywords(_question):
            self.cur_answer = SAFE_RESPONSE
        else:
            # self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
            similarity = self.get_similarity_origin(_aquestion=self.cur_oquestion)
            if self.rerank:
                self.cur_similarity = similarity.get_rerank(self.rerank_model)
            else:
                self.cur_similarity = similarity.get_similarity_doc()
            similarity_docs = similarity.get_similarity_docs()
            rerank_docs = similarity.get_rerank_docs()
            self.cur_question = self.prompt.format(history=history,context=self.cur_similarity, question=self.cur_oquestion)
            if not _question:
                return ""
            self.cur_answer = self.llm.run(history=history,context=self.cur_similarity, question=self.cur_oquestion)
            if self.contains_blocked_keywords(self.cur_answer):
                self.cur_answer = SAFE_RESPONSE

        self.update_history()
        if with_similarity:
            if self.rerank:
                return self.cur_answer, rerank_docs
            else:
                return self.cur_answer, similarity_docs
140
        return self.cur_answer
陈正乐 committed
141

142

143
    # 异步输出,逐渐输出答案
144
    async def async_chat(self, _question):
陈正乐 committed
145
        self.cur_oquestion = _question
146 147
        history = self.get_history()

148 149 150
        # if self.contains_blocked_keywords(_question):
        #     self.cur_answer = SAFE_RESPONSE
        #     yield [(self.cur_oquestion, self.cur_answer)]
151
            # return
152

陈正乐 committed
153
        self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
154
        self.cur_question = self.prompt.format(context=self.cur_similarity, question=self.cur_oquestion)
155
        callback = AsyncIteratorCallbackHandler()
156

157 158 159 160 161 162 163 164 165
        async def wrap_done(fn: Awaitable, event: asyncio.Event):
            try:
                await fn
            except Exception as e:
                import traceback
                traceback.print_exc()
                print(f"Caught exception: {e}")
            finally:
                event.set()
166

167
        task = asyncio.create_task(
168
            wrap_done(self.llm.arun(context=self.cur_similarity, question=self.cur_oquestion, callbacks=[callback]),
169
                      callback.done))
170

171
        self.cur_answer = ""
陈正乐 committed
172
        history.append((self.cur_oquestion, self.cur_answer))
173
        async for token in callback.aiter():
174
            self.cur_answer += token
175
            if self.contains_blocked_keywords(self.cur_oquestion):
176 177 178 179
                self.cur_answer = SAFE_RESPONSE
                history[-1] = (self.cur_oquestion, self.cur_answer)
                yield history
                return
陈正乐 committed
180 181
            history[-1] = (self.cur_oquestion, self.cur_answer)
            yield history
182 183
        await task

陈正乐 committed
184 185 186 187
    async def async_chat2(self, history):
        _question = history[-1][0]
        history = history[:-1]

188 189 190 191
        # if self.contains_blocked_keywords(_question):
        #     self.cur_answer = SAFE_RESPONSE
        #     yield [(_question, self.cur_answer)]
        #     return
陈正乐 committed
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215

        self.cur_similarity = self.get_similarity(_aquestion=_question)
        self.cur_question = self.prompt.format(context=self.cur_similarity, question=_question)
        callback = AsyncIteratorCallbackHandler()

        async def wrap_done(fn: Awaitable, event: asyncio.Event):
            try:
                await fn
            except Exception as e:
                import traceback
                traceback.print_exc()
                print(f"Caught exception: {e}")
            finally:
                event.set()

        task = asyncio.create_task(
            wrap_done(self.llm.arun(context=self.cur_similarity, question=_question, callbacks=[callback]),
                      callback.done))

        self.cur_answer = ""
        print(_question, self.cur_answer)
        history.append((_question, self.cur_answer))
        async for token in callback.aiter():
            self.cur_answer += token
216
            if self.contains_blocked_keywords(_question):
陈正乐 committed
217 218 219 220 221 222 223 224
                self.cur_answer = SAFE_RESPONSE
                history[-1] = (_question, self.cur_answer)
                yield history
                return
            history[-1] = (_question, self.cur_answer)
            yield history
        await task

225
    def get_history(self):
226
        self.history = self.crud.get_history(self.chat_id)
227 228
        return self.history

陈正乐 committed
229
    def update_history(self):
230 231 232 233 234
        if self.cur_oquestion and self.cur_answer:
            if not self.history:
                self.history = []
            # 避免重复添加条目
            self.history = [(q, a) for q, a in self.history if q != self.cur_oquestion or a != ""]
陈正乐 committed
235
            self.history.append((self.cur_oquestion, self.cur_answer))
236
            self.crud.update_last(chat_id=self.chat_id)
陈正乐 committed
237
            self.crud.insert_turn_qa(chat_id=self.chat_id, question=self.cur_oquestion, answer=self.cur_answer,
238
                                     turn_number=len(self.history), is_last=1)
陈正乐 committed
239

240 241 242 243 244 245 246 247 248 249
    def get_users(self):
        return self.crud.get_users()

    def get_chats(self, user_account):
        return self.crud.get_chats(user_account)

    def set_chat_id(self, chat_id):
        self.chat_id = chat_id
        self.history = self.crud.get_history(self.chat_id)

陈正乐 committed
250 251 252 253 254 255 256
    def create_chat(self, user_account):
        user_id = self.crud.get_uersid_from_account(user_account)
        self.chat_id = self.crud.create_chat(user_id, '\t\t', '0')

    def set_info(self, question):
        info = self.crud.get_chat_info(self.chat_id)
        if info == '\t\t':
陈正乐 committed
257 258 259 260 261 262 263 264 265 266 267 268 269
            if len(question) <= 10:
                n_info = question
                self.crud.set_info(self.chat_id, n_info)
            else:
                info_prompt = """'''
                        {question}
                        '''
                        请你用十个字之内总结上述问题,你的输出不得大于10个字。
                        """
                info_prompt_t = PromptTemplate(input_variables=["question"], template=info_prompt)
                info_llm = LLMChain(llm=self.base_llm, prompt=info_prompt_t, llm_kwargs=self.llm_kwargs)
                n_info = info_llm.run(question=question)
                self.crud.set_info(self.chat_id, n_info)
陈正乐 committed
270

陈正乐 committed
271 272

if __name__ == "__main__":
273 274 275 276 277
    # 数据库连接
    c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
                       port=CHAT_DB_PORT, )
    base_llm = ChatERNIESerLLM(
        chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
278 279
    vecstore_faiss = VectorStore_FAISS(
        embedding_model_name=EMBEEDING_MODEL_PATH,
280
        store_path=FAISS_STORE_PATH,
281 282 283 284 285
        index_name=INDEX_NAME,
        info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
              "password": VEC_DB_PASSWORD},
        show_number=SIMILARITY_SHOW_NUMBER,
        reset=False)
286 287 288
    my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _faiss_db=vecstore_faiss)
    my_chat.set_chat_id('1')
    print(my_chat.get_history())