#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
from src.llm.loader import ModelLoader
import uvicorn
import json,os,datetime
from typing import List, Optional, Any

from fastapi import FastAPI, HTTPException, Request, status, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sse_starlette.sse import EventSourceResponse


tokens = ["token1"]

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=['*'],
    allow_credentials=True,
    allow_methods=['*'],
    allow_headers=['*'],
)


class Message(BaseModel):
    role: str
    content: str


class ChatBody(BaseModel):
    messages: List[Message]
    model: str
    stream: Optional[bool] = False
    max_tokens: Optional[int] = 4096
    temperature: Optional[float] = 0.9
    top_p: Optional[float] = 5


class CompletionBody(BaseModel):
    prompt: Any
    model: str
    stream: Optional[bool] = False
    max_tokens: Optional[int] = 4096
    temperature: Optional[float] = 0.9
    top_p: Optional[float] = 5


class EmbeddingsBody(BaseModel):
    # Python 3.8 does not support str | List[str]
    input: Any
    model: Optional[str]


@app.get("/")
def read_root():
    return {"Hello": "World!"}


@app.get("/v1/models")
def get_models():
    global model
    ret = {"data": [], "object": "list"}
    if model:
        ret['data'].append({
            "created": 1677610602,
            "id": "gpt-3.5-turbo",
            "object": "model",
            "owned_by": "openai",
            "permission": [
                {
                    "created": 1680818747,
                    "id": "modelperm-fTUZTbzFp7uLLTeMSo9ks6oT",
                    "object": "model_permission",
                    "allow_create_engine": False,
                    "allow_sampling": True,
                    "allow_logprobs": True,
                    "allow_search_indices": False,
                    "allow_view": True,
                    "allow_fine_tuning": False,
                    "organization": "*",
                    "group": None,
                    "is_blocking": False
                }
            ],
            "root": "gpt-3.5-turbo",
            "parent": None,
        })

    return ret


def generate_response(content: str, chat: bool = True):
    global model_name
    if chat:
        return {
            "id": "chatcmpl-77PZm95TtxE0oYLRx3cxa6HtIDI7s",
            "object": "chat.completion",
            "created": 1682000966,
            "model": model_name,
            "usage": {
                "prompt_tokens": 0,
                "completion_tokens": 0,
                "total_tokens": 0,
            },
            "choices": [{
                "message": {"role": "assistant", "content": content},
                "finish_reason": "stop", "index": 0}
            ]
        }
    else:
        return {
            "id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
            "object": "text_completion",
            "created": 1589478378,
            "model": "text-davinci-003",
            "choices": [
                {
                "text": content,
                "index": 0,
                "logprobs": None,
                "finish_reason": "stop"
                }
            ],
            "usage": {
                "prompt_tokens": 0,
                "completion_tokens": 0,
                "total_tokens": 0
            }
        }


def generate_stream_response_start():
    return {
        "id": "chatcmpl-77QWpn5cxFi9sVMw56DZReDiGKmcB",
        "object": "chat.completion.chunk", "created": 1682004627,
        "model": "gpt-3.5-turbo-0301",
        "choices": [{"delta": {"role": "assistant"}, "index": 0, "finish_reason": None}]
    }



def generate_stream_response(content: str, chat: bool = True):
    global model_name
    if chat:
        return {
            "id": "chatcmpl-77QWpn5cxFi9sVMw56DZReDiGKmcB",
            "object": "chat.completion.chunk",
            "created": 1682004627,
            "model": model_name,
            "choices": [{"delta": {"content": content}, "index": 0, "finish_reason": None}
                        ]}
    else:
        return {
            "id":"cmpl-7GfnvmcsDmmTVbPHmTBcNqlMtaEVj",
            "object":"text_completion",
            "created":1684208299,
            "choices":[
                {
                    "text": content,
                    "index": 0,
                    "logprobs": None,
                    "finish_reason": None,
                }
            ],
            "model": "text-davinci-003"
        }


def generate_stream_response_stop(chat: bool = True):
    if chat:
        return {"id": "chatcmpl-77QWpn5cxFi9sVMw56DZReDiGKmcB",
            "object": "chat.completion.chunk", "created": 1682004627,
            "model": "gpt-3.5-turbo-0301",
            "choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}]
            }
    else:
        return {
            "id":"cmpl-7GfnvmcsDmmTVbPHmTBcNqlMtaEVj",
            "object":"text_completion",
            "created":1684208299,
            "choices":[
                {"text":"","index":0,"logprobs":None,"finish_reason":"stop"}],
            "model":"text-davinci-003",
        }

