ernie.py 4.35 KB
Newer Older
1 2 3 4
import logging
import os

from typing import Any, Dict, List, Mapping, Optional
5 6 7 8 9
from langchain_core.language_models import BaseLLM, LLM
from langchain_core.outputs import LLMResult
from langchain_core.utils import get_from_dict_or_env
from langchain.callbacks.manager import Callbacks
from langchain_core.callbacks import CallbackManagerForLLMRun
10 11 12 13 14 15 16 17 18

from enum import Enum

from pydantic import root_validator, Field

from .ernie_sdk import CompletionRequest, ErnieBot, Message, bot_message, user_message

logger = logging.getLogger(__name__)

陈正乐 committed
19

20 21 22 23 24 25 26 27 28 29
class ModelType(Enum):
    ERNIE = "ernie"
    ERNIE_LITE = "ernie-lite"
    SHEETS1 = "sheets1"
    SHEETS2 = "sheets2"
    SHEET_COMB = "sheet-comb"
    LLAMA2_7B = "llama2-7b"
    LLAMA2_13B = "llama2-13b"
    LLAMA2_70B = "llama2-70b"
    QFCN_LLAMA2_7B = "qfcn-llama2-7b"
陈正乐 committed
30
    BLOOMZ_7B = "bloomz-7b"
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47


MODEL_SERVICE_BASE_URL = "https://aip.baidubce.com/rpc/2.0/"

MODEL_SERVICE_Suffix = {
    ModelType.ERNIE: "ai_custom/v1/wenxinworkshop/chat/completions",
    ModelType.ERNIE_LITE: "ai_custom/v1/wenxinworkshop/chat/eb-instant",
    ModelType.SHEETS1: "ai_custom/v1/wenxinworkshop/chat/besheet",
    ModelType.SHEETS2: "ai_custom/v1/wenxinworkshop/chat/besheets2",
    ModelType.SHEET_COMB: "ai_custom/v1/wenxinworkshop/chat/sheet_comb1",
    ModelType.LLAMA2_7B: "ai_custom/v1/wenxinworkshop/chat/llama_2_7b",
    ModelType.LLAMA2_13B: "ai_custom/v1/wenxinworkshop/chat/llama_2_13b",
    ModelType.LLAMA2_70B: "ai_custom/v1/wenxinworkshop/chat/llama_2_70b",
    ModelType.QFCN_LLAMA2_7B: "ai_custom/v1/wenxinworkshop/chat/qianfan_chinese_llama_2_7b",
    ModelType.BLOOMZ_7B: "ai_custom/v1/wenxinworkshop/chat/bloomz_7b1",
}

陈正乐 committed
48

49 50 51 52 53 54 55 56 57
class ErnieLLM(LLM):
    """
    ErnieLLM is a LLM that uses Ernie to generate text.
    """

    model_name: Optional[ModelType] = None
    access_token: Optional[str] = ""

    @root_validator()
陈正乐 committed
58
    def validate_environment(self, values: Dict) -> Dict:
59 60 61 62
        """Validate the environment."""
        # print(values)
        model_name = ModelType(get_from_dict_or_env(values, "model_name", "model_name", str(ModelType.ERNIE)))
        access_token = get_from_dict_or_env(values, "access_token", "ERNIE_ACCESS_TOKEN", "")
陈正乐 committed
63

64 65
        if not access_token:
            raise ValueError("No access token provided.")
陈正乐 committed
66

67 68 69 70
        values["model_name"] = model_name
        values["access_token"] = access_token
        return values

陈正乐 committed
71 72 73 74
    def _call(self, prompt: str, stop: Optional[List[str]] = None,
              run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str:

        request = CompletionRequest(messages=[Message("user", prompt)])
75 76 77 78 79 80 81 82
        bot = ErnieBot(_get_model_service_url(self.model_name), self.access_token or "", request)
        try:
            # 你的代码
            response = bot.get_response().result
            # print("response: ",response)
            return response
        except Exception as e:
            # 处理异常
陈正乐 committed
83
            print("exception:", e)
84 85 86 87 88 89 90 91 92 93 94
            return e.__str__()

    @property
    def _llm_type(self) -> str:
        """Return type of llm."""
        return "ernie"

    # def _identifying_params(self) -> Mapping[str, Any]:
    #     return {
    #         "name": "ernie",
    #     }
陈正乐 committed
95 96


97 98
def _get_model_service_url(model_name) -> str:
    # print("_get_model_service_url model_name: ",model_name)
陈正乐 committed
99
    return MODEL_SERVICE_BASE_URL + MODEL_SERVICE_Suffix[model_name]
100 101 102 103


class ErnieChat(LLM):
    model_name: ModelType
陈正乐 committed
104
    access_token: str
105 106 107
    prefix_messages: List = Field(default_factory=list)
    id: str = ""

陈正乐 committed
108 109 110
    def _call(self, prompt: str, stop: Optional[List[str]] = None,
              run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str:

111
        msg = user_message(prompt)
陈正乐 committed
112 113
        request = CompletionRequest(messages=self.prefix_messages + [msg])
        bot = ErnieBot(_get_model_service_url(self.model_name), self.access_token, request)
114 115 116 117 118 119 120 121 122 123 124
        try:
            # 你的代码
            response = bot.get_response().result
            if self.id == "":
                self.id = bot.get_response().id
            self.prefix_messages.append(msg)
            self.prefix_messages.append(bot_message(response))
            return response
        except Exception as e:
            # 处理异常
            raise e
陈正乐 committed
125

126 127
    def _get_id(self) -> str:
        return self.id
陈正乐 committed
128

129 130 131
    @property
    def _llm_type(self) -> str:
        """Return type of llm."""
陈正乐 committed
132
        return "ernie"