qa.py 15.3 KB
Newer Older
陈正乐 committed
1 2
# -*- coding: utf-8 -*-
import sys
3
import openpyxl
陈正乐 committed
4
import time
陈正乐 committed
5
from langchain.chains import LLMChain
6
from langchain_core.prompts import PromptTemplate
陈正乐 committed
7 8 9 10 11
from typing import Awaitable
import asyncio
from langchain.callbacks import AsyncIteratorCallbackHandler
from src.llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion
12 13
from src.pgdb.chat.c_db import UPostgresDB
from src.pgdb.chat.crud import CRUD
文靖昊 committed
14
from src.server.get_similarity import GetSimilarity,GetSimilarityWithExt
15
from src.pgdb.knowledge.similarity import VectorStore_FAISS
16 17 18 19 20
from src.config.consts import (
    CHAT_DB_USER,
    CHAT_DB_HOST,
    CHAT_DB_PORT,
    CHAT_DB_DBNAME,
21 22
    CHAT_DB_PASSWORD,
    EMBEEDING_MODEL_PATH,
tinywell committed
23
    RERANK_MODEL_PATH,
24 25 26 27 28 29 30 31
    FAISS_STORE_PATH,
    INDEX_NAME,
    VEC_DB_HOST,
    VEC_DB_PASSWORD,
    VEC_DB_PORT,
    VEC_DB_USER,
    VEC_DB_DBNAME,
    SIMILARITY_SHOW_NUMBER
32
)
tinywell committed
33
from .rerank import BgeRerank
陈正乐 committed
34 35 36 37 38 39 40

sys.path.append("../..")
prompt1 = """'''
{context}
'''
请你根据上述已知资料回答下面的问题,问题如下:
{question}"""
41 42 43 44




陈正乐 committed
45 46
PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=prompt1)

47 48


49
# 预设的安全响应
50
SAFE_RESPONSE = "您好,我不具备人类属性,因此没有名字。我可以协助您完成范围广泛的任务并提供有关各种主题的信息,比如回答问题,提是供定义和解释及建议。如果您有任何问题,请随时向我提问。"
51
BLOCKED_KEYWORDS = ["文心一言", "百度", "模型"]
陈正乐 committed
52

陈正乐 committed
53

54
class QA:
tinywell committed
55
    def __init__(self, _prompt, _base_llm, _llm_kwargs, _prompt_kwargs, _db, _faiss_db,rerank:bool=False):
56 57 58 59 60
        self.prompt = _prompt
        self.base_llm = _base_llm
        self.llm_kwargs = _llm_kwargs
        self.prompt_kwargs = _prompt_kwargs
        self.db = _db
61
        self.chat_id = None
62
        self.faiss_db = _faiss_db
63
        self.crud = CRUD(self.db)
64
        self.history = None
65 66 67
        self.llm = LLMChain(llm=self.base_llm, prompt=self.prompt, llm_kwargs=self.llm_kwargs)
        self.cur_answer = ""
        self.cur_question = ""
陈正乐 committed
68 69
        self.cur_similarity = ""
        self.cur_oquestion = ""
tinywell committed
70 71 72 73
        self.rerank = rerank

        if rerank:
            self.rerank_model = BgeRerank(RERANK_MODEL_PATH)
陈正乐 committed
74

75 76 77 78
    # 检查是否包含敏感信息
    def contains_blocked_keywords(self, text):
        return any(keyword in text for keyword in BLOCKED_KEYWORDS)

79 80 81
    # 为所给问题返回similarity文本
    def get_similarity(self, _aquestion):
        return GetSimilarity(_question=_aquestion, _faiss_db=self.faiss_db).get_similarity_doc()
82 83 84
    
    def get_similarity_origin(self, _aquestion):
        return GetSimilarity(_question=_aquestion, _faiss_db=self.faiss_db)
85

文靖昊 committed
86 87 88
    def get_similarity_with_ext_origin(self, _ext):
        return GetSimilarityWithExt(_question=_ext, _faiss_db=self.faiss_db)

89
    # 一次性直接给出所有的答案
