ernie_sdk.py 1.82 KB
Newer Older
1 2 3 4 5 6 7
from dataclasses import asdict, dataclass
from typing import List

from pydantic import BaseModel, Field

from enum import Enum

陈正乐 committed
8

9 10 11 12
class MessageRole(str, Enum):
    USER = "user"
    BOT = "assistant"

陈正乐 committed
13

14 15 16 17 18
@dataclass
class Message:
    role: str
    content: str

陈正乐 committed
19

20 21 22 23 24 25
@dataclass
class CompletionRequest:
    messages: List[Message]
    stream: bool = False
    user: str = ""

陈正乐 committed
26

27 28 29 30 31 32
@dataclass
class Usage:
    prompt_tokens: int
    completion_tokens: int
    total_tokens: int

陈正乐 committed
33

34 35 36 37 38 39 40 41 42 43 44 45 46 47
@dataclass
class CompletionResponse:
    id: str
    object: str
    created: int
    result: str
    need_clear_history: bool
    ban_round: int = 0
    sentence_id: int = 0
    is_end: bool = False
    usage: Usage = None
    is_safe: bool = False
    is_truncated: bool = False

陈正乐 committed
48

49 50 51 52 53
class ErrorResponse(BaseModel):
    error_code: int = Field(...)
    error_msg: str = Field(...)
    id: str = Field(...)

陈正乐 committed
54 55

class ErnieBot:
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
    url: str
    access_token: str
    request: CompletionRequest

    def __init__(self, url: str, access_token: str, request: CompletionRequest):
        self.url = url
        self.access_token = access_token
        self.request = request

    def get_response(self) -> CompletionResponse:

        import requests
        import json

        headers = {'Content-Type': 'application/json'}
        params = {'access_token': self.access_token}
陈正乐 committed
72 73
        request_dict = asdict(self.request)
        response = requests.post(self.url, params=params, data=json.dumps(request_dict), headers=headers)
74 75 76 77 78 79
        # print(response.json())
        try:
            return CompletionResponse(**response.json())
        except Exception as e:
            print(e)
            raise Exception(response.json())
陈正乐 committed
80 81


82 83 84
def user_message(prompt: str) -> Message:
    return Message(MessageRole.USER, prompt)

陈正乐 committed
85

86
def bot_message(prompt: str) -> Message:
陈正乐 committed
87
    return Message(MessageRole.BOT, prompt)