# @app.post("/v1/embeddings")
# async def embeddings(body: EmbeddingsBody, request: Request, background_tasks: BackgroundTasks):
#     return do_embeddings(body, request, background_tasks)


# def do_embeddings(body: EmbeddingsBody, request: Request, background_tasks: BackgroundTasks):
#     background_tasks.add_task(torch_gc)
#     if request.headers.get("Authorization").split(" ")[1] not in context.tokens:
#         raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token is wrong!")

#     if not context.embeddings_model:
#         raise HTTPException(status.HTTP_404_NOT_FOUND, "Embeddings model not found!")

#     embeddings = context.embeddings_model.encode(body.input)
#     data = []
#     if isinstance(body.input, str):
#         data.append({
#             "object": "embedding",
#             "index": 0,
#             "embedding": embeddings.tolist(),
#         })
#     else:
#         for i, embed in enumerate(embeddings):
#             data.append({
#                 "object": "embedding",
#                 "index": i,
#                 "embedding": embed.tolist(),
#             })
#     content = {
#         "object": "list",
#         "data": data,
#         "model": "text-embedding-ada-002-v2",
#         "usage": {
#             "prompt_tokens": 0,
#             "total_tokens": 0
#         }
#     }
#     return JSONResponse(status_code=200, content=content)


# @app.post("/v1/engines/{engine}/embeddings")
# async def engines_embeddings(engine: str, body: EmbeddingsBody, request: Request, background_tasks: BackgroundTasks):
#     return do_embeddings(body, request, background_tasks)
def init_model_args(model_args = None):
    if model_args is None:
        model_args = {}
    model_args['temperature'] = model_args['temperature'] if model_args.get('temperature') != None else 0.95
    if model_args['temperature'] <= 0:
        model_args['temperature'] = 0.1
    if model_args['temperature'] > 1:
        model_args['temperature'] = 1
    model_args['top_p'] = model_args['top_p'] if model_args.get('top_p') else 0.7
    model_args['max_tokens'] = model_args['max_tokens'] if model_args.get('max_tokens') != None else 512

    return model_args

@app.post("/v1/num_tokens")
async def get_num_tokens(body: CompletionBody, request: Request):
    global model, tokenizer,model_name
    if request.headers.get("Authorization").split(" ")[1] not in tokens:
        raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token is wrong!")
    if not model:
        raise HTTPException(status.HTTP_404_NOT_FOUND, "LLM model not found!")
    prompt = body.prompt
    prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
    # now = datetime.datetime.now()
    # time = now.strftime("%Y-%m-%d %H:%M:%S")
    print(prompt,len(prompt_tokens) )
    return JSONResponse(content=generate_response(str(len(prompt_tokens)), chat=False))

@app.post("/v1/chat/completions")
async def chat_completions(body: ChatBody, request: Request):
    global model, tokenizer,model_name
    if request.headers.get("Authorization").split(" ")[1] not in tokens:
        raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token is wrong!")

    if not model:
        raise HTTPException(status.HTTP_404_NOT_FOUND, "LLM model not found!")
    question = body.messages[-1]
    if question.role == 'user':
        question = question.content
    else:
        raise HTTPException(status.HTTP_400_BAD_REQUEST, "No Question Found")

    history = []
    user_question = ''
    if model_name == "chatglm3-6b":
        for message in body.messages[:-1]:
            history.append({"role":message.role, "content":message.content})
        # history.extend(body.messages[:-1])
    else:
        for message in body.messages:
            if message.role == 'system':
                history.append((message.content, "OK"))
            if message.role == 'user':
                user_question = message.content
            elif message.role == 'assistant':
                assistant_answer = message.content
                history.append((user_question, assistant_answer))

    print(f"question = {question}, history = {history}")

    if body.stream:
        async def eval_llm():
            first = True
            
            model_args = init_model_args({
                    "temperature": body.temperature,
                    "top_p": body.top_p,
                    "max_tokens": body.max_tokens,
                })
            sends = 0
            for response, _ in model.stream_chat(
                    tokenizer, question, history,
                    temperature=model_args['temperature'],
                    top_p=model_args['top_p'],
                    max_length=max(2048, model_args['max_tokens'])):
                ret = response[sends:]
                # https://github.com/THUDM/ChatGLM-6B/issues/478
                # 修复表情符号的输出问题
                if "\uFFFD" == ret[-1:]:
                    continue
                sends = len(response)
                if first:
                    first = False
                    yield json.dumps(generate_stream_response_start(),
                                    ensure_ascii=False)
                yield json.dumps(generate_stream_response(ret), ensure_ascii=False)
            yield json.dumps(generate_stream_response_stop(), ensure_ascii=False)
            yield "[DONE]"
        return EventSourceResponse(eval_llm(), ping=10000)
    else:
        model_args = init_model_args({
            "temperature": body.temperature,
            "top_p": body.top_p,
            "max_tokens": body.max_tokens,
        })
        response, _ = model.chat(
            tokenizer, question, history,
            temperature=model_args['temperature'],
            top_p=model_args['top_p'],
            max_length=max(2048, model_args['max_tokens']))
        return JSONResponse(content=generate_response(response))