90
    def chat(self, _question,with_similarity=False):
陈正乐 committed
91
        self.cur_oquestion = _question
92 93 94
        if self.contains_blocked_keywords(_question):
            self.cur_answer = SAFE_RESPONSE
        else:
95 96
            # self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
            similarity = self.get_similarity_origin(_aquestion=self.cur_oquestion)
tinywell committed
97 98 99 100
            if self.rerank:
                self.cur_similarity = similarity.get_rerank(self.rerank_model)
            else:
                self.cur_similarity = similarity.get_similarity_doc()
101
            similarity_docs = similarity.get_similarity_docs()
tinywell committed
102
            rerank_docs = similarity.get_rerank_docs()
103 104 105 106 107 108 109 110
            self.cur_question = self.prompt.format(context=self.cur_similarity, question=self.cur_oquestion)
            if not _question:
                return ""
            self.cur_answer = self.llm.run(context=self.cur_similarity, question=self.cur_oquestion)
            if self.contains_blocked_keywords(self.cur_answer):
                self.cur_answer = SAFE_RESPONSE

        self.update_history()
111
        if with_similarity:
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
            if self.rerank:
                return self.cur_answer, rerank_docs
            else:
                return self.cur_answer, similarity_docs
        return self.cur_answer

    def chat_with_history(self, _question,history,with_similarity=False):
        self.cur_oquestion = _question
        if self.contains_blocked_keywords(_question):
            self.cur_answer = SAFE_RESPONSE
        else:
            # self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
            similarity = self.get_similarity_origin(_aquestion=self.cur_oquestion)
            if self.rerank:
                self.cur_similarity = similarity.get_rerank(self.rerank_model)
            else:
                self.cur_similarity = similarity.get_similarity_doc()
            similarity_docs = similarity.get_similarity_docs()
            rerank_docs = similarity.get_rerank_docs()
            self.cur_question = self.prompt.format(history=history,context=self.cur_similarity, question=self.cur_oquestion)
            if not _question:
                return ""
            self.cur_answer = self.llm.run(history=history,context=self.cur_similarity, question=self.cur_oquestion)
            if self.contains_blocked_keywords(self.cur_answer):
                self.cur_answer = SAFE_RESPONSE

        self.update_history()
        if with_similarity:
            if self.rerank:
                return self.cur_answer, rerank_docs
文靖昊 committed
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
            else:
                return self.cur_answer, similarity_docs
        return self.cur_answer

    def chat_with_history_with_ext(self, _question,ext,history,with_similarity=False):
        self.cur_oquestion = _question
        if self.contains_blocked_keywords(_question):
            self.cur_answer = SAFE_RESPONSE
        else:
            # self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
            similarity = self.get_similarity_with_ext_origin(ext)
            if self.rerank:
                self.cur_similarity = similarity.get_rerank(self.rerank_model)
            else:
                self.cur_similarity = similarity.get_similarity_doc()
            similarity_docs = similarity.get_similarity_docs()
            rerank_docs = similarity.get_rerank_docs()
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
            self.cur_question = self.prompt.format(history=history,context=self.cur_similarity, question=self.cur_oquestion)
            if not _question:
                return ""
            self.cur_answer = self.llm.run(history=history,context=self.cur_similarity, question=self.cur_oquestion)
            if self.contains_blocked_keywords(self.cur_answer):
                self.cur_answer = SAFE_RESPONSE

        self.update_history()
        if with_similarity:
            if self.rerank:
                return self.cur_answer, rerank_docs
            else:
                return self.cur_answer, similarity_docs
        return self.cur_answer

    def chat_with_history_with_ext_with_save_excel(self, _question,ext,history,with_similarity=False):
        self.cur_oquestion = _question
        if self.contains_blocked_keywords(_question):
            self.cur_answer = SAFE_RESPONSE
        else:
            # self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
            similarity1 = self.get_similarity_origin(_aquestion=self.cur_oquestion)
            if self.rerank:
                self.cur_similarity = similarity1.get_rerank(self.rerank_model)
            else:
                self.cur_similarity = similarity1.get_similarity_doc()
            similarity_docs1 = similarity1.get_similarity_docs()

            rerank_docs1 = similarity1.get_rerank_docs()

            similarity = self.get_similarity_with_ext_origin(ext)
            if self.rerank:
                self.cur_similarity = similarity.get_rerank(self.rerank_model)
            else:
                self.cur_similarity = similarity.get_similarity_doc()
            similarity_docs = similarity.get_similarity_docs()
            rerank_docs = similarity.get_rerank_docs()
            print("rerank_docs_with_ext================================================")
            for i  in range(len(rerank_docs)):
                save_excel(_question,similarity_docs1[i].page_content, rerank_docs1[i].page_content,rerank_docs[i].page_content)


