#!/usr/bin/env python3 # -*- coding: utf-8 -*- import argparse from src.llm.loader import ModelLoader import uvicorn import json,os,datetime from typing import List, Optional, Any from fastapi import FastAPI, HTTPException, Request, status, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from pydantic import BaseModel from sse_starlette.sse import EventSourceResponse tokens = ["token1"] app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=['*'], allow_credentials=True, allow_methods=['*'], allow_headers=['*'], ) class Message(BaseModel): role: str content: str class ChatBody(BaseModel): messages: List[Message] model: str stream: Optional[bool] = False max_tokens: Optional[int] = 4096 temperature: Optional[float] = 0.9 top_p: Optional[float] = 5 class CompletionBody(BaseModel): prompt: Any model: str stream: Optional[bool] = False max_tokens: Optional[int] = 4096 temperature: Optional[float] = 0.9 top_p: Optional[float] = 5 class EmbeddingsBody(BaseModel): # Python 3.8 does not support str | List[str] input: Any model: Optional[str] @app.get("/") def read_root(): return {"Hello": "World!"} @app.get("/v1/models") def get_models(): global model ret = {"data": [], "object": "list"} if model: ret['data'].append({ "created": 1677610602, "id": "gpt-3.5-turbo", "object": "model", "owned_by": "openai", "permission": [ { "created": 1680818747, "id": "modelperm-fTUZTbzFp7uLLTeMSo9ks6oT", "object": "model_permission", "allow_create_engine": False, "allow_sampling": True, "allow_logprobs": True, "allow_search_indices": False, "allow_view": True, "allow_fine_tuning": False, "organization": "*", "group": None, "is_blocking": False } ], "root": "gpt-3.5-turbo", "parent": None, }) return ret def generate_response(content: str, chat: bool = True): global model_name if chat: return { "id": "chatcmpl-77PZm95TtxE0oYLRx3cxa6HtIDI7s", "object": "chat.completion", "created": 1682000966, "model": model_name, "usage": { "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, }, "choices": [{ "message": {"role": "assistant", "content": content}, "finish_reason": "stop", "index": 0} ] } else: return { "id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7", "object": "text_completion", "created": 1589478378, "model": "text-davinci-003", "choices": [ { "text": content, "index": 0, "logprobs": None, "finish_reason": "stop" } ], "usage": { "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0 } } def generate_stream_response_start(): return { "id": "chatcmpl-77QWpn5cxFi9sVMw56DZReDiGKmcB", "object": "chat.completion.chunk", "created": 1682004627, "model": "gpt-3.5-turbo-0301", "choices": [{"delta": {"role": "assistant"}, "index": 0, "finish_reason": None}] } def generate_stream_response(content: str, chat: bool = True): global model_name if chat: return { "id": "chatcmpl-77QWpn5cxFi9sVMw56DZReDiGKmcB", "object": "chat.completion.chunk", "created": 1682004627, "model": model_name, "choices": [{"delta": {"content": content}, "index": 0, "finish_reason": None} ]} else: return { "id":"cmpl-7GfnvmcsDmmTVbPHmTBcNqlMtaEVj", "object":"text_completion", "created":1684208299, "choices":[ { "text": content, "index": 0, "logprobs": None, "finish_reason": None, } ], "model": "text-davinci-003" } def generate_stream_response_stop(chat: bool = True): if chat: return {"id": "chatcmpl-77QWpn5cxFi9sVMw56DZReDiGKmcB", "object": "chat.completion.chunk", "created": 1682004627, "model": "gpt-3.5-turbo-0301", "choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}] } else: return { "id":"cmpl-7GfnvmcsDmmTVbPHmTBcNqlMtaEVj", "object":"text_completion", "created":1684208299, "choices":[ {"text":"","index":0,"logprobs":None,"finish_reason":"stop"}], "model":"text-davinci-003", } # @app.post("/v1/embeddings") # async def embeddings(body: EmbeddingsBody, request: Request, background_tasks: BackgroundTasks): # return do_embeddings(body, request, background_tasks) # def do_embeddings(body: EmbeddingsBody, request: Request, background_tasks: BackgroundTasks): # background_tasks.add_task(torch_gc) # if request.headers.get("Authorization").split(" ")[1] not in context.tokens: # raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token is wrong!") # if not context.embeddings_model: # raise HTTPException(status.HTTP_404_NOT_FOUND, "Embeddings model not found!") # embeddings = context.embeddings_model.encode(body.input) # data = [] # if isinstance(body.input, str): # data.append({ # "object": "embedding", # "index": 0, # "embedding": embeddings.tolist(), # }) # else: # for i, embed in enumerate(embeddings): # data.append({ # "object": "embedding", # "index": i, # "embedding": embed.tolist(), # }) # content = { # "object": "list", # "data": data, # "model": "text-embedding-ada-002-v2", # "usage": { # "prompt_tokens": 0, # "total_tokens": 0 # } # } # return JSONResponse(status_code=200, content=content) # @app.post("/v1/engines/{engine}/embeddings") # async def engines_embeddings(engine: str, body: EmbeddingsBody, request: Request, background_tasks: BackgroundTasks): # return do_embeddings(body, request, background_tasks) def init_model_args(model_args = None): if model_args is None: model_args = {} model_args['temperature'] = model_args['temperature'] if model_args.get('temperature') != None else 0.95 if model_args['temperature'] <= 0: model_args['temperature'] = 0.1 if model_args['temperature'] > 1: model_args['temperature'] = 1 model_args['top_p'] = model_args['top_p'] if model_args.get('top_p') else 0.7 model_args['max_tokens'] = model_args['max_tokens'] if model_args.get('max_tokens') != None else 512 return model_args @app.post("/v1/num_tokens") async def get_num_tokens(body: CompletionBody, request: Request): global model, tokenizer,model_name if request.headers.get("Authorization").split(" ")[1] not in tokens: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token is wrong!") if not model: raise HTTPException(status.HTTP_404_NOT_FOUND, "LLM model not found!") prompt = body.prompt prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False) # now = datetime.datetime.now() # time = now.strftime("%Y-%m-%d %H:%M:%S") print(prompt,len(prompt_tokens) ) return JSONResponse(content=generate_response(str(len(prompt_tokens)), chat=False)) @app.post("/v1/chat/completions") async def chat_completions(body: ChatBody, request: Request): global model, tokenizer,model_name if request.headers.get("Authorization").split(" ")[1] not in tokens: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token is wrong!") if not model: raise HTTPException(status.HTTP_404_NOT_FOUND, "LLM model not found!") question = body.messages[-1] if question.role == 'user': question = question.content else: raise HTTPException(status.HTTP_400_BAD_REQUEST, "No Question Found") history = [] user_question = '' if model_name == "chatglm3-6b": for message in body.messages[:-1]: history.append({"role":message.role, "content":message.content}) # history.extend(body.messages[:-1]) else: for message in body.messages: if message.role == 'system': history.append((message.content, "OK")) if message.role == 'user': user_question = message.content elif message.role == 'assistant': assistant_answer = message.content history.append((user_question, assistant_answer)) print(f"question = {question}, history = {history}") if body.stream: async def eval_llm(): first = True model_args = init_model_args({ "temperature": body.temperature, "top_p": body.top_p, "max_tokens": body.max_tokens, }) sends = 0 for response, _ in model.stream_chat( tokenizer, question, history, temperature=model_args['temperature'], top_p=model_args['top_p'], max_length=max(2048, model_args['max_tokens'])): ret = response[sends:] # https://github.com/THUDM/ChatGLM-6B/issues/478 # 修复表情符号的输出问题 if "\uFFFD" == ret[-1:]: continue sends = len(response) if first: first = False yield json.dumps(generate_stream_response_start(), ensure_ascii=False) yield json.dumps(generate_stream_response(ret), ensure_ascii=False) yield json.dumps(generate_stream_response_stop(), ensure_ascii=False) yield "[DONE]" return EventSourceResponse(eval_llm(), ping=10000) else: model_args = init_model_args({ "temperature": body.temperature, "top_p": body.top_p, "max_tokens": body.max_tokens, }) response, _ = model.chat( tokenizer, question, history, temperature=model_args['temperature'], top_p=model_args['top_p'], max_length=max(2048, model_args['max_tokens'])) return JSONResponse(content=generate_response(response)) @app.post("/v1/completions") async def completions(body: CompletionBody, request: Request): print(body) if request.headers.get("Authorization").split(" ")[1] not in tokens: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token is wrong!") if not model: raise HTTPException(status.HTTP_404_NOT_FOUND, "LLM model not found!") if type(body.prompt) == list: question = body.prompt[0] else: question = body.prompt print(f"question = {question}") if body.stream: async def eval_llm(): model_args = init_model_args({ "temperature": body.temperature, "top_p": body.top_p, "max_tokens": body.max_tokens, }) sends = 0 for response, _ in model.stream_chat( tokenizer, question, [], temperature=model_args['temperature'], top_p=model_args['top_p'], max_length=max(2048, model_args['max_tokens'])): ret = response[sends:] # https://github.com/THUDM/ChatGLM-6B/issues/478 # 修复表情符号的输出问题 if "\uFFFD" == ret[-1:]: continue sends = len(response) yield json.dumps(generate_stream_response(ret, chat=False), ensure_ascii=False) yield json.dumps(generate_stream_response_stop(chat=False), ensure_ascii=False) yield "[DONE]" return EventSourceResponse(eval_llm(), ping=10000) else: model_args = init_model_args({ "temperature": body.temperature, "top_p": body.top_p, "max_tokens": body.max_tokens, }) response, _ = model.chat( tokenizer, question, [], temperature=model_args['temperature'], top_p=model_args['top_p'], max_length=max(2048, model_args['max_tokens'])) print(response) return JSONResponse(content=generate_response(response, chat=False)) def main(): global model, tokenizer,model_name parser = argparse.ArgumentParser( description='Start LLM and Embeddings models as a service.') parser.add_argument('--model_name_or_path', type=str, help='Choosed LLM model', default='/model/chatglm3-6b') parser.add_argument('--device', type=str, help='Device to run the service, gpu/cpu/mps', default='gpu') parser.add_argument('--port', type=int, help='Port number to run the service', default=8000) parser.add_argument('--host', type=str, help='host to run the service', default="0.0.0.0") parser.add_argument('--checkpoint', type=str, help='model checkpoint to load', default=None) parser.add_argument('--checkpoint_path', type=str, help='model checkpoint to load', default=None) parser.add_argument('--pre_seq_len', type=int, help='ptuning train pre_seq_len', default=None) parser.add_argument('--quantization_bit', type=int, help='quantization_bit 4 or 8, default not set', default=None) args = parser.parse_args() print("> Load config and arguments...") print(f"Language Model: {args.model_name_or_path}") print(f"Device: {args.device}") print(f"Port: {args.port}") print(f"Host: {args.host}") print(f"Quantization_bit: {args.quantization_bit}") print(f"Checkpoint: {args.checkpoint}") print(f"Checkpoint_path: {args.checkpoint_path}") model_name = os.path.basename(args.model_name_or_path) print(model_name) if args.checkpoint == "lora": # lora 微调 checkpoint 及模型加载 loader = ModelLoader(args.model_name_or_path) loader.load_lora(args.checkpoint_path) elif args.checkpoint == "ptuning": # ptuning v2 微调 checkpoint 及模型加载 loader = ModelLoader(args.model_name_or_path, args.pre_seq_len, False) loader.load_prefix(args.checkpoint_path) else: loader = ModelLoader(args.model_name_or_path) model,tokenizer = loader.models() if args.quantization_bit is not None: model = loader.quantize(args.quantization_bit) model.cuda().eval() uvicorn.run(app, host=args.host, port=args.port, workers=1) if __name__ == '__main__': main()