# -*- coding: utf-8 -*- import sys import time from langchain.chains import LLMChain from langchain.prompts import PromptTemplate from typing import Awaitable import asyncio from langchain.callbacks import AsyncIteratorCallbackHandler from src.llm.ernie_with_sdk import ChatERNIESerLLM from qianfan import ChatCompletion from src.pgdb.chat.c_db import UPostgresDB from src.pgdb.chat.crud import CRUD from src.server.get_similarity import GetSimilarity from src.pgdb.knowledge.similarity import VectorStore_FAISS from src.config.consts import ( CHAT_DB_USER, CHAT_DB_HOST, CHAT_DB_PORT, CHAT_DB_DBNAME, CHAT_DB_PASSWORD, EMBEEDING_MODEL_PATH, FAISS_STORE_PATH, INDEX_NAME, VEC_DB_HOST, VEC_DB_PASSWORD, VEC_DB_PORT, VEC_DB_USER, VEC_DB_DBNAME, SIMILARITY_SHOW_NUMBER ) sys.path.append("../..") prompt1 = """''' {context} ''' 请你根据上述已知资料回答下面的问题,问题如下: {question}""" PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=prompt1) class QA: def __init__(self, _prompt, _base_llm, _llm_kwargs, _prompt_kwargs, _db, _chat_id, _faiss_db): self.prompt = _prompt self.base_llm = _base_llm self.llm_kwargs = _llm_kwargs self.prompt_kwargs = _prompt_kwargs self.db = _db self.chat_id = _chat_id self.faiss_db = _faiss_db self.crud = CRUD(self.db) self.history = self.crud.get_history(self.chat_id) self.llm = LLMChain(llm=self.base_llm, prompt=self.prompt, llm_kwargs=self.llm_kwargs) self.cur_answer = "" self.cur_question = "" self.cur_similarity = "" self.cur_oquestion = "" # 为所给问题返回similarity文本 def get_similarity(self, _aquestion): return GetSimilarity(_question=_aquestion, _faiss_db=self.faiss_db).get_similarity_doc() # 一次性直接给出所有的答案 def chat(self, _question): self.cur_oquestion = _question self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion) self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, (self.cur_similarity, self.cur_oquestion))}) self.cur_answer = "" if not _question: return "" self.cur_answer = self.llm.run({k: v for k, v in zip(self.prompt_kwargs, (self.cur_similarity, self.cur_oquestion))}) return self.cur_answer # 异步输出,逐渐输出答案 async def async_chat(self, _question): self.cur_oquestion = _question self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion) self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, (self.cur_similarity, self.cur_oquestion))}) callback = AsyncIteratorCallbackHandler() async def wrap_done(fn: Awaitable, event: asyncio.Event): try: await fn except Exception as e: import traceback traceback.print_exc() print(f"Caught exception: {e}") finally: event.set() task = asyncio.create_task( wrap_done(self.llm.arun({k: v for k, v in zip(self.prompt_kwargs, (self.cur_similarity, self.cur_oquestion))}, callbacks=[callback]), callback.done)) self.cur_answer = "" async for token in callback.aiter(): self.cur_answer = self.cur_answer + token yield f"{self.cur_answer}" await task def get_history(self): return self.history def update_history(self): if self.cur_oquestion == '' and self.cur_answer == '': pass else: self.history.append((self.cur_oquestion, self.cur_answer)) self.crud.update_last(chat_id=self.chat_id) self.crud.insert_turn_qa(chat_id=self.chat_id, question=self.cur_oquestion, answer=self.cur_answer, turn_number=len(self.history), is_last=1) if __name__ == "__main__": # 数据库连接 c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD, port=CHAT_DB_PORT, ) base_llm = ChatERNIESerLLM( chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu")) vecstore_faiss = VectorStore_FAISS( embedding_model_name=EMBEEDING_MODEL_PATH, store_path=FAISS_STORE_PATH, index_name=INDEX_NAME, info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER, "password": VEC_DB_PASSWORD}, show_number=SIMILARITY_SHOW_NUMBER, reset=False) my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id='2', _faiss_db=vecstore_faiss) print(my_chat.chat("什么是低空经济")) my_chat.update_history() time.sleep(20) print(my_chat.cur_answer)