ernie.py 4.3 KB
Newer Older
1 2 3 4
import logging
import os

from typing import Any, Dict, List, Mapping, Optional
陈正乐 committed
5
from langchain.llms.base import BaseLLM, LLM
6 7 8 9 10 11 12 13 14 15 16 17
from langchain.schema import LLMResult
from langchain.utils import get_from_dict_or_env
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks

from enum import Enum

from pydantic import root_validator, Field

from .ernie_sdk import CompletionRequest, ErnieBot, Message, bot_message, user_message

logger = logging.getLogger(__name__)

陈正乐 committed
18

19 20 21 22 23 24 25 26 27 28
class ModelType(Enum):
    ERNIE = "ernie"
    ERNIE_LITE = "ernie-lite"
    SHEETS1 = "sheets1"
    SHEETS2 = "sheets2"
    SHEET_COMB = "sheet-comb"
    LLAMA2_7B = "llama2-7b"
    LLAMA2_13B = "llama2-13b"
    LLAMA2_70B = "llama2-70b"
    QFCN_LLAMA2_7B = "qfcn-llama2-7b"
陈正乐 committed
29
    BLOOMZ_7B = "bloomz-7b"
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46


MODEL_SERVICE_BASE_URL = "https://aip.baidubce.com/rpc/2.0/"

MODEL_SERVICE_Suffix = {
    ModelType.ERNIE: "ai_custom/v1/wenxinworkshop/chat/completions",
    ModelType.ERNIE_LITE: "ai_custom/v1/wenxinworkshop/chat/eb-instant",
    ModelType.SHEETS1: "ai_custom/v1/wenxinworkshop/chat/besheet",
    ModelType.SHEETS2: "ai_custom/v1/wenxinworkshop/chat/besheets2",
    ModelType.SHEET_COMB: "ai_custom/v1/wenxinworkshop/chat/sheet_comb1",
    ModelType.LLAMA2_7B: "ai_custom/v1/wenxinworkshop/chat/llama_2_7b",
    ModelType.LLAMA2_13B: "ai_custom/v1/wenxinworkshop/chat/llama_2_13b",
    ModelType.LLAMA2_70B: "ai_custom/v1/wenxinworkshop/chat/llama_2_70b",
    ModelType.QFCN_LLAMA2_7B: "ai_custom/v1/wenxinworkshop/chat/qianfan_chinese_llama_2_7b",
    ModelType.BLOOMZ_7B: "ai_custom/v1/wenxinworkshop/chat/bloomz_7b1",
}

陈正乐 committed
47

48 49 50 51 52 53 54 55 56
class ErnieLLM(LLM):
    """
    ErnieLLM is a LLM that uses Ernie to generate text.
    """

    model_name: Optional[ModelType] = None
    access_token: Optional[str] = ""

    @root_validator()
陈正乐 committed
57
    def validate_environment(self, values: Dict) -> Dict:
58 59 60 61
        """Validate the environment."""
        # print(values)
        model_name = ModelType(get_from_dict_or_env(values, "model_name", "model_name", str(ModelType.ERNIE)))
        access_token = get_from_dict_or_env(values, "access_token", "ERNIE_ACCESS_TOKEN", "")
陈正乐 committed
62

63 64
        if not access_token:
            raise ValueError("No access token provided.")
陈正乐 committed
65

66 67 68 69
        values["model_name"] = model_name
        values["access_token"] = access_token
        return values

陈正乐 committed
70 71 72 73
    def _call(self, prompt: str, stop: Optional[List[str]] = None,
              run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str:

        request = CompletionRequest(messages=[Message("user", prompt)])
74 75 76 77 78 79 80 81
        bot = ErnieBot(_get_model_service_url(self.model_name), self.access_token or "", request)
        try:
            # 你的代码
            response = bot.get_response().result
            # print("response: ",response)
            return response
        except Exception as e:
            # 处理异常
陈正乐 committed
82
            print("exception:", e)
83 84 85 86 87 88 89 90 91 92 93
            return e.__str__()

    @property
    def _llm_type(self) -> str:
        """Return type of llm."""
        return "ernie"

    # def _identifying_params(self) -> Mapping[str, Any]:
    #     return {
    #         "name": "ernie",
    #     }
陈正乐 committed
94 95


96 97
def _get_model_service_url(model_name) -> str:
    # print("_get_model_service_url model_name: ",model_name)
陈正乐 committed
98
    return MODEL_SERVICE_BASE_URL + MODEL_SERVICE_Suffix[model_name]
99 100 101 102


class ErnieChat(LLM):
    model_name: ModelType
陈正乐 committed
103
    access_token: str
104 105 106
    prefix_messages: List = Field(default_factory=list)
    id: str = ""

陈正乐 committed
107 108 109
    def _call(self, prompt: str, stop: Optional[List[str]] = None,
              run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str:

110
        msg = user_message(prompt)
陈正乐 committed
111 112
        request = CompletionRequest(messages=self.prefix_messages + [msg])
        bot = ErnieBot(_get_model_service_url(self.model_name), self.access_token, request)
113 114 115 116 117 118 119 120 121 122 123
        try:
            # 你的代码
            response = bot.get_response().result
            if self.id == "":
                self.id = bot.get_response().id
            self.prefix_messages.append(msg)
            self.prefix_messages.append(bot_message(response))
            return response
        except Exception as e:
            # 处理异常
            raise e
陈正乐 committed
124

125 126
    def _get_id(self) -> str:
        return self.id
陈正乐 committed
127

128 129 130
    @property
    def _llm_type(self) -> str:
        """Return type of llm."""
陈正乐 committed
131
        return "ernie"