@app.post("/v1/completions")
async def completions(body: CompletionBody, request: Request):
    print(body)
    if request.headers.get("Authorization").split(" ")[1] not in tokens:
        raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token is wrong!")
    if not model:
        raise HTTPException(status.HTTP_404_NOT_FOUND, "LLM model not found!")
    if type(body.prompt) == list:
        question = body.prompt[0]
    else:
        question = body.prompt

    print(f"question = {question}")

    if body.stream:
        async def eval_llm():
            model_args = init_model_args({
                    "temperature": body.temperature,
                    "top_p": body.top_p,
                    "max_tokens": body.max_tokens,
                })
            sends = 0
            for response, _ in model.stream_chat(
                    tokenizer, question, [],
                    temperature=model_args['temperature'],
                    top_p=model_args['top_p'],
                    max_length=max(2048, model_args['max_tokens'])):
                ret = response[sends:]
                # https://github.com/THUDM/ChatGLM-6B/issues/478
                # 修复表情符号的输出问题
                if "\uFFFD" == ret[-1:]:
                    continue
                sends = len(response)
                yield json.dumps(generate_stream_response(ret, chat=False), ensure_ascii=False)
            yield json.dumps(generate_stream_response_stop(chat=False), ensure_ascii=False)
            yield "[DONE]"
        return EventSourceResponse(eval_llm(), ping=10000)
    else:
        model_args = init_model_args({
            "temperature": body.temperature,
            "top_p": body.top_p,
            "max_tokens": body.max_tokens,
        })
        response, _ = model.chat(
            tokenizer, question, [],
            temperature=model_args['temperature'],
            top_p=model_args['top_p'],
            max_length=max(2048, model_args['max_tokens']))
        print(response)
        return JSONResponse(content=generate_response(response, chat=False))

def main():
    global model, tokenizer,model_name
    parser = argparse.ArgumentParser(
        description='Start LLM and Embeddings models as a service.')
    parser.add_argument('--model_name_or_path', type=str, help='Choosed LLM model',
                        default='/model/chatglm3-6b')
    parser.add_argument('--device', type=str,
                        help='Device to run the service, gpu/cpu/mps',
                        default='gpu')
    parser.add_argument('--port', type=int, help='Port number to run the service',
                        default=8000)
    parser.add_argument('--host', type=str, help='host to run the service',
                        default="0.0.0.0")
    parser.add_argument('--checkpoint', type=str, help='model checkpoint to load',
                        default=None)
    parser.add_argument('--checkpoint_path', type=str, help='model checkpoint to load',
                        default=None)
    parser.add_argument('--pre_seq_len', type=int, help='ptuning train pre_seq_len',
                        default=None)
    parser.add_argument('--quantization_bit', type=int, help='quantization_bit 4 or 8, default not set',
                        default=None)
    args = parser.parse_args()

    print("> Load config and arguments...")
    print(f"Language Model: {args.model_name_or_path}")
    print(f"Device: {args.device}")
    print(f"Port: {args.port}")
    print(f"Host: {args.host}")
    print(f"Quantization_bit: {args.quantization_bit}")
    print(f"Checkpoint: {args.checkpoint}")
    print(f"Checkpoint_path: {args.checkpoint_path}")
    model_name = os.path.basename(args.model_name_or_path)
    print(model_name)
    if args.checkpoint == "lora":
        # lora 微调 checkpoint 及模型加载
        loader = ModelLoader(args.model_name_or_path)
        loader.load_lora(args.checkpoint_path)
    elif args.checkpoint == "ptuning":
        # ptuning v2 微调 checkpoint 及模型加载
        loader = ModelLoader(args.model_name_or_path, args.pre_seq_len, False)
        loader.load_prefix(args.checkpoint_path)
    else:
        loader = ModelLoader(args.model_name_or_path)

    model,tokenizer = loader.models()

    if args.quantization_bit is not None:
        model = loader.quantize(args.quantization_bit)

    model.cuda().eval()
    
    uvicorn.run(app, host=args.host, port=args.port, workers=1)
    
if __name__ == '__main__':
    main()