文靖昊 committed
201 202 203 204 205 206 207 208 209 210 211
            self.cur_question = self.prompt.format(history=history,context=self.cur_similarity, question=self.cur_oquestion)
            if not _question:
                return ""
            self.cur_answer = self.llm.run(history=history,context=self.cur_similarity, question=self.cur_oquestion)
            if self.contains_blocked_keywords(self.cur_answer):
                self.cur_answer = SAFE_RESPONSE

        self.update_history()
        if with_similarity:
            if self.rerank:
                return self.cur_answer, rerank_docs
212 213
            else:
                return self.cur_answer, similarity_docs
214
        return self.cur_answer
陈正乐 committed
215

216

217
    # 异步输出,逐渐输出答案
218
    async def async_chat(self, _question):
陈正乐 committed
219
        self.cur_oquestion = _question
220 221
        history = self.get_history()

222 223 224
        # if self.contains_blocked_keywords(_question):
        #     self.cur_answer = SAFE_RESPONSE
        #     yield [(self.cur_oquestion, self.cur_answer)]
225
            # return
226

陈正乐 committed
227
        self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
228
        self.cur_question = self.prompt.format(context=self.cur_similarity, question=self.cur_oquestion)
229
        callback = AsyncIteratorCallbackHandler()
230

231 232 233 234 235 236 237 238 239
        async def wrap_done(fn: Awaitable, event: asyncio.Event):
            try:
                await fn
            except Exception as e:
                import traceback
                traceback.print_exc()
                print(f"Caught exception: {e}")
            finally:
                event.set()
240

241
        task = asyncio.create_task(
242
            wrap_done(self.llm.arun(context=self.cur_similarity, question=self.cur_oquestion, callbacks=[callback]),
243
                      callback.done))
244

245
        self.cur_answer = ""
陈正乐 committed
246
        history.append((self.cur_oquestion, self.cur_answer))
247
        async for token in callback.aiter():
248
            self.cur_answer += token
249
            if self.contains_blocked_keywords(self.cur_oquestion):
250 251 252 253
                self.cur_answer = SAFE_RESPONSE
                history[-1] = (self.cur_oquestion, self.cur_answer)
                yield history
                return
陈正乐 committed
254 255
            history[-1] = (self.cur_oquestion, self.cur_answer)
            yield history
256 257
        await task

陈正乐 committed
258 259 260 261
    async def async_chat2(self, history):
        _question = history[-1][0]
        history = history[:-1]

262 263 264 265
        # if self.contains_blocked_keywords(_question):
        #     self.cur_answer = SAFE_RESPONSE
        #     yield [(_question, self.cur_answer)]
        #     return
陈正乐 committed
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289

        self.cur_similarity = self.get_similarity(_aquestion=_question)
        self.cur_question = self.prompt.format(context=self.cur_similarity, question=_question)
        callback = AsyncIteratorCallbackHandler()

        async def wrap_done(fn: Awaitable, event: asyncio.Event):
            try:
                await fn
            except Exception as e:
                import traceback
                traceback.print_exc()
                print(f"Caught exception: {e}")
            finally:
                event.set()

        task = asyncio.create_task(
            wrap_done(self.llm.arun(context=self.cur_similarity, question=_question, callbacks=[callback]),
                      callback.done))

        self.cur_answer = ""
        print(_question, self.cur_answer)
        history.append((_question, self.cur_answer))
        async for token in callback.aiter():
            self.cur_answer += token
