# -*- coding: utf-8 -*- import sys import time from langchain.chains import LLMChain from langchain_core.prompts import PromptTemplate from typing import Awaitable import asyncio from langchain.callbacks import AsyncIteratorCallbackHandler from src.llm.ernie_with_sdk import ChatERNIESerLLM from qianfan import ChatCompletion from src.pgdb.chat.c_db import UPostgresDB from src.pgdb.chat.crud import CRUD from src.server.get_similarity import GetSimilarity from src.pgdb.knowledge.similarity import VectorStore_FAISS from src.config.consts import ( CHAT_DB_USER, CHAT_DB_HOST, CHAT_DB_PORT, CHAT_DB_DBNAME, CHAT_DB_PASSWORD, EMBEEDING_MODEL_PATH, RERANK_MODEL_PATH, FAISS_STORE_PATH, INDEX_NAME, VEC_DB_HOST, VEC_DB_PASSWORD, VEC_DB_PORT, VEC_DB_USER, VEC_DB_DBNAME, SIMILARITY_SHOW_NUMBER ) from .rerank import BgeRerank sys.path.append("../..") prompt1 = """''' {context} ''' 请你根据上述已知资料回答下面的问题,问题如下: {question}""" PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=prompt1) # 预设的安全响应 SAFE_RESPONSE = "您好,我不具备人类属性,因此没有名字。我可以协助您完成范围广泛的任务并提供有关各种主题的信息,比如回答问题,提是供定义和解释及建议。如果您有任何问题,请随时向我提问。" BLOCKED_KEYWORDS = ["文心一言", "百度", "模型"] class QA: def __init__(self, _prompt, _base_llm, _llm_kwargs, _prompt_kwargs, _db, _faiss_db,rerank:bool=False): self.prompt = _prompt self.base_llm = _base_llm self.llm_kwargs = _llm_kwargs self.prompt_kwargs = _prompt_kwargs self.db = _db self.chat_id = None self.faiss_db = _faiss_db self.crud = CRUD(self.db) self.history = None self.llm = LLMChain(llm=self.base_llm, prompt=self.prompt, llm_kwargs=self.llm_kwargs) self.cur_answer = "" self.cur_question = "" self.cur_similarity = "" self.cur_oquestion = "" self.rerank = rerank if rerank: self.rerank_model = BgeRerank(RERANK_MODEL_PATH) # 检查是否包含敏感信息 def contains_blocked_keywords(self, text): return any(keyword in text for keyword in BLOCKED_KEYWORDS) # 为所给问题返回similarity文本 def get_similarity(self, _aquestion): return GetSimilarity(_question=_aquestion, _faiss_db=self.faiss_db).get_similarity_doc() def get_similarity_origin(self, _aquestion): return GetSimilarity(_question=_aquestion, _faiss_db=self.faiss_db) # 一次性直接给出所有的答案 def chat(self, _question,with_similarity=False): self.cur_oquestion = _question if self.contains_blocked_keywords(_question): self.cur_answer = SAFE_RESPONSE else: # self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion) similarity = self.get_similarity_origin(_aquestion=self.cur_oquestion) if self.rerank: self.cur_similarity = similarity.get_rerank(self.rerank_model) else: self.cur_similarity = similarity.get_similarity_doc() similarity_docs = similarity.get_similarity_docs() rerank_docs = similarity.get_rerank_docs() self.cur_question = self.prompt.format(context=self.cur_similarity, question=self.cur_oquestion) if not _question: return "" self.cur_answer = self.llm.run(context=self.cur_similarity, question=self.cur_oquestion) if self.contains_blocked_keywords(self.cur_answer): self.cur_answer = SAFE_RESPONSE self.update_history() if with_similarity: if self.rerank: return self.cur_answer, rerank_docs else: return self.cur_answer, similarity_docs return self.cur_answer def chat_with_history(self, _question,history,with_similarity=False): self.cur_oquestion = _question if self.contains_blocked_keywords(_question): self.cur_answer = SAFE_RESPONSE else: # self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion) similarity = self.get_similarity_origin(_aquestion=self.cur_oquestion) if self.rerank: self.cur_similarity = similarity.get_rerank(self.rerank_model) else: self.cur_similarity = similarity.get_similarity_doc() similarity_docs = similarity.get_similarity_docs() rerank_docs = similarity.get_rerank_docs() self.cur_question = self.prompt.format(history=history,context=self.cur_similarity, question=self.cur_oquestion) if not _question: return "" self.cur_answer = self.llm.run(history=history,context=self.cur_similarity, question=self.cur_oquestion) if self.contains_blocked_keywords(self.cur_answer): self.cur_answer = SAFE_RESPONSE self.update_history() if with_similarity: if self.rerank: return self.cur_answer, rerank_docs else: return self.cur_answer, similarity_docs return self.cur_answer # 异步输出,逐渐输出答案 async def async_chat(self, _question): self.cur_oquestion = _question history = self.get_history() # if self.contains_blocked_keywords(_question): # self.cur_answer = SAFE_RESPONSE # yield [(self.cur_oquestion, self.cur_answer)] # return self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion) self.cur_question = self.prompt.format(context=self.cur_similarity, question=self.cur_oquestion) callback = AsyncIteratorCallbackHandler() async def wrap_done(fn: Awaitable, event: asyncio.Event): try: await fn except Exception as e: import traceback traceback.print_exc() print(f"Caught exception: {e}") finally: event.set() task = asyncio.create_task( wrap_done(self.llm.arun(context=self.cur_similarity, question=self.cur_oquestion, callbacks=[callback]), callback.done)) self.cur_answer = "" history.append((self.cur_oquestion, self.cur_answer)) async for token in callback.aiter(): self.cur_answer += token if self.contains_blocked_keywords(self.cur_oquestion): self.cur_answer = SAFE_RESPONSE history[-1] = (self.cur_oquestion, self.cur_answer) yield history return history[-1] = (self.cur_oquestion, self.cur_answer) yield history await task async def async_chat2(self, history): _question = history[-1][0] history = history[:-1] # if self.contains_blocked_keywords(_question): # self.cur_answer = SAFE_RESPONSE # yield [(_question, self.cur_answer)] # return self.cur_similarity = self.get_similarity(_aquestion=_question) self.cur_question = self.prompt.format(context=self.cur_similarity, question=_question) callback = AsyncIteratorCallbackHandler() async def wrap_done(fn: Awaitable, event: asyncio.Event): try: await fn except Exception as e: import traceback traceback.print_exc() print(f"Caught exception: {e}") finally: event.set() task = asyncio.create_task( wrap_done(self.llm.arun(context=self.cur_similarity, question=_question, callbacks=[callback]), callback.done)) self.cur_answer = "" print(_question, self.cur_answer) history.append((_question, self.cur_answer)) async for token in callback.aiter(): self.cur_answer += token if self.contains_blocked_keywords(_question): self.cur_answer = SAFE_RESPONSE history[-1] = (_question, self.cur_answer) yield history return history[-1] = (_question, self.cur_answer) yield history await task def get_history(self): self.history = self.crud.get_history(self.chat_id) return self.history def update_history(self): if self.cur_oquestion and self.cur_answer: if not self.history: self.history = [] # 避免重复添加条目 self.history = [(q, a) for q, a in self.history if q != self.cur_oquestion or a != ""] self.history.append((self.cur_oquestion, self.cur_answer)) self.crud.update_last(chat_id=self.chat_id) self.crud.insert_turn_qa(chat_id=self.chat_id, question=self.cur_oquestion, answer=self.cur_answer, turn_number=len(self.history), is_last=1) def get_users(self): return self.crud.get_users() def get_chats(self, user_account): return self.crud.get_chats(user_account) def set_chat_id(self, chat_id): self.chat_id = chat_id self.history = self.crud.get_history(self.chat_id) def create_chat(self, user_account): user_id = self.crud.get_uersid_from_account(user_account) self.chat_id = self.crud.create_chat(user_id, '\t\t', '0') def set_info(self, question): info = self.crud.get_chat_info(self.chat_id) if info == '\t\t': if len(question) <= 10: n_info = question self.crud.set_info(self.chat_id, n_info) else: info_prompt = """''' {question} ''' 请你用十个字之内总结上述问题,你的输出不得大于10个字。 """ info_prompt_t = PromptTemplate(input_variables=["question"], template=info_prompt) info_llm = LLMChain(llm=self.base_llm, prompt=info_prompt_t, llm_kwargs=self.llm_kwargs) n_info = info_llm.run(question=question) self.crud.set_info(self.chat_id, n_info) if __name__ == "__main__": # 数据库连接 c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD, port=CHAT_DB_PORT, ) base_llm = ChatERNIESerLLM( chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu")) vecstore_faiss = VectorStore_FAISS( embedding_model_name=EMBEEDING_MODEL_PATH, store_path=FAISS_STORE_PATH, index_name=INDEX_NAME, info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER, "password": VEC_DB_PASSWORD}, show_number=SIMILARITY_SHOW_NUMBER, reset=False) my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _faiss_db=vecstore_faiss) my_chat.set_chat_id('1') print(my_chat.get_history())