import logging import os from typing import Any, Dict, List, Mapping, Optional from langchain.llms.base import BaseLLM,LLM from langchain.schema import LLMResult from langchain.utils import get_from_dict_or_env from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks from enum import Enum from pydantic import root_validator, Field from .ernie_sdk import CompletionRequest, ErnieBot, Message, bot_message, user_message logger = logging.getLogger(__name__) class ModelType(Enum): ERNIE = "ernie" ERNIE_LITE = "ernie-lite" SHEETS1 = "sheets1" SHEETS2 = "sheets2" SHEET_COMB = "sheet-comb" LLAMA2_7B = "llama2-7b" LLAMA2_13B = "llama2-13b" LLAMA2_70B = "llama2-70b" QFCN_LLAMA2_7B = "qfcn-llama2-7b" BLOOMZ_7B="bloomz-7b" MODEL_SERVICE_BASE_URL = "https://aip.baidubce.com/rpc/2.0/" MODEL_SERVICE_Suffix = { ModelType.ERNIE: "ai_custom/v1/wenxinworkshop/chat/completions", ModelType.ERNIE_LITE: "ai_custom/v1/wenxinworkshop/chat/eb-instant", ModelType.SHEETS1: "ai_custom/v1/wenxinworkshop/chat/besheet", ModelType.SHEETS2: "ai_custom/v1/wenxinworkshop/chat/besheets2", ModelType.SHEET_COMB: "ai_custom/v1/wenxinworkshop/chat/sheet_comb1", ModelType.LLAMA2_7B: "ai_custom/v1/wenxinworkshop/chat/llama_2_7b", ModelType.LLAMA2_13B: "ai_custom/v1/wenxinworkshop/chat/llama_2_13b", ModelType.LLAMA2_70B: "ai_custom/v1/wenxinworkshop/chat/llama_2_70b", ModelType.QFCN_LLAMA2_7B: "ai_custom/v1/wenxinworkshop/chat/qianfan_chinese_llama_2_7b", ModelType.BLOOMZ_7B: "ai_custom/v1/wenxinworkshop/chat/bloomz_7b1", } class ErnieLLM(LLM): """ ErnieLLM is a LLM that uses Ernie to generate text. """ model_name: Optional[ModelType] = None access_token: Optional[str] = "" @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate the environment.""" # print(values) model_name = ModelType(get_from_dict_or_env(values, "model_name", "model_name", str(ModelType.ERNIE))) access_token = get_from_dict_or_env(values, "access_token", "ERNIE_ACCESS_TOKEN", "") if not access_token: raise ValueError("No access token provided.") values["model_name"] = model_name values["access_token"] = access_token return values def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> str: request = CompletionRequest(messages=[Message("user",prompt)]) bot = ErnieBot(_get_model_service_url(self.model_name), self.access_token or "", request) try: # 你的代码 response = bot.get_response().result # print("response: ",response) return response except Exception as e: # 处理异常 print("exception:",e) return e.__str__() @property def _llm_type(self) -> str: """Return type of llm.""" return "ernie" # def _identifying_params(self) -> Mapping[str, Any]: # return { # "name": "ernie", # } def _get_model_service_url(model_name) -> str: # print("_get_model_service_url model_name: ",model_name) return MODEL_SERVICE_BASE_URL+MODEL_SERVICE_Suffix[model_name] class ErnieChat(LLM): model_name: ModelType access_token: str prefix_messages: List = Field(default_factory=list) id: str = "" def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> str: msg = user_message(prompt) request = CompletionRequest(messages=self.prefix_messages+[msg]) bot = ErnieBot(_get_model_service_url(self.model_name),self.access_token,request) try: # 你的代码 response = bot.get_response().result if self.id == "": self.id = bot.get_response().id self.prefix_messages.append(msg) self.prefix_messages.append(bot_message(response)) return response except Exception as e: # 处理异常 raise e def _get_id(self) -> str: return self.id @property def _llm_type(self) -> str: """Return type of llm.""" return "ernie"