290
            if self.contains_blocked_keywords(_question):
陈正乐 committed
291 292 293 294 295 296 297 298
                self.cur_answer = SAFE_RESPONSE
                history[-1] = (_question, self.cur_answer)
                yield history
                return
            history[-1] = (_question, self.cur_answer)
            yield history
        await task

299
    def get_history(self):
300
        self.history = self.crud.get_history(self.chat_id)
301 302
        return self.history

陈正乐 committed
303
    def update_history(self):
304 305 306 307 308
        if self.cur_oquestion and self.cur_answer:
            if not self.history:
                self.history = []
            # 避免重复添加条目
            self.history = [(q, a) for q, a in self.history if q != self.cur_oquestion or a != ""]
陈正乐 committed
309
            self.history.append((self.cur_oquestion, self.cur_answer))
310
            self.crud.update_last(chat_id=self.chat_id)
陈正乐 committed
311
            self.crud.insert_turn_qa(chat_id=self.chat_id, question=self.cur_oquestion, answer=self.cur_answer,
312
                                     turn_number=len(self.history), is_last=1)
陈正乐 committed
313

314 315 316 317 318 319 320 321 322 323
    def get_users(self):
        return self.crud.get_users()

    def get_chats(self, user_account):
        return self.crud.get_chats(user_account)

    def set_chat_id(self, chat_id):
        self.chat_id = chat_id
        self.history = self.crud.get_history(self.chat_id)

陈正乐 committed
324 325 326 327 328 329 330
    def create_chat(self, user_account):
        user_id = self.crud.get_uersid_from_account(user_account)
        self.chat_id = self.crud.create_chat(user_id, '\t\t', '0')

    def set_info(self, question):
        info = self.crud.get_chat_info(self.chat_id)
        if info == '\t\t':
陈正乐 committed
331 332 333 334 335 336 337 338 339 340 341 342 343
            if len(question) <= 10:
                n_info = question
                self.crud.set_info(self.chat_id, n_info)
            else:
                info_prompt = """'''
                        {question}
                        '''
                        请你用十个字之内总结上述问题,你的输出不得大于10个字。
                        """
                info_prompt_t = PromptTemplate(input_variables=["question"], template=info_prompt)
                info_llm = LLMChain(llm=self.base_llm, prompt=info_prompt_t, llm_kwargs=self.llm_kwargs)
                n_info = info_llm.run(question=question)
                self.crud.set_info(self.chat_id, n_info)
陈正乐 committed
344

345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366
def save_excel(question,similar,rerank,rerank_ext):
    file_full_path = r'D:\work\py\LAE\testdoc\test.xlsx'
    # sheet名称
    sheet_name = 'Sheet1'

    # 获取指定的文件
    wb = openpyxl.load_workbook(file_full_path)
    # 获取指定的sheet
    ws = wb[sheet_name]
    # 获得最大行数
    max_row_num = ws.max_row
    # 获得最大列数
    max_col_num = ws.max_column
    ws.cell(row=max_row_num+1, column=1, value=question)
    ws.cell(row=max_row_num+1, column=2, value=similar)
    ws.cell(row=max_row_num+1, column=3, value=rerank)
    ws.cell(row=max_row_num+1, column=4, value=rerank_ext)
    # 保存文件
    wb.save(file_full_path)



陈正乐 committed
367 368

if __name__ == "__main__":
369 370 371 372 373
    # 数据库连接
    c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
                       port=CHAT_DB_PORT, )
    base_llm = ChatERNIESerLLM(
        chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
374 375
    vecstore_faiss = VectorStore_FAISS(
        embedding_model_name=EMBEEDING_MODEL_PATH,
376
        store_path=FAISS_STORE_PATH,
377 378 379 380 381
        index_name=INDEX_NAME,
        info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
              "password": VEC_DB_PASSWORD},
        show_number=SIMILARITY_SHOW_NUMBER,
        reset=False)
382 383 384
    my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _faiss_db=vecstore_faiss)
    my_chat.set_chat_id('1')
    print(my_chat.get_history())