Commit 54455eae by 陈正乐

新增模型接口方式,本地、openai

parent 847e155f
......@@ -35,7 +35,7 @@ SIMILARITY_THRESHOLD = 0.8
# =============================
# FAISS向量库文件存储路径配置
# =============================
FAISS_STORE_PATH = '../faiss'
FAISS_STORE_PATH = 'C:\\Users\\15663\\code\\dkjj-llm\\LAE\\faiss'
INDEX_NAME = 'know'
# =============================
......@@ -48,3 +48,12 @@ KNOWLEDGE_PATH = 'C:\\Users\\15663\\Desktop\\低空经济数据库'
# =============================
GR_SERVER_NAME = 'localhost'
GR_PORT = 8888
# =============================
# prompt 配置
# =============================
prompt1 = """'''
{context}
'''
请你根据上述已知资料回答下面的问题,问题如下:
{question}"""
\ No newline at end of file
# shellcheck disable=SC1128
#!/bin/bash
# Set the path to the server.py script
SERVER_PATH=server.py
# Set the default values for the arguments
MODEL_NAME_OR_PATH="../../../model/chatglm2-6b"
CHECKPOINT=None
CHECKPOINT_PATH="../../../model/ckpt/chatglm2-6b-qlora-INSv11-rank16-1e-3-30/checkpoint-2000"
PRE_SEQ_LEN=128
QUANTIZATION_BIT=8
PORT=8002
# Call the server.py script with the parsed arguments
python $SERVER_PATH \
--model_name_or_path $MODEL_NAME_OR_PATH \
--checkpoint $CHECKPOINT \
--checkpoint_path $CHECKPOINT_PATH \
--pre_seq_len $PRE_SEQ_LEN \
--quantization_bit $QUANTIZATION_BIT \
--port $PORT
\ No newline at end of file
#!/bin/bash
# Set the path to the server.py script
SERVER_PATH=server.py
# Set the default values for the arguments
MODEL_NAME_OR_PATH="../../../model/chatglm2-6b"
CHECKPOINT=lora
CHECKPOINT_PATH="../../../model/ckpt/chatglm2-6b-qlora-INSv11-rank16-1e-3-30/checkpoint-2000"
QUANTIZATION_BIT=8
PORT=8001
# Call the server.py script with the parsed arguments
python $SERVER_PATH \
--model_name_or_path $MODEL_NAME_OR_PATH \
--checkpoint $CHECKPOINT \
--checkpoint_path $CHECKPOINT_PATH \
--quantization_bit $QUANTIZATION_BIT \
--port $PORT
\ No newline at end of file
import argparse
import time
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
import json
import datetime
import torch
from typing import AsyncIterable
from pydantic import BaseModel
import uvicorn
import signal
from src.llm.loader import ModelLoader
DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(CUDA_DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def build_history(history):
result = []
for item in history if history else []:
result.append((item['q'], item['a']))
return result
def convert_data(data):
result = []
for item in data:
result.append({'q': item[0], 'a': item[1]})
return result
class StreamRequest(BaseModel):
"""Request body for streaming."""
message: str
stop_stream = False
def signal_handler(signal, frame):
global stop_stream
stop_stream = True
async def send_message(message: str, history=[], max_length=2048, top_p=0.7, temperature=0.95) -> AsyncIterable[str]:
global model, tokenizer, stop_stream
count = 0
old_len = 0
print(message)
output = ''
for response, history in model.stream_chat(tokenizer, message, history=history,
max_length=max_length,
top_p=top_p,
temperature=temperature):
# print(old_len,count)
if stop_stream:
stop_stream = False
break
else:
output = response[old_len:]
print(output, end='', flush=True)
# print(output)
old_len = len(response)
signal.signal(signal.SIGINT, signal_handler)
yield f"{output}"
print("")
# yield f"\n"
# print()
app = FastAPI()
@app.post("/stream")
async def stream(request: Request):
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
history = build_history(json_post_list.get('history'))
max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
return StreamingResponse(send_message(prompt, history=history, max_length=max_length if max_length else 2048,
top_p=top_p if top_p else 0.7,
temperature=temperature if temperature else 0.95), media_type="text/plain")
@app.post("/")
async def create_item(request: Request):
global model, tokenizer
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
history = build_history(json_post_list.get('history'))
max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
response, history = model.chat(tokenizer,
prompt,
history=history,
max_length=max_length if max_length else 2048,
top_p=top_p if top_p else 0.7,
temperature=temperature if temperature else 0.95)
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = {
"response": response,
"history": history,
"status": 200,
"time": time
}
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log)
torch_gc()
return answer
@app.post("/tokens")
async def get_num_tokens(request: Request):
global model, tokenizer
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
tokens = tokenizer.encode(prompt, add_special_tokens=False)
print("=======================================")
print("=======================================")
print(len(tokens), prompt)
print("=======================================")
print("=======================================")
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = {
"response": len(tokens),
"status": 200,
"time": time
}
return answer
def parse_args():
parser = argparse.ArgumentParser(description='ChatGLM2-6B Server')
parser.add_argument('--model_name_or_path', type=str, default='THUDM/chatglm2-6b', help='模型id或local path')
parser.add_argument('--checkpoint', type=str, default=None, help='checkpoint类型(None、ptuning、lora)')
parser.add_argument('--checkpoint_path', type=str,
default='../../../model/ckpt/chatglm2-6b-qlora-INSv11-rank16-1e-3-30/checkpoint-2000',
help='checkpoint路径')
parser.add_argument('--pre_seq_len', type=int, default=128, help='prefix 长度')
parser.add_argument('--quantization_bit', type=int, default=None, help='是否量化')
parser.add_argument('--port', type=int, default=8000, help='端口')
parser.add_argument('--host', type=str, default='0.0.0.0', help='host')
# parser.add_argument('--max_input_length', type=int, default=512, help='instruction + input的最大长度')
# parser.add_argument('--max_output_length', type=int, default=1536, help='output的最大长度')
return parser.parse_args()
if __name__ == '__main__':
cfg = parse_args()
## ----------- load model --------------
start = time.time()
if cfg.checkpoint == "lora":
# lora 微调 checkpoint 及模型加载
loader = ModelLoader(cfg.model_name_or_path)
loader.load_lora(cfg.checkpoint_path)
elif cfg.checkpoint == "ptuning":
# ptuning v2 微调 checkpoint 及模型加载
loader = ModelLoader(cfg.model_name_or_path, cfg.pre_seq_len, False)
loader.load_prefix(cfg.checkpoint_path)
else:
loader = ModelLoader(cfg.model_name_or_path)
model, tokenizer = loader.models()
if cfg.quantization_bit is not None:
model = loader.quantize(cfg.quantization_bit)
model.cuda().eval()
uvicorn.run(app, host=cfg.host, port=cfg.port, workers=1)
# -*- coding: utf-8 -*-
import sys
import time
from datetime import datetime
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from typing import Awaitable
......@@ -118,7 +117,7 @@ if __name__ == "__main__":
chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
vecstore_faiss = VectorStore_FAISS(
embedding_model_name=EMBEEDING_MODEL_PATH,
store_path='../../faiss',
store_path=FAISS_STORE_PATH,
index_name=INDEX_NAME,
info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
"password": VEC_DB_PASSWORD},
......
import sys
sys.path.append("../")
from src.pgdb.chat.c_db import UPostgresDB
from src.pgdb.chat.chat_table import Chat
from src.pgdb.chat.c_user_table import CUser
from src.pgdb.chat.turn_qa_table import TurnQa
from src.pgdb.chat.crud import CRUD
from src.config.consts import (
CHAT_DB_USER,
......@@ -22,17 +17,17 @@ def test():
port=CHAT_DB_PORT, )
print(c_db)
crud = CRUD(c_db)
crud.create_table()
crud.insert_turn_qa("2", "wen4", "da1", 1, 0)
crud.insert_turn_qa("2", "wen4", "da1", 2, 0)
crud.insert_turn_qa("2", "wen4", "da1", 5, 0)
crud.insert_turn_qa("2", "wen4", "da1", 4, 0)
crud.insert_turn_qa("2", "wen4", "da1", 3, 0)
crud.insert_turn_qa("2", "wen4", "da1", 6, 0)
crud.insert_turn_qa("2", "wen4", "da1", 8, 0)
crud.insert_turn_qa("2", "wen4", "da1", 7, 0)
crud.insert_turn_qa("2", "wen4", "da1", 9, 0)
# crud.create_table()
# crud.insert_turn_qa("2", "wen4", "da1", 1, 0)
# crud.insert_turn_qa("2", "wen4", "da1", 2, 0)
# crud.insert_turn_qa("2", "wen4", "da1", 5, 0)
# crud.insert_turn_qa("2", "wen4", "da1", 4, 0)
# crud.insert_turn_qa("2", "wen4", "da1", 3, 0)
# crud.insert_turn_qa("2", "wen4", "da1", 6, 0)
# crud.insert_turn_qa("2", "wen4", "da1", 8, 0)
# crud.insert_turn_qa("2", "wen4", "da1", 7, 0)
# crud.insert_turn_qa("2", "wen4", "da1", 9, 0)
crud.insert_c_user('zhangs','111111')
print(crud.get_history('2'))
......
# -*- coding: utf-8 -*-
import gradio as gr
from langchain.prompts import PromptTemplate
from src.llm.chatglm import ChatGLMSerLLM
from src.llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion
import os
from src.pgdb.chat.c_db import UPostgresDB
from src.server.get_similarity import GetSimilarity
from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.config.consts import (
CHAT_DB_USER,
......@@ -46,10 +48,19 @@ def main():
"password": VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER,
reset=False)
base_llm = ChatERNIESerLLM(
chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
# base_llm = ChatERNIESerLLM(
# chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
base_llm = ChatGLMSerLLM(url='http://192.168.22.106:8088')
my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id='2',
_faiss_db=vecstore_faiss)
def clear_q_a():
return '', ''
def show_history():
return my_chat.get_history()
with gr.Blocks() as demo:
with gr.Row():
inn = gr.Textbox(show_label=True, lines=10)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment