import argparse import time from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse import json import datetime import torch from typing import AsyncIterable from pydantic import BaseModel import uvicorn import signal from src.llm.loader import ModelLoader DEVICE = "cuda" DEVICE_ID = "0" CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE def torch_gc(): if torch.cuda.is_available(): with torch.cuda.device(CUDA_DEVICE): torch.cuda.empty_cache() torch.cuda.ipc_collect() def build_history(history): result = [] for item in history if history else []: result.append((item['q'], item['a'])) return result def convert_data(data): result = [] for item in data: result.append({'q': item[0], 'a': item[1]}) return result class StreamRequest(BaseModel): """Request body for streaming.""" message: str stop_stream = False def signal_handler(signal, frame): global stop_stream stop_stream = True async def send_message(message: str, history=[], max_length=2048, top_p=0.7, temperature=0.95) -> AsyncIterable[str]: global model, tokenizer, stop_stream count = 0 old_len = 0 print(message) output = '' for response, history in model.stream_chat(tokenizer, message, history=history, max_length=max_length, top_p=top_p, temperature=temperature): # print(old_len,count) if stop_stream: stop_stream = False break else: output = response[old_len:] print(output, end='', flush=True) # print(output) old_len = len(response) signal.signal(signal.SIGINT, signal_handler) yield f"{output}" print("") # yield f"\n" # print() app = FastAPI() @app.post("/stream") async def stream(request: Request): json_post_raw = await request.json() json_post = json.dumps(json_post_raw) json_post_list = json.loads(json_post) prompt = json_post_list.get('prompt') history = build_history(json_post_list.get('history')) max_length = json_post_list.get('max_length') top_p = json_post_list.get('top_p') temperature = json_post_list.get('temperature') return StreamingResponse(send_message(prompt, history=history, max_length=max_length if max_length else 2048, top_p=top_p if top_p else 0.7, temperature=temperature if temperature else 0.95), media_type="text/plain") @app.post("/") async def create_item(request: Request): global model, tokenizer json_post_raw = await request.json() json_post = json.dumps(json_post_raw) json_post_list = json.loads(json_post) prompt = json_post_list.get('prompt') history = build_history(json_post_list.get('history')) max_length = json_post_list.get('max_length') top_p = json_post_list.get('top_p') temperature = json_post_list.get('temperature') response, history = model.chat(tokenizer, prompt, history=history, max_length=max_length if max_length else 2048, top_p=top_p if top_p else 0.7, temperature=temperature if temperature else 0.95) now = datetime.datetime.now() time = now.strftime("%Y-%m-%d %H:%M:%S") answer = { "response": response, "history": history, "status": 200, "time": time } log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' print(log) torch_gc() return answer @app.post("/tokens") async def get_num_tokens(request: Request): global model, tokenizer json_post_raw = await request.json() json_post = json.dumps(json_post_raw) json_post_list = json.loads(json_post) prompt = json_post_list.get('prompt') tokens = tokenizer.encode(prompt, add_special_tokens=False) print("=======================================") print("=======================================") print(len(tokens), prompt) print("=======================================") print("=======================================") now = datetime.datetime.now() time = now.strftime("%Y-%m-%d %H:%M:%S") answer = { "response": len(tokens), "status": 200, "time": time } return answer def parse_args(): parser = argparse.ArgumentParser(description='ChatGLM2-6B Server') parser.add_argument('--model_name_or_path', type=str, default='THUDM/chatglm2-6b', help='模型id或local path') parser.add_argument('--checkpoint', type=str, default=None, help='checkpoint类型(None、ptuning、lora)') parser.add_argument('--checkpoint_path', type=str, default='../../../model/ckpt/chatglm2-6b-qlora-INSv11-rank16-1e-3-30/checkpoint-2000', help='checkpoint路径') parser.add_argument('--pre_seq_len', type=int, default=128, help='prefix 长度') parser.add_argument('--quantization_bit', type=int, default=None, help='是否量化') parser.add_argument('--port', type=int, default=8000, help='端口') parser.add_argument('--host', type=str, default='0.0.0.0', help='host') # parser.add_argument('--max_input_length', type=int, default=512, help='instruction + input的最大长度') # parser.add_argument('--max_output_length', type=int, default=1536, help='output的最大长度') return parser.parse_args() if __name__ == '__main__': cfg = parse_args() ## ----------- load model -------------- start = time.time() if cfg.checkpoint == "lora": # lora 微调 checkpoint 及模型加载 loader = ModelLoader(cfg.model_name_or_path) loader.load_lora(cfg.checkpoint_path) elif cfg.checkpoint == "ptuning": # ptuning v2 微调 checkpoint 及模型加载 loader = ModelLoader(cfg.model_name_or_path, cfg.pre_seq_len, False) loader.load_prefix(cfg.checkpoint_path) else: loader = ModelLoader(cfg.model_name_or_path) model, tokenizer = loader.models() if cfg.quantization_bit is not None: model = loader.quantize(cfg.quantization_bit) model.cuda().eval() uvicorn.run(app, host=cfg.host, port=cfg.port, workers=1)