import sys
sys.path.append("../..") 

from common import consts
## --------- config -------------
from argparse import Namespace
cfg = Namespace()

#dataset
cfg.train_file = '../../../data/train_spdsvb_v7.csv' 
cfg.val_file = '../../../data/val_spdsvb_v4.csv' 
cfg.prompt_column = 'prompt'
cfg.response_column = 'response'
cfg.history_column = None
cfg.source_prefix = consts.INSTRUCTION_V1 #添加到每个prompt开头的前缀引导语

cfg.max_source_length = 64 
cfg.max_target_length = 128

#model
cfg.model_name_or_path = consts.MODEL_PATH_ChatGLM  #远程'THUDM/chatglm-6b' 
cfg.quantization_bit = None #仅仅预测时可以选 4 or 8 

#train
cfg.epochs = 50
cfg.lr = 1e-3
cfg.batch_size = 1
cfg.gradient_accumulation_steps = 16 #梯度累积

cfg.ckpt_path =  '../../../model/ckpt/chatglm-6b-lora-single-INSq1-{:.0e}-{}'.format(cfg.lr, cfg.epochs)

## --------- end config -------------

## --------- load model -------------
import transformers
from transformers import  AutoModel,AutoTokenizer,AutoConfig,DataCollatorForSeq2Seq

config = AutoConfig.from_pretrained(cfg.model_name_or_path, trust_remote_code=True)

tokenizer = AutoTokenizer.from_pretrained(
    cfg.model_name_or_path, trust_remote_code=True)

model = AutoModel.from_pretrained(cfg.model_name_or_path,config=config,
                                  trust_remote_code=True).half() 

#先量化瘦身
if cfg.quantization_bit is not None:
    print(f"Quantized to {cfg.quantization_bit} bit")
    model = model.quantize(cfg.quantization_bit)
    
#再移动到GPU上
model = model.cuda()
## --------- end load model -------------


## --------- load data -------------
def preprocess(examples):
    max_seq_length = cfg.max_source_length + cfg.max_target_length
    model_inputs = {
        "input_ids": [],
        "labels": [],
    }
    for i in range(len(examples[cfg.prompt_column])):
        if examples[cfg.prompt_column][i] and examples[cfg.response_column][i]:
            query, answer = examples[cfg.prompt_column][i], examples[cfg.response_column][i]

            history = examples[cfg.history_column][i] if cfg.history_column is not None else None
            # prompt = tokenizer.build_prompt(query, history)
            prompt = query

            prompt = cfg.source_prefix + prompt
            a_ids = tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True,
                                     max_length=cfg.max_source_length)
            b_ids = tokenizer.encode(text=answer, add_special_tokens=False, truncation=True,
                                     max_length=cfg.max_target_length)

            context_length = len(a_ids)
            input_ids = a_ids + b_ids + [tokenizer.eos_token_id]
            labels = [tokenizer.pad_token_id] * context_length + b_ids + [tokenizer.eos_token_id]

            pad_len = max_seq_length - len(input_ids)
            input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
            labels = labels + [tokenizer.pad_token_id] * pad_len
            labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]
            model_inputs["input_ids"].append(input_ids)
            model_inputs["labels"].append(labels)
    return model_inputs

import datasets
from datasets import load_dataset

# data_files = {"train": cfg.train_file, "val": cfg.val_file}
# data_sets = load_dataset(cfg.train_file.split(".")[-1], data_files=data_files)



import pandas as pd

description = """1.只有产品代码输入正确时才能进入购买页面;
2.取当前操作员可操作账号,如果账号为监管账户,则不允许操作。
3.产品说明书、产品合同,必须要客户客户阅读后才能进行提交;
4.交易全部落地流程引擎处理;
5.会根据活期账户对应的核心客户号查询该客户是否设置KYC笔笔落地的标识。设置KYC笔笔落地会发送对应的KYC落地邮件信息。(RM, CAS@spd-svbank.com,asu@spd-svbank.com)
"""

keyword = "结构性存款购买的业务规则"

def get_prompt_list(keyword):
    return [f'{keyword}', 
            f'{keyword}是什么?',
            f'介绍一下{keyword}',
            f'你听过{keyword}吗?',
            f'啥是{keyword}?',
            f'{keyword}是什么样的?',
            f'{keyword}有什么要求?',
           ]

data =[{'prompt':x,'response':description} for x in get_prompt_list(keyword) ]
dfdata = pd.DataFrame(data)

for index, row in dfdata.iterrows():
    print(f"[{index}] {row['prompt']}: {row['response'][:10]}...")

ds_train_raw = ds_val_raw = datasets.Dataset.from_pandas(dfdata)

data_sets = {
    "train": ds_train_raw,
    "val": ds_val_raw
}

ds_train = ds_val = None
if cfg.train_file is not None:
    ds_train = data_sets["train"].map(preprocess, batched=True,remove_columns=data_sets["train"].column_names)
    print(data_sets["train"])
if cfg.val_file is not None:
    ds_val = data_sets["val"].map(preprocess, batched=True,remove_columns=data_sets["val"].column_names)
    print(data_sets["val"])


data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=None,
    label_pad_token_id=-100,
    pad_to_multiple_of=None,
    padding=False
)

from torch.utils.data import DataLoader
dl_train = DataLoader(ds_train, batch_size = cfg.batch_size,
                      num_workers = 2, shuffle = True, collate_fn = data_collator 
                     )
dl_val = DataLoader(ds_val, batch_size = cfg.batch_size,
                      num_workers = 2, shuffle = False, collate_fn = data_collator 
                     )
## --------- end load data -------------



## --------- train -------------
from peft import get_peft_model, AdaLoraConfig, TaskType

#训练时节约GPU占用
model.config.use_cache=False
model.supports_gradient_checkpointing = True  #
model.gradient_checkpointing_enable()
model.enable_input_require_grads()

peft_config = AdaLoraConfig(
    task_type=TaskType.CAUSAL_LM, inference_mode=False,
    r=8,
    lora_alpha=32, lora_dropout=0.1,
    target_modules=["query", "value"]
)

peft_model = get_peft_model(model, peft_config)

peft_model.is_parallelizable = True
peft_model.model_parallel = True
peft_model.print_trainable_parameters()


from keras_model import KerasModel
import torch

optimizer = torch.optim.AdamW(peft_model.parameters(),lr=cfg.lr) 
keras_model = KerasModel(peft_model, loss_fn = None, optimizer=optimizer) 

df = keras_model.fit(train_data = dl_train,
                val_data = dl_val,
                epochs=cfg.epochs,
                patience=20,
                monitor='val_loss',
                mode='min',
                ckpt_path = cfg.ckpt_path,
                mixed_precision='fp16',
                gradient_accumulation_steps = cfg.gradient_accumulation_steps
               )
df.to_json(f"{cfg.ckpt_path}/train_history.json")