qa.py 3.71 KB
Newer Older
陈正乐 committed
1 2
# -*- coding: utf-8 -*-
import sys
陈正乐 committed
3 4
import time
from datetime import datetime
陈正乐 committed
5 6 7 8 9 10 11
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from typing import Awaitable
import asyncio
from langchain.callbacks import AsyncIteratorCallbackHandler
from src.llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion
12 13 14 15 16 17 18 19 20
from src.pgdb.chat.c_db import UPostgresDB
from src.pgdb.chat.crud import CRUD
from src.config.consts import (
    CHAT_DB_USER,
    CHAT_DB_HOST,
    CHAT_DB_PORT,
    CHAT_DB_DBNAME,
    CHAT_DB_PASSWORD
)
陈正乐 committed
21 22 23 24 25 26 27 28 29 30

sys.path.append("../..")
prompt1 = """'''
{context}
'''
请你根据上述已知资料回答下面的问题,问题如下:
{question}"""
PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=prompt1)


31 32 33 34 35 36 37 38 39 40 41 42 43
class QA:
    def __init__(self, _prompt, _base_llm, _llm_kwargs, _prompt_kwargs, _db, _chat_id):
        self.prompt = _prompt
        self.base_llm = _base_llm
        self.llm_kwargs = _llm_kwargs
        self.prompt_kwargs = _prompt_kwargs
        self.db = _db
        self.chat_id = _chat_id
        self.crud = CRUD(self.db)
        self.history = self.crud.get_history(self.chat_id)
        self.llm = LLMChain(llm=self.base_llm, prompt=self.prompt, llm_kwargs=self.llm_kwargs)
        self.cur_answer = ""
        self.cur_question = ""
陈正乐 committed
44

45
    # 一次性直接给出所有的答案
陈正乐 committed
46
    async def chat(self, *args):
47 48 49 50 51 52
        self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, args)})
        self.cur_answer = ""
        if not args:
            return ""
        self.cur_answer = self.llm.run({k: v for k, v in zip(self.prompt_kwargs, args)})
        return self.cur_answer
陈正乐 committed
53

54
    # 异步输出,逐渐输出答案
陈正乐 committed
55
    async def async_chat(self, *args):
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
        self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, args)})
        callback = AsyncIteratorCallbackHandler()
        async def wrap_done(fn: Awaitable, event: asyncio.Event):
            try:
                await fn
            except Exception as e:
                import traceback
                traceback.print_exc()
                print(f"Caught exception: {e}")
            finally:
                event.set()
        task = asyncio.create_task(
            wrap_done(self.llm.arun({k: v for k, v in zip(self.prompt_kwargs, args)}, callbacks=[callback]),
                      callback.done))
        self.cur_answer = ""
        async for token in callback.aiter():
            self.cur_answer = self.cur_answer + token
            yield f"{self.cur_answer}"
陈正乐 committed
74
        print(datetime.now())
75
        await task
陈正乐 committed
76 77 78
        print('----------------',self.cur_question)
        print('================',self.cur_answer)
        print(datetime.now())
79 80 81 82 83 84 85 86 87 88


    def get_history(self):
        return self.history

    def updata_history(self):
        self.history.append((self.cur_question, self.cur_answer))
        self.crud.update_last(chat_id=self.chat_id)
        self.crud.insert_turn_qa(chat_id=self.chat_id, question=self.cur_question, answer=self.cur_answer,
                                 turn_number=len(self.history), is_last=1)
陈正乐 committed
89 90 91


if __name__ == "__main__":
92 93 94 95 96 97
    # 数据库连接
    c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
                       port=CHAT_DB_PORT, )
    base_llm = ChatERNIESerLLM(
        chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
    my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id='2')
陈正乐 committed
98
    print(my_chat.async_chat("当别人想你说你好的时候,你也应该说你好", "你好"))
99
    my_chat.updata_history()
陈正乐 committed
100 101
    time.sleep(20)
    print(my_chat.cur_answer)