├── .DS_Store ├── resources ├── math.png ├── cli-demo.png ├── web-demo.gif ├── web-demo.png ├── wechat.jpg ├── knowledge.png ├── long-context.png └── WECHAT.md ├── requirements.txt ├── scripts ├── web_demo.sh ├── train.sh ├── train_chat.sh ├── train_ptuning.sh ├── train_lora.sh ├── ds_train_ptuning.sh └── ds_train_peft.sh ├── chatglm_model_v1 ├── web_demo.sh ├── deepspeed.json ├── evaluate_finetune.sh ├── evaluate.sh ├── train.sh ├── train_chat.sh ├── ds_train_finetune.sh ├── web_demo.py ├── README.md ├── README_en.md ├── arguments.py ├── trainer_seq2seq.py └── run_ptuning.py ├── evaluation ├── README.md └── evaluate_ceval.py ├── deepspeed └── deepspeed.json ├── chatglm_model_v2 ├── evaluate_finetune.sh ├── evaluate.sh ├── trainer.py ├── README.md ├── web_demo.py ├── arguments.py ├── trainer_seq2seq.py ├── run_ptuning.py └── run_peft.py ├── FAQ.md ├── application ├── cli_demo.py ├── api.py ├── web_demo2.py ├── web_demo.py └── openai_api.py ├── utils.py ├── MODEL_LICENSE ├── README.md └── LICENSE /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wjn1996/ChatGLM2-Tuning/HEAD/.DS_Store -------------------------------------------------------------------------------- /resources/math.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wjn1996/ChatGLM2-Tuning/HEAD/resources/math.png -------------------------------------------------------------------------------- /resources/cli-demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wjn1996/ChatGLM2-Tuning/HEAD/resources/cli-demo.png -------------------------------------------------------------------------------- /resources/web-demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wjn1996/ChatGLM2-Tuning/HEAD/resources/web-demo.gif -------------------------------------------------------------------------------- /resources/web-demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wjn1996/ChatGLM2-Tuning/HEAD/resources/web-demo.png -------------------------------------------------------------------------------- /resources/wechat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wjn1996/ChatGLM2-Tuning/HEAD/resources/wechat.jpg -------------------------------------------------------------------------------- /resources/knowledge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wjn1996/ChatGLM2-Tuning/HEAD/resources/knowledge.png -------------------------------------------------------------------------------- /resources/long-context.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wjn1996/ChatGLM2-Tuning/HEAD/resources/long-context.png -------------------------------------------------------------------------------- /resources/WECHAT.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 |

扫码关注公众号,加入「ChatGLM交流群」

5 |

Scan the QR code to follow the official account and join the "ChatGLM Discussion Group"

6 |
7 | 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | protobuf 2 | transformers==4.30.2 3 | cpm_kernels 4 | torch>=2.0 5 | gradio 6 | mdtex2html 7 | sentencepiece 8 | accelerate 9 | sse-starlette 10 | streamlit>=1.24.0 11 | rouge_chinese 12 | nltk 13 | jieba 14 | datasets 15 | peft -------------------------------------------------------------------------------- /scripts/web_demo.sh: -------------------------------------------------------------------------------- 1 | PRE_SEQ_LEN=128 2 | 3 | CUDA_VISIBLE_DEVICES=0 python3 web_demo.py \ 4 | --model_name_or_path THUDM/chatglm2-6b \ 5 | --ptuning_checkpoint output/adgen-chatglm2-6b-pt-128-2e-2/checkpoint-3000 \ 6 | --pre_seq_len $PRE_SEQ_LEN 7 | 8 | -------------------------------------------------------------------------------- /chatglm_model_v1/web_demo.sh: -------------------------------------------------------------------------------- 1 | PRE_SEQ_LEN=128 2 | 3 | CUDA_VISIBLE_DEVICES=0 python3 web_demo.py \ 4 | --model_name_or_path THUDM/chatglm-6b \ 5 | --ptuning_checkpoint output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000 \ 6 | --pre_seq_len $PRE_SEQ_LEN 7 | 8 | -------------------------------------------------------------------------------- /evaluation/README.md: -------------------------------------------------------------------------------- 1 | 首先从 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/e84444333b6d434ea7b0) 下载处理好的 C-Eval 数据集,解压到 `evaluation` 目录下。然后运行 2 | 3 | ```shell 4 | cd evaluation 5 | python evaluate_ceval.py 6 | ``` 7 | 8 | 这个脚本会在C-Eval的验证集上进行预测并输出准确率。如果想要得到测试集上的结果可以将代码中的 `./CEval/val/**/*.jsonl` 改为 `./CEval/test/**/*.jsonl`,并按照 C-Eval 规定的格式保存结果并在 [官网](https://cevalbenchmark.com/) 上提交。 9 | 10 | 汇报的结果使用的是内部的并行测试框架,结果可能会有轻微波动。 -------------------------------------------------------------------------------- /deepspeed/deepspeed.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": "auto", 3 | "zero_allow_untested_optimizer": true, 4 | "fp16": { 5 | "enabled": "auto", 6 | "loss_scale": 0, 7 | "initial_scale_power": 16, 8 | "loss_scale_window": 1000, 9 | "hysteresis": 2, 10 | "min_loss_scale": 1 11 | }, 12 | "zero_optimization": { 13 | "stage": 2, 14 | "allgather_partitions": true, 15 | "allgather_bucket_size": 5e8, 16 | "overlap_comm": false, 17 | "reduce_scatter": true, 18 | "reduce_bucket_size": 5e8, 19 | "contiguous_gradients" : true 20 | } 21 | } -------------------------------------------------------------------------------- /chatglm_model_v1/deepspeed.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": "auto", 3 | "zero_allow_untested_optimizer": true, 4 | "fp16": { 5 | "enabled": "auto", 6 | "loss_scale": 0, 7 | "initial_scale_power": 16, 8 | "loss_scale_window": 1000, 9 | "hysteresis": 2, 10 | "min_loss_scale": 1 11 | }, 12 | "zero_optimization": { 13 | "stage": 2, 14 | "allgather_partitions": true, 15 | "allgather_bucket_size": 5e8, 16 | "overlap_comm": false, 17 | "reduce_scatter": true, 18 | "reduce_bucket_size": 5e8, 19 | "contiguous_gradients" : true 20 | } 21 | } -------------------------------------------------------------------------------- /chatglm_model_v1/evaluate_finetune.sh: -------------------------------------------------------------------------------- 1 | CHECKPOINT=adgen-chatglm-6b-ft-1e-4 2 | STEP=3000 3 | 4 | CUDA_VISIBLE_DEVICES=0 python3 main.py \ 5 | --do_predict \ 6 | --validation_file AdvertiseGen/dev.json \ 7 | --test_file AdvertiseGen/dev.json \ 8 | --overwrite_cache \ 9 | --prompt_column content \ 10 | --response_column summary \ 11 | --model_name_or_path ./output/$CHECKPOINT/checkpoint-$STEP \ 12 | --output_dir ./output/$CHECKPOINT \ 13 | --overwrite_output_dir \ 14 | --max_source_length 256 \ 15 | --max_target_length 256 \ 16 | --per_device_eval_batch_size 1 \ 17 | --predict_with_generate \ 18 | --fp16_full_eval 19 | -------------------------------------------------------------------------------- /chatglm_model_v2/evaluate_finetune.sh: -------------------------------------------------------------------------------- 1 | CHECKPOINT=adgen-chatglm2-6b-ft-1e-4 2 | STEP=3000 3 | NUM_GPUS=1 4 | 5 | torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \ 6 | --do_predict \ 7 | --validation_file AdvertiseGen/dev.json \ 8 | --test_file AdvertiseGen/dev.json \ 9 | --overwrite_cache \ 10 | --prompt_column content \ 11 | --response_column summary \ 12 | --model_name_or_path ./output/$CHECKPOINT/checkpoint-$STEP \ 13 | --output_dir ./output/$CHECKPOINT \ 14 | --overwrite_output_dir \ 15 | --max_source_length 256 \ 16 | --max_target_length 256 \ 17 | --per_device_eval_batch_size 1 \ 18 | --predict_with_generate \ 19 | --fp16_full_eval 20 | -------------------------------------------------------------------------------- /FAQ.md: -------------------------------------------------------------------------------- 1 | ## Q1 2 | 3 | **Mac直接加载量化后的模型出现提示 `clang: error: unsupported option '-fopenmp'** 4 | 5 | 这是由于Mac由于本身缺乏omp导致的,此时可运行但是单核。需要单独安装 openmp 依赖,即可在Mac下使用OMP: 6 | 7 | ```bash 8 | # 参考`https://mac.r-project.org/openmp/` 9 | ## 假设: gcc(clang)是14.x版本,其他版本见R-Project提供的表格 10 | curl -O https://mac.r-project.org/openmp/openmp-14.0.6-darwin20-Release.tar.gz 11 | sudo tar fvxz openmp-14.0.6-darwin20-Release.tar.gz -C / 12 | ``` 13 | 此时会安装下面几个文件:`/usr/local/lib/libomp.dylib`, `/usr/local/include/ompt.h`, `/usr/local/include/omp.h`, `/usr/local/include/omp-tools.h`。 14 | 15 | > 注意:如果你之前运行`ChatGLM2-6B`项目失败过,最好清一下Hugging Face的缓存,i.e. 默认下是 `rm -rf ${HOME}/.cache/huggingface/modules/transformers_modules/chatglm-6b-int4`。由于使用了`rm`命令,请明确知道自己在删除什么。 16 | -------------------------------------------------------------------------------- /chatglm_model_v1/evaluate.sh: -------------------------------------------------------------------------------- 1 | PRE_SEQ_LEN=128 2 | CHECKPOINT=adgen-chatglm-6b-pt-128-2e-2 3 | STEP=3000 4 | 5 | CUDA_VISIBLE_DEVICES=0 python3 main.py \ 6 | --do_predict \ 7 | --validation_file AdvertiseGen/dev.json \ 8 | --test_file AdvertiseGen/dev.json \ 9 | --overwrite_cache \ 10 | --prompt_column content \ 11 | --response_column summary \ 12 | --model_name_or_path THUDM/chatglm-6b \ 13 | --ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \ 14 | --output_dir ./output/$CHECKPOINT \ 15 | --overwrite_output_dir \ 16 | --max_source_length 64 \ 17 | --max_target_length 64 \ 18 | --per_device_eval_batch_size 1 \ 19 | --predict_with_generate \ 20 | --pre_seq_len $PRE_SEQ_LEN \ 21 | --quantization_bit 4 22 | -------------------------------------------------------------------------------- /chatglm_model_v2/evaluate.sh: -------------------------------------------------------------------------------- 1 | PRE_SEQ_LEN=128 2 | CHECKPOINT=adgen-chatglm2-6b-pt-128-2e-2 3 | STEP=3000 4 | NUM_GPUS=1 5 | 6 | torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \ 7 | --do_predict \ 8 | --validation_file AdvertiseGen/dev.json \ 9 | --test_file AdvertiseGen/dev.json \ 10 | --overwrite_cache \ 11 | --prompt_column content \ 12 | --response_column summary \ 13 | --model_name_or_path THUDM/chatglm2-6b \ 14 | --ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \ 15 | --output_dir ./output/$CHECKPOINT \ 16 | --overwrite_output_dir \ 17 | --max_source_length 64 \ 18 | --max_target_length 64 \ 19 | --per_device_eval_batch_size 1 \ 20 | --predict_with_generate \ 21 | --pre_seq_len $PRE_SEQ_LEN \ 22 | --quantization_bit 4 23 | -------------------------------------------------------------------------------- /chatglm_model_v1/train.sh: -------------------------------------------------------------------------------- 1 | PRE_SEQ_LEN=128 2 | LR=2e-2 3 | 4 | CUDA_VISIBLE_DEVICES=0 python3 main.py \ 5 | --do_train \ 6 | --train_file AdvertiseGen/train.json \ 7 | --validation_file AdvertiseGen/dev.json \ 8 | --prompt_column content \ 9 | --response_column summary \ 10 | --overwrite_cache \ 11 | --model_name_or_path THUDM/chatglm-6b \ 12 | --output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \ 13 | --overwrite_output_dir \ 14 | --max_source_length 64 \ 15 | --max_target_length 64 \ 16 | --per_device_train_batch_size 1 \ 17 | --per_device_eval_batch_size 1 \ 18 | --gradient_accumulation_steps 16 \ 19 | --predict_with_generate \ 20 | --max_steps 3000 \ 21 | --logging_steps 10 \ 22 | --save_steps 1000 \ 23 | --learning_rate $LR \ 24 | --pre_seq_len $PRE_SEQ_LEN \ 25 | --quantization_bit 4 26 | 27 | -------------------------------------------------------------------------------- /chatglm_model_v1/train_chat.sh: -------------------------------------------------------------------------------- 1 | PRE_SEQ_LEN=128 2 | LR=1e-2 3 | 4 | CUDA_VISIBLE_DEVICES=0 python3 main.py \ 5 | --do_train \ 6 | --train_file $CHAT_TRAIN_DATA \ 7 | --validation_file $CHAT_VAL_DATA \ 8 | --prompt_column prompt \ 9 | --response_column response \ 10 | --history_column history \ 11 | --overwrite_cache \ 12 | --model_name_or_path THUDM/chatglm-6b \ 13 | --output_dir $CHECKPOINT_NAME \ 14 | --overwrite_output_dir \ 15 | --max_source_length 256 \ 16 | --max_target_length 256 \ 17 | --per_device_train_batch_size 1 \ 18 | --per_device_eval_batch_size 1 \ 19 | --gradient_accumulation_steps 16 \ 20 | --predict_with_generate \ 21 | --max_steps 3000 \ 22 | --logging_steps 10 \ 23 | --save_steps 1000 \ 24 | --learning_rate $LR \ 25 | --pre_seq_len $PRE_SEQ_LEN \ 26 | --quantization_bit 4 27 | 28 | -------------------------------------------------------------------------------- /chatglm_model_v1/ds_train_finetune.sh: -------------------------------------------------------------------------------- 1 | 2 | LR=1e-4 3 | 4 | MASTER_PORT=$(shuf -n 1 -i 10000-65535) 5 | 6 | deepspeed --num_gpus=4 --master_port $MASTER_PORT main.py \ 7 | --deepspeed deepspeed.json \ 8 | --do_train \ 9 | --train_file AdvertiseGen/train.json \ 10 | --test_file AdvertiseGen/dev.json \ 11 | --prompt_column content \ 12 | --response_column summary \ 13 | --overwrite_cache \ 14 | --model_name_or_path THUDM/chatglm-6b \ 15 | --output_dir ./output/adgen-chatglm-6b-ft-$LR \ 16 | --overwrite_output_dir \ 17 | --max_source_length 64 \ 18 | --max_target_length 64 \ 19 | --per_device_train_batch_size 4 \ 20 | --per_device_eval_batch_size 1 \ 21 | --gradient_accumulation_steps 1 \ 22 | --predict_with_generate \ 23 | --max_steps 5000 \ 24 | --logging_steps 10 \ 25 | --save_steps 1000 \ 26 | --learning_rate $LR \ 27 | --fp16 28 | 29 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | PRE_SEQ_LEN=128 2 | LR=2e-2 3 | NUM_GPUS=1 4 | 5 | torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \ 6 | --do_train \ 7 | --train_file AdvertiseGen/train.json \ 8 | --validation_file AdvertiseGen/dev.json \ 9 | --preprocessing_num_workers 10 \ 10 | --prompt_column content \ 11 | --response_column summary \ 12 | --overwrite_cache \ 13 | --model_name_or_path THUDM/chatglm2-6b \ 14 | --output_dir output/adgen-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \ 15 | --overwrite_output_dir \ 16 | --max_source_length 64 \ 17 | --max_target_length 128 \ 18 | --per_device_train_batch_size 1 \ 19 | --per_device_eval_batch_size 1 \ 20 | --gradient_accumulation_steps 16 \ 21 | --predict_with_generate \ 22 | --max_steps 3000 \ 23 | --logging_steps 10 \ 24 | --save_steps 1000 \ 25 | --learning_rate $LR \ 26 | --pre_seq_len $PRE_SEQ_LEN \ 27 | --quantization_bit 4 28 | 29 | -------------------------------------------------------------------------------- /scripts/train_chat.sh: -------------------------------------------------------------------------------- 1 | PRE_SEQ_LEN=128 2 | LR=1e-2 3 | NUM_GPUS=1 4 | 5 | torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \ 6 | --do_train \ 7 | --train_file $CHAT_TRAIN_DATA \ 8 | --validation_file $CHAT_VAL_DATA \ 9 | --preprocessing_num_workers 10 \ 10 | --prompt_column prompt \ 11 | --response_column response \ 12 | --history_column history \ 13 | --overwrite_cache \ 14 | --model_name_or_path THUDM/chatglm2-6b \ 15 | --output_dir $CHECKPOINT_NAME \ 16 | --overwrite_output_dir \ 17 | --max_source_length 256 \ 18 | --max_target_length 256 \ 19 | --per_device_train_batch_size 1 \ 20 | --per_device_eval_batch_size 1 \ 21 | --gradient_accumulation_steps 16 \ 22 | --predict_with_generate \ 23 | --max_steps 3000 \ 24 | --logging_steps 10 \ 25 | --save_steps 1000 \ 26 | --learning_rate $LR \ 27 | --pre_seq_len $PRE_SEQ_LEN \ 28 | --quantization_bit 4 29 | 30 | -------------------------------------------------------------------------------- /scripts/train_ptuning.sh: -------------------------------------------------------------------------------- 1 | TASK_NAME=default_task 2 | PRE_SEQ_LEN=128 3 | LR=2e-2 4 | 5 | CHAT_TRAIN_DATA=./data/train.json 6 | CHAT_VAL_DATA=./data/dev.json 7 | 8 | MODEL_NAME_OR_PATH=pre-trained-lm/chatglm-6b 9 | 10 | NUM_GPUS=8 11 | 12 | MODEL_VERSION=v1 # v1 or v2 13 | 14 | 15 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 16 | torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS chatglm_model_$MODEL_VERSION/run_ptuning.py \ 17 | --do_train \ 18 | --train_file $CHAT_TRAIN_DATA \ 19 | --validation_file $CHAT_VAL_DATA \ 20 | --prompt_column input \ 21 | --response_column output \ 22 | --model_name_or_path $MODEL_NAME_OR_PATH \ 23 | --output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \ 24 | --overwrite_output_dir \ 25 | --max_source_length 256 \ 26 | --max_target_length 256 \ 27 | --per_device_train_batch_size 32 \ 28 | --per_device_eval_batch_size 32 \ 29 | --gradient_accumulation_steps 1 \ 30 | --predict_with_generate \ 31 | --max_steps 9000 \ 32 | --logging_steps 10 \ 33 | --save_steps 1000 \ 34 | --learning_rate $LR \ 35 | --pre_seq_len $PRE_SEQ_LEN \ 36 | --task_name $TASK_NAME \ 37 | --base_cache_dir ./.cache/ 38 | # --quantization_bit 4 -------------------------------------------------------------------------------- /scripts/train_lora.sh: -------------------------------------------------------------------------------- 1 | 2 | PEFT_TYPE=lora 3 | LORA_DIM=8 4 | LR=2e-2 5 | NUM_GPUS=1 6 | TRAIN_DATA=./data/train.json 7 | EVAL_DATA=./data/dev.json 8 | 9 | MODEL_VERSION=v1 # v1 or v2 10 | 11 | torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS chatglm_model_$MODEL_VERSION/run_peft.py \ 12 | --do_train \ 13 | --train_file $TRAIN_DATA \ 14 | --validation_file $EVAL_DATA \ 15 | --preprocessing_num_workers 10 \ 16 | --prompt_column content \ 17 | --response_column summary \ 18 | --model_name_or_path THUDM/chatglm2-6b \ 19 | --output_dir output/adgen-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \ 20 | --overwrite_output_dir \ 21 | --max_source_length 256 \ 22 | --max_target_length 256 \ 23 | --per_device_train_batch_size 32 \ 24 | --per_device_eval_batch_size 32 \ 25 | --gradient_accumulation_steps 1 \ 26 | --predict_with_generate \ 27 | --max_steps 9000 \ 28 | --logging_steps 10 \ 29 | --save_steps 1000 \ 30 | --learning_rate $LR \ 31 | --peft_type $PEFT_TYPE \ 32 | --lora_dim $LORA_DIM \ 33 | --task_name $TASK_NAME \ 34 | --base_cache_dir ./cache 35 | # --quantization_bit 4 36 | # --overwrite_cache \ 37 | 38 | -------------------------------------------------------------------------------- /scripts/ds_train_ptuning.sh: -------------------------------------------------------------------------------- 1 | TASK_NAME=default_task 2 | PRE_SEQ_LEN=128 3 | LR=1e-4 4 | 5 | CHAT_TRAIN_DATA=data/train.json 6 | CHAT_VAL_DATA=data/dev.json 7 | 8 | MODEL_NAME_OR_PATH=pre-trained-lm/chatglm-6b 9 | 10 | NUM_GPUS=8 11 | 12 | MASTER_PORT=$(shuf -n 1 -i 10000-65535) 13 | 14 | MODEL_VERSION=v1 # v1 or v2 15 | 16 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 17 | deepspeed --num_gpus=$NUM_GPUS --master_port $MASTER_PORT chatglm_model_$MODEL_VERSION/run_ptuning.py \ 18 | --deepspeed deepspeed/deepspeed.json \ 19 | --do_train \ 20 | --train_file $CHAT_TRAIN_DATA \ 21 | --test_file $CHAT_VAL_DATA \ 22 | --prompt_column input \ 23 | --response_column output \ 24 | --model_name_or_path $MODEL_NAME_OR_PATH \ 25 | --output_dir ./output/deepspeed/adgen-chatglm-6b-ft-$LR \ 26 | --overwrite_output_dir \ 27 | --max_source_length 256 \ 28 | --max_target_length 256 \ 29 | --per_device_train_batch_size 32 \ 30 | --per_device_eval_batch_size 32 \ 31 | --gradient_accumulation_steps 1 \ 32 | --predict_with_generate \ 33 | --max_steps 9000 \ 34 | --logging_steps 10 \ 35 | --save_steps 1000 \ 36 | --learning_rate $LR \ 37 | --task_name $TASK_NAME \ 38 | --base_cache_dir ./.cache \ 39 | --fp16 40 | # --overwrite_cache \ -------------------------------------------------------------------------------- /scripts/ds_train_peft.sh: -------------------------------------------------------------------------------- 1 | TASK_NAME=default_task 2 | # PRE_SEQ_LEN=128 3 | PEFT_TYPE=lora 4 | LORA_DIM=8 5 | LR=1e-4 6 | 7 | CHAT_TRAIN_DATA=./data/train.json 8 | CHAT_VAL_DATA=./data/dev.json 9 | 10 | MODEL_NAME_OR_PATH=./pre-trained-lm/chatglm-6b 11 | 12 | NUM_GPUS=8 13 | 14 | MASTER_PORT=$(shuf -n 1 -i 10000-65535) 15 | 16 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 17 | deepspeed --num_gpus=$NUM_GPUS --master_port $MASTER_PORT chatglm_model_v1/run_peft.py \ 18 | --deepspeed deepspeed/deepspeed.json \ 19 | --do_train \ 20 | --train_file $CHAT_TRAIN_DATA \ 21 | --test_file $CHAT_VAL_DATA \ 22 | --prompt_column input \ 23 | --response_column output \ 24 | --model_name_or_path $MODEL_NAME_OR_PATH \ 25 | --output_dir ./output/deepspeed/chatglm-6b-$TASK_NAME-$PEFT_TYPE-$LORA_DIM-$LR \ 26 | --overwrite_output_dir \ 27 | --max_source_length 256 \ 28 | --max_target_length 1024 \ 29 | --per_device_train_batch_size 32 \ 30 | --per_device_eval_batch_size 32 \ 31 | --gradient_accumulation_steps 1 \ 32 | --predict_with_generate \ 33 | --max_steps 9000 \ 34 | --logging_steps 10 \ 35 | --save_steps 1000 \ 36 | --learning_rate $LR \ 37 | --peft_type $PEFT_TYPE \ 38 | --lora_dim $LORA_DIM \ 39 | --task_name $TASK_NAME \ 40 | --base_cache_dir ./.cache/ \ 41 | --fp16 42 | # --overwrite_cache \ -------------------------------------------------------------------------------- /application/cli_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | import signal 4 | from transformers import AutoTokenizer, AutoModel 5 | import readline 6 | 7 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) 8 | model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda() 9 | # 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量 10 | # from utils import load_model_on_gpus 11 | # model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2) 12 | model = model.eval() 13 | 14 | os_name = platform.system() 15 | clear_command = 'cls' if os_name == 'Windows' else 'clear' 16 | stop_stream = False 17 | 18 | 19 | def build_prompt(history): 20 | prompt = "欢迎使用 ChatGLM2-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序" 21 | for query, response in history: 22 | prompt += f"\n\n用户:{query}" 23 | prompt += f"\n\nChatGLM2-6B:{response}" 24 | return prompt 25 | 26 | 27 | def signal_handler(signal, frame): 28 | global stop_stream 29 | stop_stream = True 30 | 31 | 32 | def main(): 33 | past_key_values, history = None, [] 34 | global stop_stream 35 | print("欢迎使用 ChatGLM2-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") 36 | while True: 37 | query = input("\n用户:") 38 | if query.strip() == "stop": 39 | break 40 | if query.strip() == "clear": 41 | past_key_values, history = None, [] 42 | os.system(clear_command) 43 | print("欢迎使用 ChatGLM2-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") 44 | continue 45 | print("\nChatGLM:", end="") 46 | current_length = 0 47 | for response, history, past_key_values in model.stream_chat(tokenizer, query, history=history, 48 | past_key_values=past_key_values, 49 | return_past_key_values=True): 50 | if stop_stream: 51 | stop_stream = False 52 | break 53 | else: 54 | print(response[current_length:], end="", flush=True) 55 | current_length = len(response) 56 | print("") 57 | 58 | 59 | if __name__ == "__main__": 60 | main() 61 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, Tuple, Union, Optional 3 | 4 | from torch.nn import Module 5 | from transformers import AutoModel 6 | 7 | 8 | def auto_configure_device_map(num_gpus: int) -> Dict[str, int]: 9 | # transformer.word_embeddings 占用1层 10 | # transformer.final_layernorm 和 lm_head 占用1层 11 | # transformer.layers 占用 28 层 12 | # 总共30层分配到num_gpus张卡上 13 | num_trans_layers = 28 14 | per_gpu_layers = 30 / num_gpus 15 | 16 | # bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError 17 | # windows下 model.device 会被设置成 transformer.word_embeddings.device 18 | # linux下 model.device 会被设置成 lm_head.device 19 | # 在调用chat或者stream_chat时,input_ids会被放到model.device上 20 | # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError 21 | # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上 22 | # 本文件来源于https://github.com/THUDM/ChatGLM-6B/blob/main/utils.py 23 | # 仅此处做少许修改以支持ChatGLM2 24 | device_map = { 25 | 'transformer.embedding.word_embeddings': 0, 26 | 'transformer.encoder.final_layernorm': 0, 27 | 'transformer.output_layer': 0, 28 | 'transformer.rotary_pos_emb': 0, 29 | 'lm_head': 0 30 | } 31 | 32 | used = 2 33 | gpu_target = 0 34 | for i in range(num_trans_layers): 35 | if used >= per_gpu_layers: 36 | gpu_target += 1 37 | used = 0 38 | assert gpu_target < num_gpus 39 | device_map[f'transformer.encoder.layers.{i}'] = gpu_target 40 | used += 1 41 | 42 | return device_map 43 | 44 | 45 | def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 2, 46 | device_map: Optional[Dict[str, int]] = None, **kwargs) -> Module: 47 | if num_gpus < 2 and device_map is None: 48 | model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half().cuda() 49 | else: 50 | from accelerate import dispatch_model 51 | 52 | model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half() 53 | 54 | if device_map is None: 55 | device_map = auto_configure_device_map(num_gpus) 56 | 57 | model = dispatch_model(model, device_map=device_map) 58 | 59 | return model 60 | -------------------------------------------------------------------------------- /application/api.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Request 2 | from transformers import AutoTokenizer, AutoModel 3 | import uvicorn, json, datetime 4 | import torch 5 | 6 | DEVICE = "cuda" 7 | DEVICE_ID = "0" 8 | CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE 9 | 10 | 11 | def torch_gc(): 12 | if torch.cuda.is_available(): 13 | with torch.cuda.device(CUDA_DEVICE): 14 | torch.cuda.empty_cache() 15 | torch.cuda.ipc_collect() 16 | 17 | 18 | app = FastAPI() 19 | 20 | 21 | @app.post("/") 22 | async def create_item(request: Request): 23 | global model, tokenizer 24 | json_post_raw = await request.json() 25 | json_post = json.dumps(json_post_raw) 26 | json_post_list = json.loads(json_post) 27 | prompt = json_post_list.get('prompt') 28 | history = json_post_list.get('history') 29 | max_length = json_post_list.get('max_length') 30 | top_p = json_post_list.get('top_p') 31 | temperature = json_post_list.get('temperature') 32 | response, history = model.chat(tokenizer, 33 | prompt, 34 | history=history, 35 | max_length=max_length if max_length else 2048, 36 | top_p=top_p if top_p else 0.7, 37 | temperature=temperature if temperature else 0.95) 38 | now = datetime.datetime.now() 39 | time = now.strftime("%Y-%m-%d %H:%M:%S") 40 | answer = { 41 | "response": response, 42 | "history": history, 43 | "status": 200, 44 | "time": time 45 | } 46 | log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' 47 | print(log) 48 | torch_gc() 49 | return answer 50 | 51 | 52 | if __name__ == '__main__': 53 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) 54 | model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda() 55 | # 多显卡支持,使用下面三行代替上面两行,将num_gpus改为你实际的显卡数量 56 | # model_path = "THUDM/chatglm2-6b" 57 | # tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 58 | # model = load_model_on_gpus(model_path, num_gpus=2) 59 | model.eval() 60 | uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) 61 | -------------------------------------------------------------------------------- /application/web_demo2.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModel, AutoTokenizer 2 | import streamlit as st 3 | 4 | 5 | st.set_page_config( 6 | page_title="ChatGLM2-6b 演示", 7 | page_icon=":robot:", 8 | layout='wide' 9 | ) 10 | 11 | 12 | @st.cache_resource 13 | def get_model(): 14 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) 15 | model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda() 16 | # 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量 17 | # from utils import load_model_on_gpus 18 | # model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2) 19 | model = model.eval() 20 | return tokenizer, model 21 | 22 | 23 | tokenizer, model = get_model() 24 | 25 | st.title("ChatGLM2-6B") 26 | 27 | max_length = st.sidebar.slider( 28 | 'max_length', 0, 32768, 8192, step=1 29 | ) 30 | top_p = st.sidebar.slider( 31 | 'top_p', 0.0, 1.0, 0.8, step=0.01 32 | ) 33 | temperature = st.sidebar.slider( 34 | 'temperature', 0.0, 1.0, 0.8, step=0.01 35 | ) 36 | 37 | if 'history' not in st.session_state: 38 | st.session_state.history = [] 39 | 40 | if 'past_key_values' not in st.session_state: 41 | st.session_state.past_key_values = None 42 | 43 | for i, (query, response) in enumerate(st.session_state.history): 44 | with st.chat_message(name="user", avatar="user"): 45 | st.markdown(query) 46 | with st.chat_message(name="assistant", avatar="assistant"): 47 | st.markdown(response) 48 | with st.chat_message(name="user", avatar="user"): 49 | input_placeholder = st.empty() 50 | with st.chat_message(name="assistant", avatar="assistant"): 51 | message_placeholder = st.empty() 52 | 53 | prompt_text = st.text_area(label="用户命令输入", 54 | height=100, 55 | placeholder="请在这儿输入您的命令") 56 | 57 | button = st.button("发送", key="predict") 58 | 59 | if button: 60 | input_placeholder.markdown(prompt_text) 61 | history, past_key_values = st.session_state.history, st.session_state.past_key_values 62 | for response, history, past_key_values in model.stream_chat(tokenizer, prompt_text, history, 63 | past_key_values=past_key_values, 64 | max_length=max_length, top_p=top_p, 65 | temperature=temperature, 66 | return_past_key_values=True): 67 | message_placeholder.markdown(response) 68 | 69 | st.session_state.history = history 70 | st.session_state.past_key_values = past_key_values 71 | -------------------------------------------------------------------------------- /evaluation/evaluate_ceval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import re 4 | import json 5 | import torch 6 | import torch.utils.data 7 | from transformers import AutoTokenizer, AutoModel 8 | from tqdm import tqdm 9 | 10 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) 11 | model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).bfloat16().cuda() 12 | 13 | choices = ["A", "B", "C", "D"] 14 | choice_tokens = [tokenizer.encode(choice, add_special_tokens=False)[0] for choice in choices] 15 | 16 | 17 | def build_prompt(text): 18 | return "[Round {}]\n\n问:{}\n\n答:".format(1, text) 19 | 20 | 21 | extraction_prompt = '综上所述,ABCD中正确的选项是:' 22 | 23 | accuracy_dict, count_dict = {}, {} 24 | with torch.no_grad(): 25 | for entry in glob.glob("./CEval/val/**/*.jsonl", recursive=True): 26 | dataset = [] 27 | with open(entry, encoding='utf-8') as file: 28 | for line in file: 29 | dataset.append(json.loads(line)) 30 | correct = 0 31 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=8) 32 | for batch in tqdm(dataloader): 33 | texts = batch["inputs_pretokenized"] 34 | queries = [build_prompt(query) for query in texts] 35 | inputs = tokenizer(queries, padding=True, return_tensors="pt", truncation=True, max_length=2048).to('cuda') 36 | outputs = model.generate(**inputs, do_sample=False, max_new_tokens=512) 37 | intermediate_outputs = [] 38 | for idx in range(len(outputs)): 39 | output = outputs.tolist()[idx][len(inputs["input_ids"][idx]):] 40 | response = tokenizer.decode(output) 41 | intermediate_outputs.append(response) 42 | answer_texts = [text + intermediate + "\n" + extraction_prompt for text, intermediate in 43 | zip(texts, intermediate_outputs)] 44 | input_tokens = [build_prompt(answer_text) for answer_text in answer_texts] 45 | inputs = tokenizer(input_tokens, padding=True, return_tensors="pt", truncation=True, max_length=2048).to('cuda') 46 | outputs = model(**inputs, return_last_logit=True) 47 | logits = outputs.logits[:, -1] 48 | logits = logits[:, choice_tokens] 49 | preds = logits.argmax(dim=-1) 50 | correct += (preds.cpu() == batch["label"]).sum().item() 51 | accuracy = correct / len(dataset) 52 | print(entry, accuracy) 53 | accuracy_dict[entry] = accuracy 54 | count_dict[entry] = len(dataset) 55 | 56 | acc_total, count_total = 0.0, 0 57 | for key in accuracy_dict: 58 | acc_total += accuracy_dict[key] * count_dict[key] 59 | count_total += count_dict[key] 60 | print(acc_total / count_total) -------------------------------------------------------------------------------- /MODEL_LICENSE: -------------------------------------------------------------------------------- 1 | The ChatGLM-6B License 2 | 3 | 一、定义 4 | 5 | “许可方”是指分发其软件的 ChatGLM2-6B 模型团队。 6 | 7 | “软件”是指根据本许可提供的 ChatGLM2-6B 模型参数。 8 | 9 | 2. 许可授予 10 | 11 | 根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可,仅用于您的非商业研究目的。 12 | 13 | 上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。 14 | 15 | 3.限制 16 | 17 | 您不得出于任何商业、军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。 18 | 19 | 您不得利用本软件从事任何危害国家安全和国家统一、危害社会公共利益、侵犯人身权益的行为。 20 | 21 | 4.免责声明 22 | 23 | 本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。 在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关 软件。 24 | 25 | 5. 责任限制 26 | 27 | 除适用法律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。 28 | 29 | 6.争议解决 30 | 31 | 本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。 32 | 33 | 请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 glm-130b@googlegroups.com 与我们联系。 34 | 35 | 1. Definitions 36 | 37 | “Licensor” means the ChatGLM2-6B Model Team that distributes its Software. 38 | 39 | “Software” means the ChatGLM2-6B model parameters made available under this license. 40 | 41 | 2. License Grant 42 | 43 | Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes. 44 | 45 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 46 | 47 | 3. Restriction 48 | 49 | You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes. 50 | 51 | You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings. 52 | 53 | 4. Disclaimer 54 | 55 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 56 | 57 | 5. Limitation of Liability 58 | 59 | EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 60 | 61 | 6. Dispute Resolution 62 | 63 | This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing. 64 | 65 | Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com. 66 | -------------------------------------------------------------------------------- /chatglm_model_v2/trainer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task. 17 | """ 18 | import os 19 | from typing import Optional 20 | from transformers import Trainer 21 | 22 | import torch 23 | from transformers.modeling_utils import PreTrainedModel, unwrap_model 24 | from transformers.utils import logging 25 | 26 | logger = logging.get_logger(__name__) 27 | 28 | WEIGHTS_NAME = "pytorch_model.bin" 29 | TRAINING_ARGS_NAME = "training_args.bin" 30 | 31 | 32 | class PrefixTrainer(Trainer): 33 | def __init__(self, *args, save_changed=False, **kwargs): 34 | self.save_changed = save_changed 35 | super().__init__(*args, **kwargs) 36 | 37 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 38 | # If we are executing this function, we are the process zero, so we don't check for that. 39 | output_dir = output_dir if output_dir is not None else self.args.output_dir 40 | os.makedirs(output_dir, exist_ok=True) 41 | logger.info(f"Saving model checkpoint to {output_dir}") 42 | # Save a trained model and configuration using `save_pretrained()`. 43 | # They can then be reloaded using `from_pretrained()` 44 | if not isinstance(self.model, PreTrainedModel): 45 | if isinstance(unwrap_model(self.model), PreTrainedModel): 46 | if state_dict is None: 47 | state_dict = self.model.state_dict() 48 | unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict) 49 | else: 50 | logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") 51 | if state_dict is None: 52 | state_dict = self.model.state_dict() 53 | torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) 54 | else: 55 | if self.save_changed: 56 | print("Saving PrefixEncoder") 57 | state_dict = self.model.state_dict() 58 | filtered_state_dict = {} 59 | for k, v in self.model.named_parameters(): 60 | if v.requires_grad: 61 | filtered_state_dict[k] = state_dict[k] 62 | self.model.save_pretrained(output_dir, state_dict=filtered_state_dict) 63 | else: 64 | print("Saving the whole model") 65 | self.model.save_pretrained(output_dir, state_dict=state_dict) 66 | if self.tokenizer is not None: 67 | self.tokenizer.save_pretrained(output_dir) 68 | 69 | # Good practice: save your training arguments together with the trained model 70 | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) 71 | -------------------------------------------------------------------------------- /application/web_demo.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModel, AutoTokenizer 2 | import gradio as gr 3 | import mdtex2html 4 | from utils import load_model_on_gpus 5 | 6 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) 7 | model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda() 8 | # 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量 9 | # from utils import load_model_on_gpus 10 | # model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2) 11 | model = model.eval() 12 | 13 | """Override Chatbot.postprocess""" 14 | 15 | 16 | def postprocess(self, y): 17 | if y is None: 18 | return [] 19 | for i, (message, response) in enumerate(y): 20 | y[i] = ( 21 | None if message is None else mdtex2html.convert((message)), 22 | None if response is None else mdtex2html.convert(response), 23 | ) 24 | return y 25 | 26 | 27 | gr.Chatbot.postprocess = postprocess 28 | 29 | 30 | def parse_text(text): 31 | """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/""" 32 | lines = text.split("\n") 33 | lines = [line for line in lines if line != ""] 34 | count = 0 35 | for i, line in enumerate(lines): 36 | if "```" in line: 37 | count += 1 38 | items = line.split('`') 39 | if count % 2 == 1: 40 | lines[i] = f'
'
 41 |             else:
 42 |                 lines[i] = f'
' 43 | else: 44 | if i > 0: 45 | if count % 2 == 1: 46 | line = line.replace("`", "\`") 47 | line = line.replace("<", "<") 48 | line = line.replace(">", ">") 49 | line = line.replace(" ", " ") 50 | line = line.replace("*", "*") 51 | line = line.replace("_", "_") 52 | line = line.replace("-", "-") 53 | line = line.replace(".", ".") 54 | line = line.replace("!", "!") 55 | line = line.replace("(", "(") 56 | line = line.replace(")", ")") 57 | line = line.replace("$", "$") 58 | lines[i] = "
"+line 59 | text = "".join(lines) 60 | return text 61 | 62 | 63 | def predict(input, chatbot, max_length, top_p, temperature, history, past_key_values): 64 | chatbot.append((parse_text(input), "")) 65 | for response, history, past_key_values in model.stream_chat(tokenizer, input, history, past_key_values=past_key_values, 66 | return_past_key_values=True, 67 | max_length=max_length, top_p=top_p, 68 | temperature=temperature): 69 | chatbot[-1] = (parse_text(input), parse_text(response)) 70 | 71 | yield chatbot, history, past_key_values 72 | 73 | 74 | def reset_user_input(): 75 | return gr.update(value='') 76 | 77 | 78 | def reset_state(): 79 | return [], [], None 80 | 81 | 82 | with gr.Blocks() as demo: 83 | gr.HTML("""

ChatGLM2-6B

""") 84 | 85 | chatbot = gr.Chatbot() 86 | with gr.Row(): 87 | with gr.Column(scale=4): 88 | with gr.Column(scale=12): 89 | user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( 90 | container=False) 91 | with gr.Column(min_width=32, scale=1): 92 | submitBtn = gr.Button("Submit", variant="primary") 93 | with gr.Column(scale=1): 94 | emptyBtn = gr.Button("Clear History") 95 | max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True) 96 | top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True) 97 | temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) 98 | 99 | history = gr.State([]) 100 | past_key_values = gr.State(None) 101 | 102 | submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history, past_key_values], 103 | [chatbot, history, past_key_values], show_progress=True) 104 | submitBtn.click(reset_user_input, [], [user_input]) 105 | 106 | emptyBtn.click(reset_state, outputs=[chatbot, history, past_key_values], show_progress=True) 107 | 108 | demo.queue().launch(share=False, inbrowser=True) 109 | -------------------------------------------------------------------------------- /chatglm_model_v2/README.md: -------------------------------------------------------------------------------- 1 | # ChatGLM2-6B-PT 2 | 本仓库实现了对于 ChatGLM2-6B 模型基于 [P-Tuning v2](https://github.com/THUDM/P-tuning-v2) 的微调。P-Tuning v2 将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,最低只需要 7GB 显存即可运行。 3 | 4 | 下面以 [ADGEN](https://aclanthology.org/D19-1321.pdf) (广告生成) 数据集为例介绍代码的使用方法。 5 | 6 | ## 软件依赖 7 | 运行微调除 ChatGLM2-6B 的依赖之外,还需要安装以下依赖 8 | ``` 9 | pip install rouge_chinese nltk jieba datasets 10 | ``` 11 | ## 使用方法 12 | 13 | ### 下载数据集 14 | ADGEN 数据集任务为根据输入(content)生成一段广告词(summary)。 15 | 16 | ```json 17 | { 18 | "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳", 19 | "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。" 20 | } 21 | ``` 22 | 23 | 从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) 或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载处理好的 ADGEN 数据集,将解压后的 `AdvertiseGen` 目录放到本目录下。 24 | 25 | ### 训练 26 | 27 | #### P-Tuning v2 28 | 29 | 运行以下指令进行训练: 30 | ```shell 31 | bash train.sh 32 | ``` 33 | `train.sh` 中的 `PRE_SEQ_LEN` 和 `LR` 分别是 soft prompt 长度和训练的学习率,可以进行调节以取得最佳的效果。P-Tuning-v2 方法会冻结全部的模型参数,可通过调整 `quantization_bit` 来被原始模型的量化等级,不加此选项则为 FP16 精度加载。 34 | 35 | 在默认配置 `quantization_bit=4`、`per_device_train_batch_size=1`、`gradient_accumulation_steps=16` 下,INT4 的模型参数被冻结,一次训练迭代会以 1 的批处理大小进行 16 次累加的前后向传播,等效为 16 的总批处理大小,此时最低只需 6.7G 显存。若想在同等批处理大小下提升训练效率,可在二者乘积不变的情况下,加大 `per_device_train_batch_size` 的值,但也会带来更多的显存消耗,请根据实际情况酌情调整。 36 | 37 | 如果你想要[从本地加载模型](../README.md#从本地加载模型),可以将 `train.sh` 中的 `THUDM/chatglm2-6b` 改为你本地的模型路径。 38 | 39 | #### Finetune 40 | 41 | 如果需要进行全参数的 Finetune,需要安装 [Deepspeed](https://github.com/microsoft/DeepSpeed),然后运行以下指令: 42 | 43 | ```shell 44 | bash ds_train_finetune.sh 45 | ``` 46 | 47 | ### 推理 48 | 49 | 在 P-tuning v2 训练时模型只保存 PrefixEncoder 部分的参数,所以在推理时需要同时加载原 ChatGLM2-6B 模型以及 PrefixEncoder 的权重,因此需要指定 `evaluate.sh` 中的参数: 50 | 51 | ```shell 52 | --model_name_or_path THUDM/chatglm2-6b 53 | --ptuning_checkpoint $CHECKPOINT_PATH 54 | ``` 55 | 56 | 如果是,只需要跟之前一样设定 `model_name_or_path`: 57 | 58 | ```shell 59 | --model_name_or_path $CHECKPOINT_PATH 60 | ``` 61 | 62 | 评测指标为中文 Rouge score 和 BLEU-4。生成的结果保存在 63 | `./output/adgen-chatglm2-6b-pt-128-2e-2/generated_predictions.txt`。 64 | 65 | ### 例子 66 | #### 示例1 67 | * Input: 类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞 68 | * Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。 69 | * Output[微调前]: 这件上衣的材质是牛仔布,颜色是白色,风格是简约,图案是刺绣,衣样式是外套,衣款式是破洞。 70 | * Output[微调后]: 这是一款简约的牛仔外套,破洞设计,将牛仔布破洞,带来一种随意与个性。破洞的牛仔外套,展现出时尚气息,带来一种休闲感。同时,刺绣图案,让整件外套更加立体。 71 | 72 | #### 示例2 73 | 74 | * Input: 类型#裙\*版型#显瘦\*风格#文艺\*风格#简约\*图案#印花\*图案#撞色\*裙下摆#压褶\*裙长#连衣裙\*裙领型#圆领 75 | * Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。 76 | * Output[微调前]: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领 1\. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2\. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3\. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4\. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5\. 风格:文艺风格,让连衣裙更加有内涵和品味。6\. 图案:印花设计,在连衣裙上印有独特的图案。7\. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。 77 | * Output[微调后]: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。 78 | 79 | 80 | ## 模型部署 81 | 首先载入Tokenizer: 82 | 83 | ```python 84 | from transformers import AutoConfig, AutoModel, AutoTokenizer 85 | 86 | # 载入Tokenizer 87 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) 88 | ``` 89 | 90 | 1. 如果需要加载的 P-Tuning 的 checkpoint: 91 | 92 | ```python 93 | config = AutoConfig.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, pre_seq_len=128) 94 | model = AutoModel.from_pretrained("THUDM/chatglm2-6b", config=config, trust_remote_code=True) 95 | prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin")) 96 | new_prefix_state_dict = {} 97 | for k, v in prefix_state_dict.items(): 98 | if k.startswith("transformer.prefix_encoder."): 99 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v 100 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) 101 | ``` 102 | 注意你可能需要将 `pre_seq_len` 改成你训练时的实际值。如果你是[从本地加载模型](../README.md#从本地加载模型)的话,需要将 `THUDM/chatglm2-6b` 改成本地的模型路径(注意不是checkpoint路径)。 103 | 104 | 2. 如果需要加载的是全参数微调的 checkpoint,则直接加载整个 checkpoint: 105 | 106 | ```python 107 | model = AutoModel.from_pretrained(CHECKPOINT_PATH, trust_remote_code=True) 108 | ``` 109 | 110 | 之后根据需求可以进行量化,也可以直接使用: 111 | 112 | ```python 113 | # Comment out the following line if you don't use quantization 114 | model = model.quantize(4) 115 | model = model.cuda() 116 | model = model.eval() 117 | 118 | response, history = model.chat(tokenizer, "你好", history=[]) 119 | ``` 120 | 121 | 你也可以直接运行支持加载 P-Tuning v2 checkpoint 的 [web demo](./web_demo.py) 122 | ```shell 123 | bash web_demo.sh 124 | ``` 125 | 可能需要修改 [web_demo.sh](./web_demo.sh) 的内容以符合你实际的 checkpoint 情况。 126 | 127 | ## 使用自己的数据集 128 | 修改 `train.sh` 和 `evaluate.sh` 中的 `train_file`、`validation_file`和`test_file`为你自己的 JSON 格式数据集路径,并将 `prompt_column` 和 `response_column` 改为 JSON 文件中输入文本和输出文本对应的 KEY。可能还需要增大 `max_source_length` 和 `max_target_length` 来匹配你自己的数据集中的最大输入输出长度。 129 | 130 | ## 对话数据集 131 | 132 | 如需要使用多轮对话数据对模型进行微调,可以提供聊天历史,例如以下是一个三轮对话的训练数据: 133 | 134 | ```json lines 135 | {"prompt": "长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "response": "用电脑能读数据流吗?水温多少", "history": []} 136 | {"prompt": "95", "response": "上下水管温差怎么样啊?空气是不是都排干净了呢?", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"]]} 137 | {"prompt": "是的。上下水管都好的", "response": "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"], ["95", "上下水管温差怎么样啊?空气是不是都排干净了呢?"]]} 138 | ``` 139 | 140 | 训练时需要指定 `--history_column` 为数据中聊天历史的 key(在此例子中是 `history`),将自动把聊天历史拼接。要注意超过输入长度 `max_source_length` 的内容会被截断。 141 | 142 | 可以参考以下指令: 143 | 144 | ```shell 145 | bash train_chat.sh 146 | ``` 147 | 148 | ## 引用 149 | 150 | ``` 151 | @inproceedings{liu2022p, 152 | title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks}, 153 | author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie}, 154 | booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)}, 155 | pages={61--68}, 156 | year={2022} 157 | } 158 | ``` 159 | 160 | 161 | 162 | -------------------------------------------------------------------------------- /chatglm_model_v1/web_demo.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | import gradio as gr 4 | import mdtex2html 5 | 6 | import torch 7 | import transformers 8 | from transformers import ( 9 | AutoConfig, 10 | AutoModel, 11 | AutoTokenizer, 12 | AutoTokenizer, 13 | DataCollatorForSeq2Seq, 14 | HfArgumentParser, 15 | Seq2SeqTrainingArguments, 16 | set_seed, 17 | ) 18 | 19 | from arguments import ModelArguments, DataTrainingArguments 20 | 21 | 22 | model = None 23 | tokenizer = None 24 | 25 | """Override Chatbot.postprocess""" 26 | 27 | 28 | def postprocess(self, y): 29 | if y is None: 30 | return [] 31 | for i, (message, response) in enumerate(y): 32 | y[i] = ( 33 | None if message is None else mdtex2html.convert((message)), 34 | None if response is None else mdtex2html.convert(response), 35 | ) 36 | return y 37 | 38 | 39 | gr.Chatbot.postprocess = postprocess 40 | 41 | 42 | def parse_text(text): 43 | """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/""" 44 | lines = text.split("\n") 45 | lines = [line for line in lines if line != ""] 46 | count = 0 47 | for i, line in enumerate(lines): 48 | if "```" in line: 49 | count += 1 50 | items = line.split('`') 51 | if count % 2 == 1: 52 | lines[i] = f'
'
 53 |             else:
 54 |                 lines[i] = f'
' 55 | else: 56 | if i > 0: 57 | if count % 2 == 1: 58 | line = line.replace("`", "\`") 59 | line = line.replace("<", "<") 60 | line = line.replace(">", ">") 61 | line = line.replace(" ", " ") 62 | line = line.replace("*", "*") 63 | line = line.replace("_", "_") 64 | line = line.replace("-", "-") 65 | line = line.replace(".", ".") 66 | line = line.replace("!", "!") 67 | line = line.replace("(", "(") 68 | line = line.replace(")", ")") 69 | line = line.replace("$", "$") 70 | lines[i] = "
"+line 71 | text = "".join(lines) 72 | return text 73 | 74 | 75 | def predict(input, chatbot, max_length, top_p, temperature, history): 76 | chatbot.append((parse_text(input), "")) 77 | for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, 78 | temperature=temperature): 79 | chatbot[-1] = (parse_text(input), parse_text(response)) 80 | 81 | yield chatbot, history 82 | 83 | 84 | def reset_user_input(): 85 | return gr.update(value='') 86 | 87 | 88 | def reset_state(): 89 | return [], [] 90 | 91 | 92 | with gr.Blocks() as demo: 93 | gr.HTML("""

ChatGLM

""") 94 | 95 | chatbot = gr.Chatbot() 96 | with gr.Row(): 97 | with gr.Column(scale=4): 98 | with gr.Column(scale=12): 99 | user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( 100 | container=False) 101 | with gr.Column(min_width=32, scale=1): 102 | submitBtn = gr.Button("Submit", variant="primary") 103 | with gr.Column(scale=1): 104 | emptyBtn = gr.Button("Clear History") 105 | max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) 106 | top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) 107 | temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) 108 | 109 | history = gr.State([]) 110 | 111 | submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], 112 | show_progress=True) 113 | submitBtn.click(reset_user_input, [], [user_input]) 114 | 115 | emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) 116 | 117 | 118 | 119 | def main(): 120 | global model, tokenizer 121 | 122 | parser = HfArgumentParser(( 123 | ModelArguments)) 124 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 125 | # If we pass only one argument to the script and it's the path to a json file, 126 | # let's parse it to get our arguments. 127 | model_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0] 128 | else: 129 | model_args = parser.parse_args_into_dataclasses()[0] 130 | 131 | tokenizer = AutoTokenizer.from_pretrained( 132 | model_args.model_name_or_path, trust_remote_code=True) 133 | config = AutoConfig.from_pretrained( 134 | model_args.model_name_or_path, trust_remote_code=True) 135 | 136 | config.pre_seq_len = model_args.pre_seq_len 137 | config.prefix_projection = model_args.prefix_projection 138 | 139 | if model_args.ptuning_checkpoint is not None: 140 | print(f"Loading prefix_encoder weight from {model_args.ptuning_checkpoint}") 141 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) 142 | prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin")) 143 | new_prefix_state_dict = {} 144 | for k, v in prefix_state_dict.items(): 145 | if k.startswith("transformer.prefix_encoder."): 146 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v 147 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) 148 | else: 149 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) 150 | 151 | if model_args.quantization_bit is not None: 152 | print(f"Quantized to {model_args.quantization_bit} bit") 153 | model = model.quantize(model_args.quantization_bit) 154 | 155 | if model_args.pre_seq_len is not None: 156 | # P-tuning v2 157 | model = model.half().cuda() 158 | model.transformer.prefix_encoder.float().cuda() 159 | 160 | model = model.eval() 161 | demo.queue().launch(share=False, inbrowser=True) 162 | 163 | 164 | 165 | if __name__ == "__main__": 166 | main() -------------------------------------------------------------------------------- /application/openai_api.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Implements API for ChatGLM2-6B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat) 3 | # Usage: python openai_api.py 4 | # Visit http://localhost:8000/docs for documents. 5 | 6 | 7 | import time 8 | import torch 9 | import uvicorn 10 | from pydantic import BaseModel, Field 11 | from fastapi import FastAPI, HTTPException 12 | from fastapi.middleware.cors import CORSMiddleware 13 | from contextlib import asynccontextmanager 14 | from typing import Any, Dict, List, Literal, Optional, Union 15 | from transformers import AutoTokenizer, AutoModel 16 | from sse_starlette.sse import ServerSentEvent, EventSourceResponse 17 | 18 | 19 | @asynccontextmanager 20 | async def lifespan(app: FastAPI): # collects GPU memory 21 | yield 22 | if torch.cuda.is_available(): 23 | torch.cuda.empty_cache() 24 | torch.cuda.ipc_collect() 25 | 26 | 27 | app = FastAPI(lifespan=lifespan) 28 | 29 | app.add_middleware( 30 | CORSMiddleware, 31 | allow_origins=["*"], 32 | allow_credentials=True, 33 | allow_methods=["*"], 34 | allow_headers=["*"], 35 | ) 36 | 37 | class ModelCard(BaseModel): 38 | id: str 39 | object: str = "model" 40 | created: int = Field(default_factory=lambda: int(time.time())) 41 | owned_by: str = "owner" 42 | root: Optional[str] = None 43 | parent: Optional[str] = None 44 | permission: Optional[list] = None 45 | 46 | 47 | class ModelList(BaseModel): 48 | object: str = "list" 49 | data: List[ModelCard] = [] 50 | 51 | 52 | class ChatMessage(BaseModel): 53 | role: Literal["user", "assistant", "system"] 54 | content: str 55 | 56 | 57 | class DeltaMessage(BaseModel): 58 | role: Optional[Literal["user", "assistant", "system"]] = None 59 | content: Optional[str] = None 60 | 61 | 62 | class ChatCompletionRequest(BaseModel): 63 | model: str 64 | messages: List[ChatMessage] 65 | temperature: Optional[float] = None 66 | top_p: Optional[float] = None 67 | max_length: Optional[int] = None 68 | stream: Optional[bool] = False 69 | 70 | 71 | class ChatCompletionResponseChoice(BaseModel): 72 | index: int 73 | message: ChatMessage 74 | finish_reason: Literal["stop", "length"] 75 | 76 | 77 | class ChatCompletionResponseStreamChoice(BaseModel): 78 | index: int 79 | delta: DeltaMessage 80 | finish_reason: Optional[Literal["stop", "length"]] 81 | 82 | 83 | class ChatCompletionResponse(BaseModel): 84 | model: str 85 | object: Literal["chat.completion", "chat.completion.chunk"] 86 | choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] 87 | created: Optional[int] = Field(default_factory=lambda: int(time.time())) 88 | 89 | 90 | @app.get("/v1/models", response_model=ModelList) 91 | async def list_models(): 92 | global model_args 93 | model_card = ModelCard(id="gpt-3.5-turbo") 94 | return ModelList(data=[model_card]) 95 | 96 | 97 | @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) 98 | async def create_chat_completion(request: ChatCompletionRequest): 99 | global model, tokenizer 100 | 101 | if request.messages[-1].role != "user": 102 | raise HTTPException(status_code=400, detail="Invalid request") 103 | query = request.messages[-1].content 104 | 105 | prev_messages = request.messages[:-1] 106 | if len(prev_messages) > 0 and prev_messages[0].role == "system": 107 | query = prev_messages.pop(0).content + query 108 | 109 | history = [] 110 | if len(prev_messages) % 2 == 0: 111 | for i in range(0, len(prev_messages), 2): 112 | if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": 113 | history.append([prev_messages[i].content, prev_messages[i+1].content]) 114 | 115 | if request.stream: 116 | generate = predict(query, history, request.model) 117 | return EventSourceResponse(generate, media_type="text/event-stream") 118 | 119 | response, _ = model.chat(tokenizer, query, history=history) 120 | choice_data = ChatCompletionResponseChoice( 121 | index=0, 122 | message=ChatMessage(role="assistant", content=response), 123 | finish_reason="stop" 124 | ) 125 | 126 | return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion") 127 | 128 | 129 | async def predict(query: str, history: List[List[str]], model_id: str): 130 | global model, tokenizer 131 | 132 | choice_data = ChatCompletionResponseStreamChoice( 133 | index=0, 134 | delta=DeltaMessage(role="assistant"), 135 | finish_reason=None 136 | ) 137 | chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") 138 | yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) 139 | 140 | current_length = 0 141 | 142 | for new_response, _ in model.stream_chat(tokenizer, query, history): 143 | if len(new_response) == current_length: 144 | continue 145 | 146 | new_text = new_response[current_length:] 147 | current_length = len(new_response) 148 | 149 | choice_data = ChatCompletionResponseStreamChoice( 150 | index=0, 151 | delta=DeltaMessage(content=new_text), 152 | finish_reason=None 153 | ) 154 | chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") 155 | yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) 156 | 157 | 158 | choice_data = ChatCompletionResponseStreamChoice( 159 | index=0, 160 | delta=DeltaMessage(), 161 | finish_reason="stop" 162 | ) 163 | chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") 164 | yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) 165 | yield '[DONE]' 166 | 167 | 168 | 169 | if __name__ == "__main__": 170 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) 171 | model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda() 172 | # 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量 173 | # from utils import load_model_on_gpus 174 | # model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2) 175 | model.eval() 176 | 177 | uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) 178 | -------------------------------------------------------------------------------- /chatglm_model_v2/web_demo.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | import gradio as gr 4 | import mdtex2html 5 | 6 | import torch 7 | import transformers 8 | from transformers import ( 9 | AutoConfig, 10 | AutoModel, 11 | AutoTokenizer, 12 | AutoTokenizer, 13 | DataCollatorForSeq2Seq, 14 | HfArgumentParser, 15 | Seq2SeqTrainingArguments, 16 | set_seed, 17 | ) 18 | 19 | from arguments import ModelArguments, DataTrainingArguments 20 | 21 | 22 | model = None 23 | tokenizer = None 24 | 25 | """Override Chatbot.postprocess""" 26 | 27 | 28 | def postprocess(self, y): 29 | if y is None: 30 | return [] 31 | for i, (message, response) in enumerate(y): 32 | y[i] = ( 33 | None if message is None else mdtex2html.convert((message)), 34 | None if response is None else mdtex2html.convert(response), 35 | ) 36 | return y 37 | 38 | 39 | gr.Chatbot.postprocess = postprocess 40 | 41 | 42 | def parse_text(text): 43 | """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/""" 44 | lines = text.split("\n") 45 | lines = [line for line in lines if line != ""] 46 | count = 0 47 | for i, line in enumerate(lines): 48 | if "```" in line: 49 | count += 1 50 | items = line.split('`') 51 | if count % 2 == 1: 52 | lines[i] = f'
'
 53 |             else:
 54 |                 lines[i] = f'
' 55 | else: 56 | if i > 0: 57 | if count % 2 == 1: 58 | line = line.replace("`", "\`") 59 | line = line.replace("<", "<") 60 | line = line.replace(">", ">") 61 | line = line.replace(" ", " ") 62 | line = line.replace("*", "*") 63 | line = line.replace("_", "_") 64 | line = line.replace("-", "-") 65 | line = line.replace(".", ".") 66 | line = line.replace("!", "!") 67 | line = line.replace("(", "(") 68 | line = line.replace(")", ")") 69 | line = line.replace("$", "$") 70 | lines[i] = "
"+line 71 | text = "".join(lines) 72 | return text 73 | 74 | 75 | def predict(input, chatbot, max_length, top_p, temperature, history, past_key_values): 76 | chatbot.append((parse_text(input), "")) 77 | for response, history, past_key_values in model.stream_chat(tokenizer, input, history, past_key_values=past_key_values, 78 | return_past_key_values=True, 79 | max_length=max_length, top_p=top_p, 80 | temperature=temperature): 81 | chatbot[-1] = (parse_text(input), parse_text(response)) 82 | 83 | yield chatbot, history, past_key_values 84 | 85 | 86 | def reset_user_input(): 87 | return gr.update(value='') 88 | 89 | 90 | def reset_state(): 91 | return [], [], None 92 | 93 | 94 | with gr.Blocks() as demo: 95 | gr.HTML("""

ChatGLM2-6B

""") 96 | 97 | chatbot = gr.Chatbot() 98 | with gr.Row(): 99 | with gr.Column(scale=4): 100 | with gr.Column(scale=12): 101 | user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( 102 | container=False) 103 | with gr.Column(min_width=32, scale=1): 104 | submitBtn = gr.Button("Submit", variant="primary") 105 | with gr.Column(scale=1): 106 | emptyBtn = gr.Button("Clear History") 107 | max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True) 108 | top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True) 109 | temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) 110 | 111 | history = gr.State([]) 112 | past_key_values = gr.State(None) 113 | 114 | submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history, past_key_values], 115 | [chatbot, history, past_key_values], show_progress=True) 116 | submitBtn.click(reset_user_input, [], [user_input]) 117 | 118 | emptyBtn.click(reset_state, outputs=[chatbot, history, past_key_values], show_progress=True) 119 | 120 | 121 | def main(): 122 | global model, tokenizer 123 | 124 | parser = HfArgumentParser(( 125 | ModelArguments)) 126 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 127 | # If we pass only one argument to the script and it's the path to a json file, 128 | # let's parse it to get our arguments. 129 | model_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0] 130 | else: 131 | model_args = parser.parse_args_into_dataclasses()[0] 132 | 133 | tokenizer = AutoTokenizer.from_pretrained( 134 | model_args.model_name_or_path, trust_remote_code=True) 135 | config = AutoConfig.from_pretrained( 136 | model_args.model_name_or_path, trust_remote_code=True) 137 | 138 | config.pre_seq_len = model_args.pre_seq_len 139 | config.prefix_projection = model_args.prefix_projection 140 | 141 | if model_args.ptuning_checkpoint is not None: 142 | print(f"Loading prefix_encoder weight from {model_args.ptuning_checkpoint}") 143 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) 144 | prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin")) 145 | new_prefix_state_dict = {} 146 | for k, v in prefix_state_dict.items(): 147 | if k.startswith("transformer.prefix_encoder."): 148 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v 149 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) 150 | else: 151 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) 152 | 153 | if model_args.quantization_bit is not None: 154 | print(f"Quantized to {model_args.quantization_bit} bit") 155 | model = model.quantize(model_args.quantization_bit) 156 | model = model.cuda() 157 | if model_args.pre_seq_len is not None: 158 | # P-tuning v2 159 | model.transformer.prefix_encoder.float() 160 | 161 | model = model.eval() 162 | demo.queue().launch(share=False, inbrowser=True) 163 | 164 | 165 | 166 | if __name__ == "__main__": 167 | main() -------------------------------------------------------------------------------- /chatglm_model_v1/README.md: -------------------------------------------------------------------------------- 1 | # ChatGLM-6B-PT 2 | 本仓库实现了对于 ChatGLM-6B 模型基于 [P-Tuning v2](https://github.com/THUDM/P-tuning-v2) 的微调。P-Tuning v2 将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,最低只需要 7GB 显存即可运行。 3 | 4 | 下面以 [ADGEN](https://aclanthology.org/D19-1321.pdf) (广告生成) 数据集为例介绍代码的使用方法。 5 | 6 | *Read this in [English](README_en.md). 7 | 8 | ## 软件依赖 9 | 运行微调需要4.27.1版本的`transformers`。除 ChatGLM-6B 的依赖之外,还需要安装以下依赖 10 | ``` 11 | pip install rouge_chinese nltk jieba datasets 12 | ``` 13 | ## 使用方法 14 | 15 | ### 下载数据集 16 | ADGEN 数据集任务为根据输入(content)生成一段广告词(summary)。 17 | 18 | ```json 19 | { 20 | "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳", 21 | "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。" 22 | } 23 | ``` 24 | 25 | 从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) 或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载处理好的 ADGEN 数据集,将解压后的 `AdvertiseGen` 目录放到本目录下。 26 | 27 | ### 训练 28 | 29 | #### P-Tuning v2 30 | 31 | 运行以下指令进行训练: 32 | ```shell 33 | bash train.sh 34 | ``` 35 | `train.sh` 中的 `PRE_SEQ_LEN` 和 `LR` 分别是 soft prompt 长度和训练的学习率,可以进行调节以取得最佳的效果。P-Tuning-v2 方法会冻结全部的模型参数,可通过调整 `quantization_bit` 来被原始模型的量化等级,不加此选项则为 FP16 精度加载。 36 | 37 | 在默认配置 `quantization_bit=4`、`per_device_train_batch_size=1`、`gradient_accumulation_steps=16` 下,INT4 的模型参数被冻结,一次训练迭代会以 1 的批处理大小进行 16 次累加的前后向传播,等效为 16 的总批处理大小,此时最低只需 6.7G 显存。若想在同等批处理大小下提升训练效率,可在二者乘积不变的情况下,加大 `per_device_train_batch_size` 的值,但也会带来更多的显存消耗,请根据实际情况酌情调整。 38 | 39 | 如果你想要[从本地加载模型](../README_en.md#load-the-model-locally),可以将 `train.sh` 中的 `THUDM/chatglm-6b` 改为你本地的模型路径。 40 | 41 | #### Finetune 42 | 43 | 如果需要进行全参数的 Finetune,需要安装 [Deepspeed](https://github.com/microsoft/DeepSpeed),然后运行以下指令: 44 | 45 | ```shell 46 | bash ds_train_finetune.sh 47 | ``` 48 | 49 | ### 推理 50 | 51 | 在 P-tuning v2 训练时模型只保存 PrefixEncoder 部分的参数,所以在推理时需要同时加载原 ChatGLM-6B 模型以及 PrefixEncoder 的权重,因此需要指定 `evaluate.sh` 中的参数: 52 | 53 | ```shell 54 | --model_name_or_path THUDM/chatglm-6b 55 | --ptuning_checkpoint $CHECKPOINT_PATH 56 | ``` 57 | 58 | 仍然兼容旧版全参保存的 Checkpoint,只需要跟之前一样设定 `model_name_or_path`: 59 | 60 | ```shell 61 | --model_name_or_path $CHECKPOINT_PATH 62 | ``` 63 | 64 | 评测指标为中文 Rouge score 和 BLEU-4。生成的结果保存在 65 | `./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`。 66 | 67 | ### 例子 68 | #### 示例1 69 | * Input: 类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞 70 | * Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。 71 | * Output[微调前]: 这件上衣的材质是牛仔布,颜色是白色,风格是简约,图案是刺绣,衣样式是外套,衣款式是破洞。 72 | * Output[微调后]: 这是一款简约的牛仔外套,破洞设计,将牛仔布破洞,带来一种随意与个性。破洞的牛仔外套,展现出时尚气息,带来一种休闲感。同时,刺绣图案,让整件外套更加立体。 73 | 74 | #### 示例2 75 | 76 | * Input: 类型#裙\*版型#显瘦\*风格#文艺\*风格#简约\*图案#印花\*图案#撞色\*裙下摆#压褶\*裙长#连衣裙\*裙领型#圆领 77 | * Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。 78 | * Output[微调前]: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领 1\. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2\. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3\. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4\. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5\. 风格:文艺风格,让连衣裙更加有内涵和品味。6\. 图案:印花设计,在连衣裙上印有独特的图案。7\. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。 79 | * Output[微调后]: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。 80 | 81 | ### 评估结果 82 | 83 | | | Finetune | P-tuning v2 | LoRA | 84 | | ------------- | ----------- | ----- | ------------- | 85 | | BLEU-4 | 8.01 | 8.10 | 7.62 | 86 | | Rouge-1 | 31.23 | 31.12 | 30.60 | 87 | | Rouge-2 | 7.36 | 7.11 | 6.96 | 88 | | Rouge-l | 25.08 | 24.97 | 24.80 | 89 | | Training Loss | 3.00 | 3.74 | 3.32 | 90 | 91 | 92 | 93 | #### 实验设置 94 | 95 | ``` 96 | max_source_length=64 97 | max_target_length=64 98 | max_steps=3000 99 | ``` 100 | 101 | ##### P-tuning v2 102 | 103 | ``` 104 | pre_seq_len=128 105 | learning_rate=2e-2 106 | quantization_bit=4 107 | per_device_train_batch_size=16 108 | gradient_accumulation_steps=1 109 | ``` 110 | 111 | ##### Finetune 112 | 113 | ``` 114 | learning_rate=1e-4 115 | fp16 116 | num_gpus=4 117 | per_device_train_batch_size=4 118 | gradient_accumulation_steps=1 119 | ``` 120 | 121 | ##### LoRA 122 | 123 | 实现采用的是 [simple_thu_chatglm6b](https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/simple_thu_chatglm6b) 124 | 125 | ``` 126 | learning_rate=5e-4 127 | per_device_train_batch_size=16 128 | gradient_accumulation_steps=1 129 | ``` 130 | 131 | ## 模型部署 132 | 首先载入Tokenizer: 133 | 134 | ```python 135 | from transformers import AutoConfig, AutoModel, AutoTokenizer 136 | 137 | # 载入Tokenizer 138 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) 139 | ``` 140 | 141 | 1. 如果需要加载的是新 Checkpoint(只包含 PrefixEncoder 参数): 142 | 143 | ```python 144 | config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128) 145 | model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True) 146 | prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin")) 147 | new_prefix_state_dict = {} 148 | for k, v in prefix_state_dict.items(): 149 | if k.startswith("transformer.prefix_encoder."): 150 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v 151 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) 152 | ``` 153 | 注意你可能需要将 `pre_seq_len` 改成你训练时的实际值。如果你是[从本地加载模型](https://github.com/THUDM/ChatGLM-6B#%E4%BB%8E%E6%9C%AC%E5%9C%B0%E5%8A%A0%E8%BD%BD%E6%A8%A1%E5%9E%8B)的话,需要将 `THUDM/chatglm-6b` 改成本地的模型路径(注意不是checkpoint路径)。 154 | 155 | 2. 如果需要加载的是旧 Checkpoint(包含 ChatGLM-6B 以及 PrefixEncoder 参数),或者进行的是全参数微调,则直接加载整个 Checkpoint: 156 | 157 | ```python 158 | model = AutoModel.from_pretrained(CHECKPOINT_PATH, trust_remote_code=True) 159 | ``` 160 | 161 | 之后根据需求可以进行量化,也可以直接使用: 162 | 163 | ```python 164 | # Comment out the following line if you don't use quantization 165 | model = model.quantize(4) 166 | model = model.half().cuda() 167 | model.transformer.prefix_encoder.float() 168 | model = model.eval() 169 | 170 | response, history = model.chat(tokenizer, "你好", history=[]) 171 | ``` 172 | 173 | **[23/04/19]** 你也可以直接运行支持加载 P-Tuning v2 checkpoint 的 [web demo](./web_demo.py) 174 | ```shell 175 | bash web_demo.sh 176 | ``` 177 | 可能需要修改 [web_demo.sh](./web_demo.sh) 的内容以符合你实际的 checkpoint 情况。 178 | 179 | ## 使用自己的数据集 180 | 修改 `train.sh` 和 `evaluate.sh` 中的 `train_file`、`validation_file`和`test_file`为你自己的 JSON 格式数据集路径,并将 `prompt_column` 和 `response_column` 改为 JSON 文件中输入文本和输出文本对应的 KEY。可能还需要增大 `max_source_length` 和 `max_target_length` 来匹配你自己的数据集中的最大输入输出长度。 181 | 182 | ## 对话数据集 183 | 184 | 如需要使用多轮对话数据对模型进行微调,可以提供聊天历史,例如以下是一个三轮对话的训练数据: 185 | 186 | ```json lines 187 | {"prompt": "长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "response": "用电脑能读数据流吗?水温多少", "history": []} 188 | {"prompt": "95", "response": "上下水管温差怎么样啊?空气是不是都排干净了呢?", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"]]} 189 | {"prompt": "是的。上下水管都好的", "response": "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"], ["95", "上下水管温差怎么样啊?空气是不是都排干净了呢?"]]} 190 | ``` 191 | 192 | 训练时需要指定 `--history_column` 为数据中聊天历史的 key(在此例子中是 `history`),将自动把聊天历史拼接。要注意超过输入长度 `max_source_length` 的内容会被截断。 193 | 194 | 可以参考以下指令: 195 | 196 | ```shell 197 | bash train_chat.sh 198 | ``` 199 | 200 | ## 引用 201 | 202 | ``` 203 | @inproceedings{liu2022p, 204 | title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks}, 205 | author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie}, 206 | booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)}, 207 | pages={61--68}, 208 | year={2022} 209 | } 210 | ``` 211 | 212 | 213 | 214 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ChatGLM2-Tuning 2 | 3 | 4 | ## 一、介绍 5 | 6 | [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) 是开源中英双语对话模型 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) 的第二代版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,进一步优化了模型,使得其具有更大的性能、更长的输入、更有效的部署和更开放的协议。ChatGLM2-6B也因此登顶C-Eval榜单。 7 | 8 | 本项目结合了 **ChatGLM-6B** 和 **ChatGLM2-6B** 进行微调,可进行全参数微调,也可以使用如下优化技术: 9 | - Peft参数有效性训练:Ptuning、Prompt-tuning、Prefix-tuning、LoRA; 10 | - DeepSpeed ZeRO训练; 11 | - 量化感知训练&推理部署; 12 | 13 | --- 14 | 15 | 开发进程: 16 | - 代码调试 ✅ 17 | - 全参数训练 ✅ 18 | - 参数有效性训练 ✅ 19 | - 量化感知训练 ✅ 20 | - 指令微调 ✅ 21 | - 多轮对话 ✅ 22 | 23 | --- 24 | 25 | ## 二、开始使用 26 | ### 2.1 环境安装 27 | 首先需要下载本仓库: 28 | ```shell 29 | git clone https://github.com/wjn1996/ChatGLM2-Tuning 30 | cd ChatGLM2-Tuning 31 | ``` 32 | 33 | 安装环境依赖: 34 | ``` 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | ### 2.2 数据集准备 39 | 40 | ##### (1)使用自定义的指令微调数据集 41 | 42 | 指令微调数据集中包括一个任务的指令(instruction),以及对应任务的输入(input)输出(output)。模型在训练时只会计算output的loss。 43 | 44 | 数据集格式样例: 45 | ```json 46 | { 47 | "instruction": "请为下面的评论的情感类别进行分类,候选为【积极】和【消极】", 48 | "input": "《消失的她》这部电影很好看,但是我觉得大多数人看完后都会emo", 49 | "output": "消极", 50 | } 51 | ``` 52 | 53 | 54 | 55 | ##### (2)使用自定义的多轮对话数据集 56 | 57 | 多轮对话数据集在训练时有两种模式,一种是in-the-loop,另一种是session: 58 | - **in-the-loop**:一个多轮对话根据对话轮次拆解成多个样本,在训练时每个样本视为独立,根据对话历史history和当前的prompt,计算response的loss;ChatGLM2-6B默认采用这种方式进行训练多轮对话。 59 | 60 | ```json 61 | { 62 | "prompt": "长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", 63 | "response": "用电脑能读数据流吗?水温多少", 64 | "history": [] 65 | } 66 | { 67 | "prompt": "95", 68 | "response": "上下水管温差怎么样啊?空气是不是都排干净了呢?", 69 | "history": [ 70 | ["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"] 71 | ] 72 | } 73 | { 74 | "prompt": "是的。上下水管都好的", 75 | "response": "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!", 76 | "history": [ 77 | ["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"], 78 | ["95", "上下水管温差怎么样啊?空气是不是都排干净了呢?"] 79 | ] 80 | } 81 | ``` 82 | > 以上面的一个多轮对话为例,in-the-loop设置中,数据处理时,一个多轮对话将会生成3个独立的样本,每个样本是一个序列,包含对话历史、当前的prompt以及输出response。 83 | - **session**:将整个多轮对话当作一个样本,计算所有token(或每一轮对话的output)对应的loss; 84 | 85 | ```json 86 | { 87 | "prompt": [ 88 | "长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", 89 | "95", 90 | "是的。上下水管都好的" 91 | ], 92 | "response": [ 93 | "用电脑能读数据流吗?水温多少", 94 | "上下水管温差怎么样啊?空气是不是都排干净了呢?", 95 | "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!" 96 | ], 97 | } 98 | ``` 99 | > 以上面的多轮对话为例,只会生成一个样本,每一轮对话的prompt和response拼接起来,所有轮次的对话拼接起来,形成类似“Q1 A1 Q2 A2 ...”格式的序列。 100 | 101 | 102 | ##### (3)获取开源评测数据集 103 | 104 | TODO 105 | 106 | 107 | ### 2.3 模型训练 108 | 109 | 训练采用Causal LM进行训练,前向传播时只会计算指定token的loss,对于指令、对话历史、input和padding部分可以通过设置label为“-100”忽略对应的loss计算。 110 | 111 | ##### (1)P-tuning训练 112 | ```bash 113 | TASK_NAME=default_task # 指定任务名称 114 | PRE_SEQ_LEN=128 # prefix token数量 115 | LR=1e-4 # 学习率 116 | 117 | CHAT_TRAIN_DATA=data/train.json 118 | CHAT_VAL_DATA=data/dev.json 119 | 120 | MODEL_NAME_OR_PATH=pre-trained-lm/chatglm-6b 121 | 122 | NUM_GPUS=8 123 | 124 | MASTER_PORT=$(shuf -n 1 -i 10000-65535) 125 | 126 | MODEL_VERSION=v1 # V1:初始化为ChatGLM-6B,V2:初始化为ChatGLM2-6B 127 | 128 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 129 | deepspeed --num_gpus=$NUM_GPUS --master_port $MASTER_PORT chatglm_model_$MODEL_VERSION/run_ptuning.py \ 130 | --deepspeed deepspeed/deepspeed.json \ 131 | --do_train \ 132 | --train_file $CHAT_TRAIN_DATA \ 133 | --test_file $CHAT_VAL_DATA \ 134 | --prompt_column input \ 135 | --response_column output \ 136 | --model_name_or_path $MODEL_NAME_OR_PATH \ 137 | --output_dir ./output/deepspeed/adgen-chatglm-6b-ft-$LR \ 138 | --overwrite_output_dir \ 139 | --max_source_length 256 \ 140 | --max_target_length 256 \ 141 | --per_device_train_batch_size 32 \ 142 | --per_device_eval_batch_size 32 \ 143 | --gradient_accumulation_steps 1 \ 144 | --predict_with_generate \ 145 | --max_steps 9000 \ 146 | --logging_steps 10 \ 147 | --save_steps 1000 \ 148 | --learning_rate $LR \ 149 | --task_name $TASK_NAME \ 150 | --base_cache_dir ./.cache \ 151 | --fp16 152 | # --overwrite_cache \ 153 | ``` 154 | 155 | 参考脚本:scripts/ds_train_ptuning.sh 156 | 157 | ##### (2)LoRA训练 158 | ```bash 159 | TASK_NAME=default_task # 指定任务名称 160 | # PRE_SEQ_LEN=128 161 | PEFT_TYPE=lora # 指定参数有效性方法 162 | LORA_DIM=8 # 指定LoRA Rank 163 | LR=1e-4 # 学习率 164 | 165 | CHAT_TRAIN_DATA=./data/train.json 166 | CHAT_VAL_DATA=./data/dev.json 167 | 168 | MODEL_NAME_OR_PATH=./pre-trained-lm/chatglm-6b 169 | 170 | NUM_GPUS=8 171 | 172 | MASTER_PORT=$(shuf -n 1 -i 10000-65535) 173 | 174 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 175 | deepspeed --num_gpus=$NUM_GPUS --master_port $MASTER_PORT chatglm_model_v1/run_peft.py \ 176 | --deepspeed deepspeed/deepspeed.json \ 177 | --do_train \ 178 | --train_file $CHAT_TRAIN_DATA \ 179 | --test_file $CHAT_VAL_DATA \ 180 | --prompt_column input \ 181 | --response_column output \ 182 | --model_name_or_path $MODEL_NAME_OR_PATH \ 183 | --output_dir ./output/deepspeed/chatglm-6b-$TASK_NAME-$PEFT_TYPE-$LORA_DIM-$LR \ 184 | --overwrite_output_dir \ 185 | --max_source_length 256 \ 186 | --max_target_length 1024 \ 187 | --per_device_train_batch_size 32 \ 188 | --per_device_eval_batch_size 32 \ 189 | --gradient_accumulation_steps 1 \ 190 | --predict_with_generate \ 191 | --max_steps 9000 \ 192 | --logging_steps 10 \ 193 | --save_steps 1000 \ 194 | --learning_rate $LR \ 195 | --peft_type $PEFT_TYPE \ 196 | --lora_dim $LORA_DIM \ 197 | --task_name $TASK_NAME \ 198 | --base_cache_dir ./.cache/ \ 199 | --fp16 200 | # --overwrite_cache \ 201 | ``` 202 | 203 | 参考脚本:scripts/ds_train_peft.sh 204 | 205 | 如果要使用INT4量化感知训练,添加参数 206 | 207 | > --quantization_bit 4 208 | 209 | 即可。 210 | 211 | ### 2.4 模型推理与部署 212 | 213 | #### API部署 214 | 部署文件“api.py”: 215 | ```python 216 | from fastapi import FastAPI, Request 217 | from transformers import AutoTokenizer, AutoModel 218 | import uvicorn, json, datetime 219 | import torch 220 | 221 | DEVICE = "cuda" 222 | DEVICE_ID = "0" 223 | CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE 224 | 225 | 226 | def torch_gc(): 227 | if torch.cuda.is_available(): 228 | with torch.cuda.device(CUDA_DEVICE): 229 | torch.cuda.empty_cache() 230 | torch.cuda.ipc_collect() 231 | 232 | 233 | app = FastAPI() 234 | 235 | 236 | @app.post("/") 237 | async def create_item(request: Request): 238 | global model, tokenizer 239 | json_post_raw = await request.json() 240 | json_post = json.dumps(json_post_raw) 241 | json_post_list = json.loads(json_post) 242 | prompt = json_post_list.get('prompt') 243 | history = json_post_list.get('history') 244 | max_length = json_post_list.get('max_length') 245 | top_p = json_post_list.get('top_p') 246 | temperature = json_post_list.get('temperature') 247 | response, history = model.chat(tokenizer, 248 | prompt, 249 | history=history, 250 | max_length=max_length if max_length else 2048, 251 | top_p=top_p if top_p else 0.7, 252 | temperature=temperature if temperature else 0.95) 253 | now = datetime.datetime.now() 254 | time = now.strftime("%Y-%m-%d %H:%M:%S") 255 | answer = { 256 | "response": response, 257 | "history": history, 258 | "status": 200, 259 | "time": time 260 | } 261 | log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' 262 | print(log) 263 | torch_gc() 264 | return answer 265 | 266 | 267 | if __name__ == '__main__': 268 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) 269 | model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() # FP16 270 | model.eval() 271 | uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) 272 | ``` 273 | 274 | 执行: 275 | > python3 api.py 276 | 277 | API调用方式(例如选择8000进行POST请求): 278 | ```bash 279 | curl -X POST "http://127.0.0.1:8000" \ 280 | -H 'Content-Type: application/json' \ 281 | -d '{"prompt": "你好", "history": []}' 282 | ``` 283 | #### 量化 284 | 加载模型时选择FP16+INT8量化: 285 | model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).quantize(8).half().cuda() 286 | 287 | 更多部署方法详见:https://github.com/THUDM/ChatGLM-6B 288 | -------------------------------------------------------------------------------- /chatglm_model_v1/README_en.md: -------------------------------------------------------------------------------- 1 | # ChatGLM-6B-PT 2 | This repository implements tuning of the ChatGLM-6B model based on [P-Tuning v2](https://github.com/THUDM/P-tuning-v2). P-Tuning v2 reduces the amount of parameters that need to be optimized to 0.1% of the full fine-tuning, and then through model quantization, Gradient Checkpoint and other methods, it only needs a minimum of 7GB of video memory to run. 3 | 4 | The following uses the [ADGEN](https://aclanthology.org/D19-1321.pdf) (advertising generation) dataset as an example to introduce how to use the code. 5 | 6 | ## Software dependencies 7 | Running p-tuning requires version 4.27.1 of `transformers`. In addition to the dependencies of ChatGLM-6B, the following dependencies are required 8 | ``` 9 | pip install rouge_chinese nltk jieba datasets 10 | ``` 11 | ## Instructions 12 | 13 | ### Download the dataset 14 | The task of the ADGEN dataset is to generate an advertisement word (summary) based on the input (content). 15 | 16 | ```json 17 | { 18 | "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳", 19 | "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。" 20 | } 21 | ``` 22 | 23 | From [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) download the processed ADGEN dataset, and put the decompressed `AdvertiseGen` directory into this directory. 24 | 25 | ### Training 26 | 27 | #### P-Tuning v2 28 | 29 | Run the following commands for training: 30 | ```shell 31 | bash train.sh 32 | ``` 33 | `PRE_SEQ_LEN` and `LR` in `train.sh` are soft prompt length and training learning rate respectively, which can be adjusted to achieve the best results. The P-Tuning-v2 method will freeze all model parameters, and the quantization level of the original model can be adjusted by adjusting `quantization_bit`. If this option is not added, it will be loaded with FP16 precision. 34 | 35 | Under the default configuration of `per_device_train_batch_size=1`, `gradient_accumulation_steps=16`, the model parameters of INT4 are frozen, and a training iteration will perform 16 cumulative forward and backward propagations with a batch size of 1, which is equivalent to the total batch size of 16, and only 6.7G GPU memory is required at this time with `quantization_bit=4`. If you want to improve the training efficiency under the same batch size, you can increase the value of `per_device_train_batch_size` while keeping the product of the two unchanged, but it will also bring more GPU memory consumption, please adjust it according to the actual situation. 36 | 37 | If you want to [load the model locally](../README_en.md#load-the-model-locally), you can change `THUDM/chatglm-6b` in `train.sh` to your local model path. 38 | 39 | #### Finetune 40 | To finetune the full parameters, you need to install [Deepspeed](https://github.com/microsoft/DeepSpeed), and then run the following command: 41 | 42 | ```shell 43 | bash ds_train_finetune.sh 44 | ``` 45 | 46 | ### Inference 47 | 48 | During P-tuning v2 training, the model only saves the parameters of the PrefixEncoder part, so the original ChatGLM-6B model and the weight of the PrefixEncoder need to be loaded at the same time during inference, and the arguments need to be specified in `evaluate.sh`: 49 | 50 | ```shell 51 | --model_name_or_path THUDM/chatglm-6b 52 | --ptuning_checkpoint $CHECKPOINT_PATH 53 | ``` 54 | 55 | It is still compatible with the old version of Checkpoint saved with full parameters, just set `model_name_or_path` as before: 56 | 57 | ```shell 58 | --model_name_or_path $CHECKPOINT_PATH 59 | ``` 60 | 61 | The evaluation indicators are Chinese Rouge score and BLEU-4. The generated results are saved in 62 | `./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`. 63 | 64 | ### Example 65 | #### Example 1 66 | * Input: 类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞 67 | * Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。 68 | * Output[before tuning]: 这件上衣的材质是牛仔布,颜色是白色,风格是简约,图案是刺绣,衣样式是外套,衣款式是破洞。 69 | * Output[after tuning]: 这是一款简约的牛仔外套,破洞设计,将牛仔布破洞,带来一种随意与个性。破洞的牛仔外套,展现出时尚气息,带来一种休闲感。同时,刺绣图案,让整件外套更加立体。 70 | 71 | #### Example 2 72 | 73 | * Input: 类型#裙\*版型#显瘦\*风格#文艺\*风格#简约\*图案#印花\*图案#撞色\*裙下摆#压褶\*裙长#连衣裙\*裙领型#圆领 74 | * Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。 75 | * Output[before tuning]: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领 1\. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2\. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3\. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4\. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5\. 风格:文艺风格,让连衣裙更加有内涵和品味。6\. 图案:印花设计,在连衣裙上印有独特的图案。7\. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。 76 | * Output[after tuning]: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。 77 | 78 | ### evaluation result 79 | 80 | | | Finetune | P-tuning v2 | LoRA | 81 | | ------------- | ----------- | ----- | ------------- | 82 | | BLEU-4 | 8.01 | 8.10 | 7.62 | 83 | | Rouge-1 | 31.23 | 31.12 | 30.60 | 84 | | Rouge-2 | 7.36 | 7.11 | 6.96 | 85 | | Rouge-l | 25.08 | 24.97 | 24.80 | 86 | | Training Loss | 3.00 | 3.74 | 3.32 | 87 | 88 | #### Experiment Settings 89 | 90 | ``` 91 | max_source_length=64 92 | max_target_length=64 93 | max_steps=3000 94 | ``` 95 | 96 | ##### P-tuning v2 97 | 98 | ``` 99 | pre_seq_len=128 100 | learning_rate=2e-2 101 | quantization_bit=4 102 | per_device_train_batch_size=16 103 | gradient_accumulation_steps=1 104 | ``` 105 | 106 | ##### Finetune 107 | 108 | ``` 109 | learning_rate=1e-4 110 | fp16 111 | num_gpus=4 112 | per_device_train_batch_size=4 113 | gradient_accumulation_steps=1 114 | ``` 115 | 116 | ##### LoRA 117 | 118 | The implementation uses [simple_thu_chatglm6b](https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/simple_thu_chatglm6b) 119 | 120 | ``` 121 | learning_rate=5e-4 122 | per_device_train_batch_size=16 123 | gradient_accumulation_steps=1 124 | ``` 125 | 126 | ## Model Deployment 127 | First load the tokenizer: 128 | 129 | ```python 130 | from transformers import AutoConfig, AutoModel, AutoTokenizer 131 | 132 | # Load Tokenizer 133 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) 134 | ``` 135 | 136 | 1. If a new Checkpoint needs to be loaded (only contains the PrefixEncoder parameter): 137 | 138 | ```python 139 | config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128) 140 | model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True) 141 | prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin")) 142 | new_prefix_state_dict = {} 143 | for k, v in prefix_state_dict.items(): 144 | if k.startswith("transformer.prefix_encoder."): 145 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v 146 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) 147 | ``` 148 | Note that you may need to change `pre_seq_len` to the actual value of your training. If you [load model from local](../README_en.md#load-the-model-locally), you need to change `THUDM/chatglm-6b` to the local model path (not the checkpoint path). 149 | 150 | 2. If you need to load the old checkpoint (including both ChatGLM-6B and PrefixEncoder parameters), or perform full parameter fine-tuning, then directly load the entire checkpoint: 151 | 152 | ```python 153 | model = AutoModel.from_pretrained(CHECKPOINT_PATH, trust_remote_code=True) 154 | ``` 155 | 156 | Then it can be quantified according to the needs, or it can be used directly: 157 | 158 | ```python 159 | # Comment out the following line if you don't use quantization 160 | model = model. quantize(4) 161 | model = model.half().cuda() 162 | model.transformer.prefix_encoder.float() 163 | model = model.eval() 164 | 165 | response, history = model.chat(tokenizer, "Hello", history=[]) 166 | ``` 167 | 168 | **[23/04/19]** You can also directly run [web demo](./web_demo.py) which supports loading P-Tuning v2 checkpoint 169 | ```shell 170 | bash web_demo.sh 171 | ``` 172 | It may be necessary to modify the content of [web_demo.sh](./web_demo.sh) to match your actual checkpoint situation. 173 | 174 | ## Use your own dataset 175 | Modify `train_file`, `validation_file` and `test_file` in `train.sh` and `evaluate.sh` to your own JSON format dataset paths, and change `prompt_column` and `response_column` to the keys in the JSON file corresponding to input text and output text. 176 | You may also need to increase `max_source_length` and `max_target_length` to match the maximum input and output lengths in your own dataset. 177 | 178 | ## Dialog Dataset 179 | 180 | If you need to use multiple rounds of dialogue data to train the model, you can provide chat history. For example, the following is the training data for a three-round dialogue: 181 | 182 | ```json lines 183 | {"prompt": "长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "response": "用电脑能读数据流吗?水温多少", "history": []} 184 | {"prompt": "95", "response": "上下水管温差怎么样啊?空气是不是都排干净了呢?", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"]]} 185 | {"prompt": "是的。上下水管都好的", "response": "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"], ["95", "上下水管温差怎么样啊?空气是不是都排干净了呢?"]]} 186 | ``` 187 | 188 | During training, you need to specify `--history_column` as the key of the chat history in the data (`history` in this example), and the chat history will be stitched automatically. Note that content exceeding the input length `max_source_length` will be truncated. 189 | 190 | You can refer to the following instructions: 191 | 192 | ```shell 193 | bash train_chat.sh 194 | ``` 195 | 196 | ## Citation 197 | 198 | ``` 199 | @inproceedings{liu2022p, 200 | title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks}, 201 | author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie}, 202 | booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)}, 203 | pages={61--68}, 204 | year={2022} 205 | } 206 | ``` -------------------------------------------------------------------------------- /chatglm_model_v1/arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | 5 | @dataclass 6 | class ModelArguments: 7 | """ 8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 9 | """ 10 | 11 | model_name_or_path: str = field( 12 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 13 | ) 14 | ptuning_checkpoint: str = field( 15 | default=None, metadata={"help": "Path to p-tuning v2 checkpoints"} 16 | ) 17 | config_name: Optional[str] = field( 18 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 19 | ) 20 | tokenizer_name: Optional[str] = field( 21 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 22 | ) 23 | cache_dir: Optional[str] = field( 24 | default=None, 25 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 26 | ) 27 | use_fast_tokenizer: bool = field( 28 | default=True, 29 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 30 | ) 31 | model_revision: str = field( 32 | default="main", 33 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 34 | ) 35 | use_auth_token: bool = field( 36 | default=False, 37 | metadata={ 38 | "help": ( 39 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 40 | "with private models)." 41 | ) 42 | }, 43 | ) 44 | resize_position_embeddings: Optional[bool] = field( 45 | default=None, 46 | metadata={ 47 | "help": ( 48 | "Whether to automatically resize the position embeddings if `max_source_length` exceeds " 49 | "the model's position embeddings." 50 | ) 51 | }, 52 | ) 53 | quantization_bit: Optional[int] = field( 54 | default=None 55 | ) 56 | pre_seq_len: Optional[int] = field( 57 | default=None 58 | ) 59 | prefix_projection: bool = field( 60 | default=False 61 | ) 62 | 63 | @dataclass 64 | class PeftArguments: 65 | peft_type: Optional[str] = field( 66 | default=None, metadata={"help": "The kind of parameter-efficient learning."} 67 | ) 68 | prompt_tuning_initial_text: Optional[str] = field( 69 | default=None, metadata={"help": "The initial token list of the soft prompt-tuning"} 70 | ) 71 | lora_dim: Optional[int] = field( 72 | default=8, metadata={"help": "The dimension of LoRA"} 73 | ) 74 | 75 | 76 | @dataclass 77 | class DataTrainingArguments: 78 | """ 79 | Arguments pertaining to what data we are going to input our model for training and eval. 80 | """ 81 | 82 | lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."}) 83 | 84 | dataset_name: Optional[str] = field( 85 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 86 | ) 87 | dataset_config_name: Optional[str] = field( 88 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 89 | ) 90 | prompt_column: Optional[str] = field( 91 | default=None, 92 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, 93 | ) 94 | response_column: Optional[str] = field( 95 | default=None, 96 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, 97 | ) 98 | history_column: Optional[str] = field( 99 | default=None, 100 | metadata={"help": "The name of the column in the datasets containing the history of chat."}, 101 | ) 102 | train_file: Optional[str] = field( 103 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 104 | ) 105 | validation_file: Optional[str] = field( 106 | default=None, 107 | metadata={ 108 | "help": ( 109 | "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 110 | ) 111 | }, 112 | ) 113 | test_file: Optional[str] = field( 114 | default=None, 115 | metadata={ 116 | "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 117 | }, 118 | ) 119 | task_name: Optional[str] = field( 120 | default="default_task", 121 | metadata={ 122 | "help": "The task name." 123 | }, 124 | ) 125 | overwrite_cache: bool = field( 126 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 127 | ) 128 | preprocessing_num_workers: Optional[int] = field( 129 | default=None, 130 | metadata={"help": "The number of processes to use for the preprocessing."}, 131 | ) 132 | max_source_length: Optional[int] = field( 133 | default=1024, 134 | metadata={ 135 | "help": ( 136 | "The maximum total input sequence length after tokenization. Sequences longer " 137 | "than this will be truncated, sequences shorter will be padded." 138 | ) 139 | }, 140 | ) 141 | max_target_length: Optional[int] = field( 142 | default=128, 143 | metadata={ 144 | "help": ( 145 | "The maximum total sequence length for target text after tokenization. Sequences longer " 146 | "than this will be truncated, sequences shorter will be padded." 147 | ) 148 | }, 149 | ) 150 | val_max_target_length: Optional[int] = field( 151 | default=None, 152 | metadata={ 153 | "help": ( 154 | "The maximum total sequence length for validation target text after tokenization. Sequences longer " 155 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 156 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 157 | "during ``evaluate`` and ``predict``." 158 | ) 159 | }, 160 | ) 161 | pad_to_max_length: bool = field( 162 | default=False, 163 | metadata={ 164 | "help": ( 165 | "Whether to pad all samples to model maximum sentence length. " 166 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 167 | "efficient on GPU but very bad for TPU." 168 | ) 169 | }, 170 | ) 171 | max_train_samples: Optional[int] = field( 172 | default=None, 173 | metadata={ 174 | "help": ( 175 | "For debugging purposes or quicker training, truncate the number of training examples to this " 176 | "value if set." 177 | ) 178 | }, 179 | ) 180 | max_eval_samples: Optional[int] = field( 181 | default=None, 182 | metadata={ 183 | "help": ( 184 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 185 | "value if set." 186 | ) 187 | }, 188 | ) 189 | max_predict_samples: Optional[int] = field( 190 | default=None, 191 | metadata={ 192 | "help": ( 193 | "For debugging purposes or quicker training, truncate the number of prediction examples to this " 194 | "value if set." 195 | ) 196 | }, 197 | ) 198 | num_beams: Optional[int] = field( 199 | default=None, 200 | metadata={ 201 | "help": ( 202 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 203 | "which is used during ``evaluate`` and ``predict``." 204 | ) 205 | }, 206 | ) 207 | ignore_pad_token_for_loss: bool = field( 208 | default=True, 209 | metadata={ 210 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 211 | }, 212 | ) 213 | source_prefix: Optional[str] = field( 214 | default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 215 | ) 216 | 217 | forced_bos_token: Optional[str] = field( 218 | default=None, 219 | metadata={ 220 | "help": ( 221 | "The token to force as the first generated token after the decoder_start_token_id." 222 | "Useful for multilingual models like mBART where the first generated token." 223 | "needs to be the target language token (Usually it is the target language token)." 224 | ) 225 | }, 226 | ) 227 | 228 | base_cache_dir: Optional[str] = field( 229 | default=None, 230 | metadata={ 231 | "help": ( 232 | "The path of cache." 233 | ) 234 | }, 235 | ) 236 | 237 | 238 | 239 | def __post_init__(self): 240 | if self.dataset_name is None and self.train_file is None and self.validation_file is None and self.test_file is None: 241 | raise ValueError("Need either a dataset name or a training/validation/test file.") 242 | else: 243 | if self.train_file is not None: 244 | extension = self.train_file.split(".")[-1] 245 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 246 | if self.validation_file is not None: 247 | extension = self.validation_file.split(".")[-1] 248 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 249 | if self.val_max_target_length is None: 250 | self.val_max_target_length = self.max_target_length 251 | 252 | -------------------------------------------------------------------------------- /chatglm_model_v2/arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | 5 | @dataclass 6 | class ModelArguments: 7 | """ 8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 9 | """ 10 | 11 | model_name_or_path: str = field( 12 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 13 | ) 14 | ptuning_checkpoint: str = field( 15 | default=None, metadata={"help": "Path to p-tuning v2 checkpoints"} 16 | ) 17 | config_name: Optional[str] = field( 18 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 19 | ) 20 | tokenizer_name: Optional[str] = field( 21 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 22 | ) 23 | cache_dir: Optional[str] = field( 24 | default=None, 25 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 26 | ) 27 | use_fast_tokenizer: bool = field( 28 | default=True, 29 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 30 | ) 31 | model_revision: str = field( 32 | default="main", 33 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 34 | ) 35 | use_auth_token: bool = field( 36 | default=False, 37 | metadata={ 38 | "help": ( 39 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 40 | "with private models)." 41 | ) 42 | }, 43 | ) 44 | resize_position_embeddings: Optional[bool] = field( 45 | default=None, 46 | metadata={ 47 | "help": ( 48 | "Whether to automatically resize the position embeddings if `max_source_length` exceeds " 49 | "the model's position embeddings." 50 | ) 51 | }, 52 | ) 53 | quantization_bit: Optional[int] = field( 54 | default=None 55 | ) 56 | pre_seq_len: Optional[int] = field( 57 | default=None 58 | ) 59 | prefix_projection: bool = field( 60 | default=False 61 | ) 62 | 63 | @dataclass 64 | class PeftArguments: 65 | peft_type: Optional[str] = field( 66 | default=None, metadata={"help": "The kind of parameter-efficient learning."} 67 | ) 68 | prompt_tuning_initial_text: Optional[str] = field( 69 | default=None, metadata={"help": "The initial token list of the soft prompt-tuning"} 70 | ) 71 | lora_dim: Optional[int] = field( 72 | default=8, metadata={"help": "The dimension of LoRA"} 73 | ) 74 | 75 | 76 | @dataclass 77 | class DataTrainingArguments: 78 | """ 79 | Arguments pertaining to what data we are going to input our model for training and eval. 80 | """ 81 | 82 | lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."}) 83 | 84 | dataset_name: Optional[str] = field( 85 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 86 | ) 87 | dataset_config_name: Optional[str] = field( 88 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 89 | ) 90 | prompt_column: Optional[str] = field( 91 | default=None, 92 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, 93 | ) 94 | response_column: Optional[str] = field( 95 | default=None, 96 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, 97 | ) 98 | history_column: Optional[str] = field( 99 | default=None, 100 | metadata={"help": "The name of the column in the datasets containing the history of chat."}, 101 | ) 102 | train_file: Optional[str] = field( 103 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 104 | ) 105 | validation_file: Optional[str] = field( 106 | default=None, 107 | metadata={ 108 | "help": ( 109 | "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 110 | ) 111 | }, 112 | ) 113 | test_file: Optional[str] = field( 114 | default=None, 115 | metadata={ 116 | "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 117 | }, 118 | ) 119 | task_name: Optional[str] = field( 120 | default="default_task", 121 | metadata={ 122 | "help": "The task name." 123 | }, 124 | ) 125 | overwrite_cache: bool = field( 126 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 127 | ) 128 | preprocessing_num_workers: Optional[int] = field( 129 | default=None, 130 | metadata={"help": "The number of processes to use for the preprocessing."}, 131 | ) 132 | max_source_length: Optional[int] = field( 133 | default=1024, 134 | metadata={ 135 | "help": ( 136 | "The maximum total input sequence length after tokenization. Sequences longer " 137 | "than this will be truncated, sequences shorter will be padded." 138 | ) 139 | }, 140 | ) 141 | max_target_length: Optional[int] = field( 142 | default=128, 143 | metadata={ 144 | "help": ( 145 | "The maximum total sequence length for target text after tokenization. Sequences longer " 146 | "than this will be truncated, sequences shorter will be padded." 147 | ) 148 | }, 149 | ) 150 | val_max_target_length: Optional[int] = field( 151 | default=None, 152 | metadata={ 153 | "help": ( 154 | "The maximum total sequence length for validation target text after tokenization. Sequences longer " 155 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 156 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 157 | "during ``evaluate`` and ``predict``." 158 | ) 159 | }, 160 | ) 161 | pad_to_max_length: bool = field( 162 | default=False, 163 | metadata={ 164 | "help": ( 165 | "Whether to pad all samples to model maximum sentence length. " 166 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 167 | "efficient on GPU but very bad for TPU." 168 | ) 169 | }, 170 | ) 171 | max_train_samples: Optional[int] = field( 172 | default=None, 173 | metadata={ 174 | "help": ( 175 | "For debugging purposes or quicker training, truncate the number of training examples to this " 176 | "value if set." 177 | ) 178 | }, 179 | ) 180 | max_eval_samples: Optional[int] = field( 181 | default=None, 182 | metadata={ 183 | "help": ( 184 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 185 | "value if set." 186 | ) 187 | }, 188 | ) 189 | max_predict_samples: Optional[int] = field( 190 | default=None, 191 | metadata={ 192 | "help": ( 193 | "For debugging purposes or quicker training, truncate the number of prediction examples to this " 194 | "value if set." 195 | ) 196 | }, 197 | ) 198 | num_beams: Optional[int] = field( 199 | default=None, 200 | metadata={ 201 | "help": ( 202 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 203 | "which is used during ``evaluate`` and ``predict``." 204 | ) 205 | }, 206 | ) 207 | ignore_pad_token_for_loss: bool = field( 208 | default=True, 209 | metadata={ 210 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 211 | }, 212 | ) 213 | source_prefix: Optional[str] = field( 214 | default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 215 | ) 216 | 217 | forced_bos_token: Optional[str] = field( 218 | default=None, 219 | metadata={ 220 | "help": ( 221 | "The token to force as the first generated token after the decoder_start_token_id." 222 | "Useful for multilingual models like mBART where the first generated token." 223 | "needs to be the target language token (Usually it is the target language token)." 224 | ) 225 | }, 226 | ) 227 | 228 | base_cache_dir: Optional[str] = field( 229 | default=None, 230 | metadata={ 231 | "help": ( 232 | "The path of cache." 233 | ) 234 | }, 235 | ) 236 | 237 | 238 | 239 | def __post_init__(self): 240 | if self.dataset_name is None and self.train_file is None and self.validation_file is None and self.test_file is None: 241 | raise ValueError("Need either a dataset name or a training/validation/test file.") 242 | else: 243 | if self.train_file is not None: 244 | extension = self.train_file.split(".")[-1] 245 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 246 | if self.validation_file is not None: 247 | extension = self.validation_file.split(".")[-1] 248 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 249 | if self.val_max_target_length is None: 250 | self.val_max_target_length = self.max_target_length 251 | 252 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /chatglm_model_v1/trainer_seq2seq.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, Dict, List, Optional, Tuple, Union 16 | 17 | import torch 18 | from torch import nn 19 | from torch.utils.data import Dataset 20 | 21 | from transformers.deepspeed import is_deepspeed_zero3_enabled 22 | from trainer import Trainer 23 | from transformers.trainer_utils import PredictionOutput 24 | from transformers.utils import logging 25 | 26 | 27 | logger = logging.get_logger(__name__) 28 | 29 | 30 | class Seq2SeqTrainer(Trainer): 31 | def evaluate( 32 | self, 33 | eval_dataset: Optional[Dataset] = None, 34 | ignore_keys: Optional[List[str]] = None, 35 | metric_key_prefix: str = "eval", 36 | **gen_kwargs 37 | ) -> Dict[str, float]: 38 | """ 39 | Run evaluation and returns metrics. 40 | 41 | The calling script will be responsible for providing a method to compute metrics, as they are task-dependent 42 | (pass it to the init `compute_metrics` argument). 43 | 44 | You can also subclass and override this method to inject custom behavior. 45 | 46 | Args: 47 | eval_dataset (`Dataset`, *optional*): 48 | Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns 49 | not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` 50 | method. 51 | ignore_keys (`List[str]`, *optional*): 52 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 53 | gathering predictions. 54 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`): 55 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named 56 | "eval_bleu" if the prefix is `"eval"` (default) 57 | max_length (`int`, *optional*): 58 | The maximum target length to use when predicting with the generate method. 59 | num_beams (`int`, *optional*): 60 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no 61 | beam search. 62 | gen_kwargs: 63 | Additional `generate` specific kwargs. 64 | 65 | Returns: 66 | A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The 67 | dictionary also contains the epoch number which comes from the training state. 68 | """ 69 | 70 | gen_kwargs = gen_kwargs.copy() 71 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: 72 | gen_kwargs["max_length"] = self.args.generation_max_length 73 | gen_kwargs["num_beams"] = ( 74 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams 75 | ) 76 | self._gen_kwargs = gen_kwargs 77 | 78 | return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) 79 | 80 | def predict( 81 | self, 82 | test_dataset: Dataset, 83 | ignore_keys: Optional[List[str]] = None, 84 | metric_key_prefix: str = "test", 85 | **gen_kwargs 86 | ) -> PredictionOutput: 87 | """ 88 | Run prediction and returns predictions and potential metrics. 89 | 90 | Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method 91 | will also return metrics, like in `evaluate()`. 92 | 93 | Args: 94 | test_dataset (`Dataset`): 95 | Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the 96 | `model.forward()` method are automatically removed. Has to implement the method `__len__` 97 | ignore_keys (`List[str]`, *optional*): 98 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 99 | gathering predictions. 100 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`): 101 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named 102 | "eval_bleu" if the prefix is `"eval"` (default) 103 | max_length (`int`, *optional*): 104 | The maximum target length to use when predicting with the generate method. 105 | num_beams (`int`, *optional*): 106 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no 107 | beam search. 108 | gen_kwargs: 109 | Additional `generate` specific kwargs. 110 | 111 | 112 | 113 | If your predictions or labels have different sequence lengths (for instance because you're doing dynamic 114 | padding in a token classification task) the predictions will be padded (on the right) to allow for 115 | concatenation into one array. The padding index is -100. 116 | 117 | 118 | 119 | Returns: *NamedTuple* A namedtuple with the following keys: 120 | 121 | - predictions (`np.ndarray`): The predictions on `test_dataset`. 122 | - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some). 123 | - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained 124 | labels). 125 | """ 126 | 127 | gen_kwargs = gen_kwargs.copy() 128 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: 129 | gen_kwargs["max_length"] = self.args.generation_max_length 130 | gen_kwargs["num_beams"] = ( 131 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams 132 | ) 133 | self._gen_kwargs = gen_kwargs 134 | 135 | 136 | return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) 137 | 138 | def prediction_step( 139 | self, 140 | model: nn.Module, 141 | inputs: Dict[str, Union[torch.Tensor, Any]], 142 | prediction_loss_only: bool, 143 | ignore_keys: Optional[List[str]] = None, 144 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 145 | """ 146 | Perform an evaluation step on `model` using `inputs`. 147 | 148 | Subclass and override to inject custom behavior. 149 | 150 | Args: 151 | model (`nn.Module`): 152 | The model to evaluate. 153 | inputs (`Dict[str, Union[torch.Tensor, Any]]`): 154 | The inputs and targets of the model. 155 | 156 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 157 | argument `labels`. Check your model's documentation for all accepted arguments. 158 | prediction_loss_only (`bool`): 159 | Whether or not to return the loss only. 160 | 161 | Return: 162 | Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and 163 | labels (each being optional). 164 | """ 165 | 166 | if not self.args.predict_with_generate or prediction_loss_only: 167 | return super().prediction_step( 168 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys 169 | ) 170 | 171 | has_labels = "labels" in inputs 172 | inputs = self._prepare_inputs(inputs) 173 | 174 | # XXX: adapt synced_gpus for fairscale as well 175 | gen_kwargs = self._gen_kwargs.copy() 176 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: 177 | gen_kwargs["max_length"] = self.model.config.max_length 178 | gen_kwargs["num_beams"] = ( 179 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams 180 | ) 181 | default_synced_gpus = True if is_deepspeed_zero3_enabled() else False 182 | gen_kwargs["synced_gpus"] = ( 183 | gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus 184 | ) 185 | 186 | if "attention_mask" in inputs: 187 | gen_kwargs["attention_mask"] = inputs.get("attention_mask", None) 188 | if "position_ids" in inputs: 189 | gen_kwargs["position_ids"] = inputs.get("position_ids", None) 190 | if "global_attention_mask" in inputs: 191 | gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None) 192 | 193 | # prepare generation inputs 194 | # some encoder-decoder models can have varying encoder's and thus 195 | # varying model input names 196 | if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name: 197 | generation_inputs = inputs[self.model.encoder.main_input_name] 198 | else: 199 | generation_inputs = inputs[self.model.main_input_name] 200 | 201 | gen_kwargs["input_ids"] = generation_inputs 202 | generated_tokens = self.model.generate(**gen_kwargs) 203 | generated_tokens = generated_tokens[:, generation_inputs.size()[-1]:] 204 | 205 | # in case the batch is shorter than max length, the output should be padded 206 | if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]: 207 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) 208 | elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < ( 209 | gen_kwargs["max_new_tokens"] + 1 210 | ): 211 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1) 212 | 213 | loss = None 214 | 215 | if self.args.prediction_loss_only: 216 | return (loss, None, None) 217 | 218 | if has_labels: 219 | labels = inputs["labels"] 220 | if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]: 221 | labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) 222 | elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < ( 223 | gen_kwargs["max_new_tokens"] + 1 224 | ): 225 | labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1)) 226 | else: 227 | labels = None 228 | 229 | return (loss, generated_tokens, labels) 230 | 231 | def _pad_tensors_to_max_len(self, tensor, max_length): 232 | if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"): 233 | # If PAD token is not defined at least EOS token has to be defined 234 | pad_token_id = ( 235 | self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id 236 | ) 237 | else: 238 | if self.model.config.pad_token_id is not None: 239 | pad_token_id = self.model.config.pad_token_id 240 | else: 241 | raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors") 242 | 243 | padded_tensor = pad_token_id * torch.ones( 244 | (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device 245 | ) 246 | padded_tensor[:, : tensor.shape[-1]] = tensor 247 | return padded_tensor 248 | -------------------------------------------------------------------------------- /chatglm_model_v2/trainer_seq2seq.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, Dict, List, Optional, Tuple, Union 16 | 17 | import torch 18 | from torch import nn 19 | from torch.utils.data import Dataset 20 | 21 | from transformers.deepspeed import is_deepspeed_zero3_enabled 22 | from trainer import PrefixTrainer 23 | from transformers.trainer_utils import PredictionOutput 24 | from transformers.utils import logging 25 | 26 | 27 | logger = logging.get_logger(__name__) 28 | 29 | 30 | class Seq2SeqTrainer(PrefixTrainer): 31 | def evaluate( 32 | self, 33 | eval_dataset: Optional[Dataset] = None, 34 | ignore_keys: Optional[List[str]] = None, 35 | metric_key_prefix: str = "eval", 36 | **gen_kwargs 37 | ) -> Dict[str, float]: 38 | """ 39 | Run evaluation and returns metrics. 40 | 41 | The calling script will be responsible for providing a method to compute metrics, as they are task-dependent 42 | (pass it to the init `compute_metrics` argument). 43 | 44 | You can also subclass and override this method to inject custom behavior. 45 | 46 | Args: 47 | eval_dataset (`Dataset`, *optional*): 48 | Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns 49 | not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` 50 | method. 51 | ignore_keys (`List[str]`, *optional*): 52 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 53 | gathering predictions. 54 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`): 55 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named 56 | "eval_bleu" if the prefix is `"eval"` (default) 57 | max_length (`int`, *optional*): 58 | The maximum target length to use when predicting with the generate method. 59 | num_beams (`int`, *optional*): 60 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no 61 | beam search. 62 | gen_kwargs: 63 | Additional `generate` specific kwargs. 64 | 65 | Returns: 66 | A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The 67 | dictionary also contains the epoch number which comes from the training state. 68 | """ 69 | 70 | gen_kwargs = gen_kwargs.copy() 71 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: 72 | gen_kwargs["max_length"] = self.args.generation_max_length 73 | gen_kwargs["num_beams"] = ( 74 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams 75 | ) 76 | self._gen_kwargs = gen_kwargs 77 | 78 | return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) 79 | 80 | def predict( 81 | self, 82 | test_dataset: Dataset, 83 | ignore_keys: Optional[List[str]] = None, 84 | metric_key_prefix: str = "test", 85 | **gen_kwargs 86 | ) -> PredictionOutput: 87 | """ 88 | Run prediction and returns predictions and potential metrics. 89 | 90 | Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method 91 | will also return metrics, like in `evaluate()`. 92 | 93 | Args: 94 | test_dataset (`Dataset`): 95 | Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the 96 | `model.forward()` method are automatically removed. Has to implement the method `__len__` 97 | ignore_keys (`List[str]`, *optional*): 98 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 99 | gathering predictions. 100 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`): 101 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named 102 | "eval_bleu" if the prefix is `"eval"` (default) 103 | max_length (`int`, *optional*): 104 | The maximum target length to use when predicting with the generate method. 105 | num_beams (`int`, *optional*): 106 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no 107 | beam search. 108 | gen_kwargs: 109 | Additional `generate` specific kwargs. 110 | 111 | 112 | 113 | If your predictions or labels have different sequence lengths (for instance because you're doing dynamic 114 | padding in a token classification task) the predictions will be padded (on the right) to allow for 115 | concatenation into one array. The padding index is -100. 116 | 117 | 118 | 119 | Returns: *NamedTuple* A namedtuple with the following keys: 120 | 121 | - predictions (`np.ndarray`): The predictions on `test_dataset`. 122 | - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some). 123 | - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained 124 | labels). 125 | """ 126 | 127 | gen_kwargs = gen_kwargs.copy() 128 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: 129 | gen_kwargs["max_length"] = self.args.generation_max_length 130 | gen_kwargs["num_beams"] = ( 131 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams 132 | ) 133 | self._gen_kwargs = gen_kwargs 134 | 135 | 136 | return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) 137 | 138 | def prediction_step( 139 | self, 140 | model: nn.Module, 141 | inputs: Dict[str, Union[torch.Tensor, Any]], 142 | prediction_loss_only: bool, 143 | ignore_keys: Optional[List[str]] = None, 144 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 145 | """ 146 | Perform an evaluation step on `model` using `inputs`. 147 | 148 | Subclass and override to inject custom behavior. 149 | 150 | Args: 151 | model (`nn.Module`): 152 | The model to evaluate. 153 | inputs (`Dict[str, Union[torch.Tensor, Any]]`): 154 | The inputs and targets of the model. 155 | 156 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 157 | argument `labels`. Check your model's documentation for all accepted arguments. 158 | prediction_loss_only (`bool`): 159 | Whether or not to return the loss only. 160 | 161 | Return: 162 | Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and 163 | labels (each being optional). 164 | """ 165 | 166 | if not self.args.predict_with_generate or prediction_loss_only: 167 | return super().prediction_step( 168 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys 169 | ) 170 | 171 | has_labels = "labels" in inputs 172 | inputs = self._prepare_inputs(inputs) 173 | 174 | # XXX: adapt synced_gpus for fairscale as well 175 | gen_kwargs = self._gen_kwargs.copy() 176 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: 177 | gen_kwargs["max_length"] = self.model.config.max_length 178 | gen_kwargs["num_beams"] = ( 179 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams 180 | ) 181 | default_synced_gpus = True if is_deepspeed_zero3_enabled() else False 182 | gen_kwargs["synced_gpus"] = ( 183 | gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus 184 | ) 185 | 186 | if "attention_mask" in inputs: 187 | gen_kwargs["attention_mask"] = inputs.get("attention_mask", None) 188 | if "position_ids" in inputs: 189 | gen_kwargs["position_ids"] = inputs.get("position_ids", None) 190 | if "global_attention_mask" in inputs: 191 | gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None) 192 | 193 | # prepare generation inputs 194 | # some encoder-decoder models can have varying encoder's and thus 195 | # varying model input names 196 | if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name: 197 | generation_inputs = inputs[self.model.encoder.main_input_name] 198 | else: 199 | generation_inputs = inputs[self.model.main_input_name] 200 | 201 | gen_kwargs["input_ids"] = generation_inputs 202 | generated_tokens = self.model.generate(**gen_kwargs) 203 | generated_tokens = generated_tokens[:, generation_inputs.size()[-1]:] 204 | 205 | # in case the batch is shorter than max length, the output should be padded 206 | if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]: 207 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) 208 | elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < ( 209 | gen_kwargs["max_new_tokens"] + 1 210 | ): 211 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1) 212 | 213 | loss = None 214 | 215 | if self.args.prediction_loss_only: 216 | return (loss, None, None) 217 | 218 | if has_labels: 219 | labels = inputs["labels"] 220 | if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]: 221 | labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) 222 | elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < ( 223 | gen_kwargs["max_new_tokens"] + 1 224 | ): 225 | labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1)) 226 | else: 227 | labels = None 228 | 229 | return (loss, generated_tokens, labels) 230 | 231 | def _pad_tensors_to_max_len(self, tensor, max_length): 232 | if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"): 233 | # If PAD token is not defined at least EOS token has to be defined 234 | pad_token_id = ( 235 | self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id 236 | ) 237 | else: 238 | if self.model.config.pad_token_id is not None: 239 | pad_token_id = self.model.config.pad_token_id 240 | else: 241 | raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors") 242 | 243 | padded_tensor = pad_token_id * torch.ones( 244 | (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device 245 | ) 246 | padded_tensor[:, : tensor.shape[-1]] = tensor 247 | return padded_tensor 248 | -------------------------------------------------------------------------------- /chatglm_model_v2/run_ptuning.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2021 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for sequence to sequence. 18 | """ 19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. 20 | 21 | import logging 22 | import os 23 | import sys 24 | import json 25 | 26 | import numpy as np 27 | from datasets import load_dataset 28 | import jieba 29 | from rouge_chinese import Rouge 30 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 31 | import torch 32 | 33 | import transformers 34 | from transformers import ( 35 | AutoConfig, 36 | AutoModel, 37 | AutoTokenizer, 38 | DataCollatorForSeq2Seq, 39 | HfArgumentParser, 40 | Seq2SeqTrainingArguments, 41 | set_seed, 42 | ) 43 | from trainer_seq2seq import Seq2SeqTrainer 44 | 45 | from arguments import ModelArguments, DataTrainingArguments 46 | 47 | logger = logging.getLogger(__name__) 48 | 49 | def main(): 50 | # 加载模型、训练和数据参数配置 51 | # loading model, training and data augments 52 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) 53 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 54 | # If we pass only one argument to the script and it's the path to a json file, 55 | # let's parse it to get our arguments. 56 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 57 | else: 58 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 59 | 60 | # Setup logging 61 | logging.basicConfig( 62 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 63 | datefmt="%m/%d/%Y %H:%M:%S", 64 | handlers=[logging.StreamHandler(sys.stdout)], 65 | ) 66 | 67 | if training_args.should_log: 68 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 69 | transformers.utils.logging.set_verbosity_info() 70 | 71 | log_level = training_args.get_process_log_level() 72 | logger.setLevel(log_level) 73 | # datasets.utils.logging.set_verbosity(log_level) 74 | transformers.utils.logging.set_verbosity(log_level) 75 | transformers.utils.logging.enable_default_handler() 76 | transformers.utils.logging.enable_explicit_format() 77 | 78 | # Log on each process the small summary: 79 | logger.warning( 80 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 81 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 82 | ) 83 | logger.info(f"Training/evaluation parameters {training_args}") 84 | 85 | # Set seed before initializing model. 86 | set_seed(training_args.seed) 87 | 88 | # Load dataset 89 | data_files = {} 90 | if data_args.train_file is not None: 91 | data_files["train"] = data_args.train_file 92 | extension = data_args.train_file.split(".")[-1] 93 | if data_args.validation_file is not None: 94 | data_files["validation"] = data_args.validation_file 95 | extension = data_args.validation_file.split(".")[-1] 96 | if data_args.test_file is not None: 97 | data_files["test"] = data_args.test_file 98 | extension = data_args.test_file.split(".")[-1] 99 | 100 | # 读取为hugging face格式的数据 101 | raw_datasets = load_dataset( 102 | extension, 103 | data_files=data_files, 104 | cache_dir=model_args.cache_dir, 105 | use_auth_token=True if model_args.use_auth_token else None, 106 | ) 107 | 108 | # Load pretrained model and tokenizer 109 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) 110 | config.pre_seq_len = model_args.pre_seq_len 111 | config.prefix_projection = model_args.prefix_projection 112 | 113 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) 114 | 115 | if model_args.ptuning_checkpoint is not None: 116 | # Evaluation 117 | # Loading extra state dict of prefix encoder 118 | # 推理时,只需要加载p-tuning v2对应的参数即可,然后将其与原始模型参数进行结合 119 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) 120 | prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin")) 121 | new_prefix_state_dict = {} 122 | # 将p-tuning v2的参与嵌入到原始模型中。 123 | for k, v in prefix_state_dict.items(): 124 | if k.startswith("transformer.prefix_encoder."): 125 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v 126 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) 127 | else: 128 | # 直接加载原始模型 129 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) 130 | 131 | if model_args.quantization_bit is not None: 132 | print(f"Quantized to {model_args.quantization_bit} bit") 133 | model = model.quantize(model_args.quantization_bit) 134 | if model_args.pre_seq_len is not None: 135 | # P-tuning v2 136 | model = model.half() 137 | model.transformer.prefix_encoder.float() 138 | else: 139 | # Finetune 140 | model = model.float() 141 | 142 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 143 | 144 | # Preprocessing the datasets. 145 | # We need to tokenize inputs and targets. 146 | if training_args.do_train: 147 | column_names = raw_datasets["train"].column_names 148 | elif training_args.do_eval: 149 | column_names = raw_datasets["validation"].column_names 150 | elif training_args.do_predict: 151 | column_names = raw_datasets["test"].column_names 152 | else: 153 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") 154 | return 155 | 156 | # Get the column names for input/target. 157 | prompt_column = data_args.prompt_column 158 | response_column = data_args.response_column 159 | history_column = data_args.history_column 160 | 161 | # Temporarily set max_target_length for training. 162 | max_target_length = data_args.max_target_length 163 | 164 | def preprocess_function_eval(examples): 165 | inputs, targets = [], [] 166 | for i in range(len(examples[prompt_column])): 167 | if examples[prompt_column][i] and examples[response_column][i]: 168 | query = examples[prompt_column][i] 169 | history = examples[history_column][i] if history_column is not None else None 170 | prompt = tokenizer.build_prompt(query, history) 171 | inputs.append(prompt) 172 | targets.append(examples[response_column][i]) 173 | 174 | inputs = [prefix + inp for inp in inputs] 175 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True) 176 | labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True) 177 | 178 | if data_args.ignore_pad_token_for_loss: 179 | labels["input_ids"] = [ 180 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 181 | ] 182 | model_inputs["labels"] = labels["input_ids"] 183 | 184 | return model_inputs 185 | 186 | def preprocess_function_train(examples): 187 | max_seq_length = data_args.max_source_length + data_args.max_target_length + 1 188 | 189 | model_inputs = { 190 | "input_ids": [], 191 | "labels": [], 192 | } 193 | for i in range(len(examples[prompt_column])): 194 | if examples[prompt_column][i] and examples[response_column][i]: 195 | query, answer = examples[prompt_column][i], examples[response_column][i] 196 | 197 | history = examples[history_column][i] if history_column is not None else None 198 | prompt = tokenizer.build_prompt(query, history) 199 | 200 | prompt = prefix + prompt 201 | a_ids = tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True, 202 | max_length=data_args.max_source_length) 203 | b_ids = tokenizer.encode(text=answer, add_special_tokens=False, truncation=True, 204 | max_length=data_args.max_target_length) 205 | 206 | context_length = len(a_ids) 207 | input_ids = a_ids + b_ids + [tokenizer.eos_token_id] 208 | labels = [tokenizer.pad_token_id] * context_length + b_ids + [tokenizer.eos_token_id] 209 | 210 | pad_len = max_seq_length - len(input_ids) 211 | input_ids = input_ids + [tokenizer.pad_token_id] * pad_len 212 | labels = labels + [tokenizer.pad_token_id] * pad_len 213 | if data_args.ignore_pad_token_for_loss: 214 | labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels] 215 | 216 | model_inputs["input_ids"].append(input_ids) 217 | model_inputs["labels"].append(labels) 218 | 219 | return model_inputs 220 | 221 | def print_dataset_example(example): 222 | print("input_ids", example["input_ids"]) 223 | print("inputs", tokenizer.decode(example["input_ids"])) 224 | print("label_ids", example["labels"]) 225 | print("labels", tokenizer.decode(example["labels"])) 226 | 227 | base_cache_dir = os.path(data_args.base_cache_dir, data_args.task_name) 228 | if training_args.local_rank <= 0 and not os.path.exists(base_cache_dir): 229 | os.makedirs(base_cache_dir) 230 | 231 | if training_args.do_train: 232 | if "train" not in raw_datasets: 233 | raise ValueError("--do_train requires a train dataset") 234 | train_dataset = raw_datasets["train"] 235 | if data_args.max_train_samples is not None: 236 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 237 | train_dataset = train_dataset.select(range(max_train_samples)) 238 | with training_args.main_process_first(desc="train dataset map pre-processing"): 239 | train_dataset = train_dataset.map( 240 | preprocess_function_train, 241 | batched=True, 242 | num_proc=data_args.preprocessing_num_workers, 243 | remove_columns=column_names, 244 | load_from_cache_file=not data_args.overwrite_cache, 245 | desc="Running tokenizer on train dataset", 246 | cache_file_name=os.path.join(base_cache_dir, "train.arrow") 247 | ) 248 | print_dataset_example(train_dataset[0]) 249 | 250 | if training_args.do_eval: 251 | max_target_length = data_args.val_max_target_length 252 | if "validation" not in raw_datasets: 253 | raise ValueError("--do_eval requires a validation dataset") 254 | eval_dataset = raw_datasets["validation"] 255 | if data_args.max_eval_samples is not None: 256 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 257 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 258 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 259 | eval_dataset = eval_dataset.map( 260 | preprocess_function_eval, 261 | batched=True, 262 | num_proc=data_args.preprocessing_num_workers, 263 | remove_columns=column_names, 264 | load_from_cache_file=not data_args.overwrite_cache, 265 | desc="Running tokenizer on validation dataset", 266 | cache_file_name=os.path.join(base_cache_dir, "eval.arrow") 267 | ) 268 | print_dataset_example(eval_dataset[0]) 269 | 270 | if training_args.do_predict: 271 | max_target_length = data_args.val_max_target_length 272 | if "test" not in raw_datasets: 273 | raise ValueError("--do_predict requires a test dataset") 274 | predict_dataset = raw_datasets["test"] 275 | if data_args.max_predict_samples is not None: 276 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) 277 | predict_dataset = predict_dataset.select(range(max_predict_samples)) 278 | with training_args.main_process_first(desc="prediction dataset map pre-processing"): 279 | predict_dataset = predict_dataset.map( 280 | preprocess_function_eval, 281 | batched=True, 282 | num_proc=data_args.preprocessing_num_workers, 283 | remove_columns=column_names, 284 | load_from_cache_file=not data_args.overwrite_cache, 285 | desc="Running tokenizer on prediction dataset", 286 | cache_file_name=os.path.join(base_cache_dir, "predict.arrow") 287 | ) 288 | print_dataset_example(predict_dataset[0]) 289 | 290 | # Data collator 291 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 292 | data_collator = DataCollatorForSeq2Seq( 293 | tokenizer, 294 | model=model, 295 | label_pad_token_id=label_pad_token_id, 296 | pad_to_multiple_of=None, 297 | padding=False 298 | ) 299 | 300 | # Metric 301 | def compute_metrics(eval_preds): 302 | preds, labels = eval_preds 303 | if isinstance(preds, tuple): 304 | preds = preds[0] 305 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 306 | if data_args.ignore_pad_token_for_loss: 307 | # Replace -100 in the labels as we can't decode them. 308 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 309 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 310 | 311 | score_dict = { 312 | "rouge-1": [], 313 | "rouge-2": [], 314 | "rouge-l": [], 315 | "bleu-4": [] 316 | } 317 | for pred, label in zip(decoded_preds, decoded_labels): 318 | hypothesis = list(jieba.cut(pred)) 319 | reference = list(jieba.cut(label)) 320 | rouge = Rouge() 321 | scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference)) 322 | result = scores[0] 323 | 324 | for k, v in result.items(): 325 | score_dict[k].append(round(v["f"] * 100, 4)) 326 | bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) 327 | score_dict["bleu-4"].append(round(bleu_score * 100, 4)) 328 | 329 | for k, v in score_dict.items(): 330 | score_dict[k] = float(np.mean(v)) 331 | return score_dict 332 | 333 | # Override the decoding parameters of Seq2SeqTrainer 334 | training_args.generation_max_length = ( 335 | training_args.generation_max_length 336 | if training_args.generation_max_length is not None 337 | else data_args.val_max_target_length 338 | ) 339 | training_args.generation_num_beams = ( 340 | data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams 341 | ) 342 | # Initialize our Trainer 343 | trainer = Seq2SeqTrainer( 344 | model=model, 345 | args=training_args, 346 | train_dataset=train_dataset if training_args.do_train else None, 347 | eval_dataset=eval_dataset if training_args.do_eval else None, 348 | tokenizer=tokenizer, 349 | data_collator=data_collator, 350 | compute_metrics=compute_metrics if training_args.predict_with_generate else None, 351 | save_changed=model_args.pre_seq_len is not None 352 | ) 353 | 354 | # Training 355 | if training_args.do_train: 356 | checkpoint = None 357 | if training_args.resume_from_checkpoint is not None: 358 | checkpoint = training_args.resume_from_checkpoint 359 | # elif last_checkpoint is not None: 360 | # checkpoint = last_checkpoint 361 | model.gradient_checkpointing_enable() 362 | model.enable_input_require_grads() 363 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 364 | # trainer.save_model() # Saves the tokenizer too for easy upload 365 | 366 | metrics = train_result.metrics 367 | max_train_samples = ( 368 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 369 | ) 370 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 371 | 372 | trainer.log_metrics("train", metrics) 373 | trainer.save_metrics("train", metrics) 374 | trainer.save_state() 375 | 376 | # Evaluation 377 | results = {} 378 | max_seq_length = data_args.max_source_length + data_args.max_target_length + 1 379 | if training_args.do_eval: 380 | logger.info("*** Evaluate ***") 381 | metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=max_seq_length, temperature=0.95) 382 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 383 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 384 | 385 | trainer.log_metrics("eval", metrics) 386 | trainer.save_metrics("eval", metrics) 387 | 388 | if training_args.do_predict: 389 | logger.info("*** Predict ***") 390 | predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=max_seq_length, do_sample=True, top_p=0.7, temperature=0.95) 391 | metrics = predict_results.metrics 392 | max_predict_samples = ( 393 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) 394 | ) 395 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 396 | 397 | trainer.log_metrics("predict", metrics) 398 | trainer.save_metrics("predict", metrics) 399 | 400 | if trainer.is_world_process_zero(): 401 | if training_args.predict_with_generate: 402 | predictions = tokenizer.batch_decode( 403 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True 404 | ) 405 | predictions = [pred.strip() for pred in predictions] 406 | labels = tokenizer.batch_decode( 407 | predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True 408 | ) 409 | labels = [label.strip() for label in labels] 410 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt") 411 | with open(output_prediction_file, "w", encoding="utf-8") as writer: 412 | for p, l in zip(predictions, labels): 413 | res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False) 414 | writer.write(f"{res}\n") 415 | return results 416 | 417 | 418 | def _mp_fn(index): 419 | # For xla_spawn (TPUs) 420 | main() 421 | 422 | 423 | if __name__ == "__main__": 424 | main() 425 | -------------------------------------------------------------------------------- /chatglm_model_v1/run_ptuning.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2021 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for sequence to sequence. 18 | """ 19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. 20 | 21 | import logging 22 | import os 23 | import sys 24 | import json 25 | 26 | import numpy as np 27 | from datasets import load_dataset 28 | import jieba 29 | from rouge_chinese import Rouge 30 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 31 | import torch 32 | 33 | import transformers 34 | from transformers import ( 35 | AutoConfig, 36 | AutoModel, 37 | AutoTokenizer, 38 | DataCollatorForSeq2Seq, 39 | HfArgumentParser, 40 | Seq2SeqTrainingArguments, 41 | set_seed, 42 | ) 43 | from trainer_seq2seq import Seq2SeqTrainer 44 | 45 | from arguments import ModelArguments, DataTrainingArguments 46 | 47 | logger = logging.getLogger(__name__) 48 | 49 | def main(): 50 | 51 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) 52 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 53 | # If we pass only one argument to the script and it's the path to a json file, 54 | # let's parse it to get our arguments. 55 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 56 | else: 57 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 58 | 59 | # Setup logging 60 | logging.basicConfig( 61 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 62 | datefmt="%m/%d/%Y %H:%M:%S", 63 | handlers=[logging.StreamHandler(sys.stdout)], 64 | ) 65 | 66 | if training_args.should_log: 67 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 68 | transformers.utils.logging.set_verbosity_info() 69 | 70 | log_level = training_args.get_process_log_level() 71 | logger.setLevel(log_level) 72 | # datasets.utils.logging.set_verbosity(log_level) 73 | transformers.utils.logging.set_verbosity(log_level) 74 | transformers.utils.logging.enable_default_handler() 75 | transformers.utils.logging.enable_explicit_format() 76 | 77 | # Log on each process the small summary: 78 | logger.warning( 79 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 80 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 81 | ) 82 | logger.info(f"Training/evaluation parameters {training_args}") 83 | 84 | # Set seed before initializing model. 85 | set_seed(training_args.seed) 86 | 87 | # Load dataset 88 | data_files = {} 89 | if data_args.train_file is not None: 90 | data_files["train"] = data_args.train_file 91 | extension = data_args.train_file.split(".")[-1] 92 | if data_args.validation_file is not None: 93 | data_files["validation"] = data_args.validation_file 94 | extension = data_args.validation_file.split(".")[-1] 95 | if data_args.test_file is not None: 96 | data_files["test"] = data_args.test_file 97 | extension = data_args.test_file.split(".")[-1] 98 | 99 | raw_datasets = load_dataset( 100 | extension, 101 | data_files=data_files, 102 | cache_dir=model_args.cache_dir, 103 | use_auth_token=True if model_args.use_auth_token else None, 104 | ) 105 | 106 | # Load pretrained model and tokenizer 107 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) 108 | config.pre_seq_len = model_args.pre_seq_len 109 | config.prefix_projection = model_args.prefix_projection 110 | 111 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) 112 | 113 | if model_args.ptuning_checkpoint is not None: 114 | # Evaluation 115 | # Loading extra state dict of prefix encoder 116 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) 117 | prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin")) 118 | new_prefix_state_dict = {} 119 | for k, v in prefix_state_dict.items(): 120 | if k.startswith("transformer.prefix_encoder."): 121 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v 122 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) 123 | else: 124 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) 125 | 126 | if model_args.quantization_bit is not None: 127 | print(f"Quantized to {model_args.quantization_bit} bit") 128 | model = model.quantize(model_args.quantization_bit) 129 | if model_args.pre_seq_len is not None: 130 | # P-tuning v2 131 | model = model.half() 132 | model.transformer.prefix_encoder.float() 133 | else: 134 | # Finetune 135 | model = model.float() 136 | 137 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 138 | 139 | # Preprocessing the datasets. 140 | # We need to tokenize inputs and targets. 141 | if training_args.do_train: 142 | column_names = raw_datasets["train"].column_names 143 | elif training_args.do_eval: 144 | column_names = raw_datasets["validation"].column_names 145 | elif training_args.do_predict: 146 | column_names = raw_datasets["test"].column_names 147 | else: 148 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") 149 | return 150 | 151 | # Get the column names for input/target. 152 | prompt_column = data_args.prompt_column 153 | response_column = data_args.response_column 154 | history_column = data_args.history_column 155 | 156 | # Temporarily set max_target_length for training. 157 | max_target_length = data_args.max_target_length 158 | 159 | def preprocess_function_eval(examples): 160 | inputs, targets = [], [] 161 | for i in range(len(examples[prompt_column])): 162 | if examples[prompt_column][i] and examples[response_column][i]: 163 | query = examples[prompt_column][i] 164 | if history_column is None or len(examples[history_column][i]) == 0: 165 | prompt = query 166 | else: 167 | prompt = "" 168 | history = examples[history_column][i] 169 | for turn_idx, (old_query, response) in enumerate(history): 170 | prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response) 171 | prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) 172 | inputs.append(prompt) 173 | targets.append(examples[response_column][i]) 174 | 175 | inputs = [prefix + inp for inp in inputs] 176 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True) 177 | labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True) 178 | 179 | if data_args.ignore_pad_token_for_loss: 180 | labels["input_ids"] = [ 181 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 182 | ] 183 | model_inputs["labels"] = labels["input_ids"] 184 | 185 | return model_inputs 186 | 187 | def preprocess_function_train(examples): 188 | max_seq_length = data_args.max_source_length + data_args.max_target_length 189 | 190 | model_inputs = { 191 | "input_ids": [], 192 | "labels": [], 193 | } 194 | for i in range(len(examples[prompt_column])): 195 | if examples[prompt_column][i] and examples[response_column][i]: 196 | query, answer = examples[prompt_column][i], examples[response_column][i] 197 | 198 | if history_column is None: 199 | prompt = query 200 | else: 201 | prompt = "" 202 | history = examples[history_column][i] 203 | for turn_idx, (old_query, response) in enumerate(history): 204 | prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response) 205 | prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) 206 | 207 | prompt = prefix + prompt 208 | a_ids = tokenizer.encode(text=prompt, add_special_tokens=False) 209 | b_ids = tokenizer.encode(text=answer, add_special_tokens=False) 210 | 211 | if len(a_ids) > data_args.max_source_length - 1: 212 | a_ids = a_ids[: data_args.max_source_length - 1] 213 | 214 | if len(b_ids) > data_args.max_target_length - 2: 215 | b_ids = b_ids[: data_args.max_target_length - 2] 216 | 217 | input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids) 218 | 219 | context_length = input_ids.index(tokenizer.bos_token_id) 220 | mask_position = context_length - 1 221 | labels = [-100] * context_length + input_ids[mask_position+1:] 222 | 223 | pad_len = max_seq_length - len(input_ids) 224 | input_ids = input_ids + [tokenizer.pad_token_id] * pad_len 225 | labels = labels + [tokenizer.pad_token_id] * pad_len 226 | if data_args.ignore_pad_token_for_loss: 227 | labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels] 228 | 229 | model_inputs["input_ids"].append(input_ids) 230 | model_inputs["labels"].append(labels) 231 | 232 | return model_inputs 233 | 234 | def print_dataset_example(example): 235 | print("input_ids",example["input_ids"]) 236 | print("inputs", tokenizer.decode(example["input_ids"])) 237 | print("label_ids", example["labels"]) 238 | print("labels", tokenizer.decode(example["labels"])) 239 | 240 | base_cache_dir = os.path.join(data_args.base_cache_dir, data_args.task_name) 241 | if training_args.local_rank <= 0 and not os.path.exists(base_cache_dir): 242 | os.makedirs(base_cache_dir) 243 | 244 | if training_args.do_train: 245 | if "train" not in raw_datasets: 246 | raise ValueError("--do_train requires a train dataset") 247 | train_dataset = raw_datasets["train"] 248 | if data_args.max_train_samples is not None: 249 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 250 | train_dataset = train_dataset.select(range(max_train_samples)) 251 | with training_args.main_process_first(desc="train dataset map pre-processing"): 252 | train_dataset = train_dataset.map( 253 | preprocess_function_train, 254 | batched=True, 255 | num_proc=data_args.preprocessing_num_workers, 256 | remove_columns=column_names, 257 | load_from_cache_file=not data_args.overwrite_cache, 258 | desc="Running tokenizer on train dataset", 259 | cache_file_name=os.path.join(base_cache_dir, "train.arrow") 260 | ) 261 | print_dataset_example(train_dataset[0]) 262 | 263 | if training_args.do_eval: 264 | max_target_length = data_args.val_max_target_length 265 | if "validation" not in raw_datasets: 266 | raise ValueError("--do_eval requires a validation dataset") 267 | eval_dataset = raw_datasets["validation"] 268 | if data_args.max_eval_samples is not None: 269 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 270 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 271 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 272 | eval_dataset = eval_dataset.map( 273 | preprocess_function_eval, 274 | batched=True, 275 | num_proc=data_args.preprocessing_num_workers, 276 | remove_columns=column_names, 277 | load_from_cache_file=not data_args.overwrite_cache, 278 | desc="Running tokenizer on validation dataset", 279 | cache_file_name=os.path.join(base_cache_dir, "eval.arrow") 280 | ) 281 | print_dataset_example(eval_dataset[0]) 282 | 283 | if training_args.do_predict: 284 | max_target_length = data_args.val_max_target_length 285 | if "test" not in raw_datasets: 286 | raise ValueError("--do_predict requires a test dataset") 287 | predict_dataset = raw_datasets["test"] 288 | if data_args.max_predict_samples is not None: 289 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) 290 | predict_dataset = predict_dataset.select(range(max_predict_samples)) 291 | with training_args.main_process_first(desc="prediction dataset map pre-processing"): 292 | predict_dataset = predict_dataset.map( 293 | preprocess_function_eval, 294 | batched=True, 295 | num_proc=data_args.preprocessing_num_workers, 296 | remove_columns=column_names, 297 | load_from_cache_file=not data_args.overwrite_cache, 298 | desc="Running tokenizer on prediction dataset", 299 | cache_file_name=os.path.join(base_cache_dir, "predict.arrow") 300 | ) 301 | print_dataset_example(predict_dataset[0]) 302 | 303 | # Data collator 304 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 305 | data_collator = DataCollatorForSeq2Seq( 306 | tokenizer, 307 | model=model, 308 | label_pad_token_id=label_pad_token_id, 309 | pad_to_multiple_of=None, 310 | padding=False 311 | ) 312 | 313 | # Metric 314 | def compute_metrics(eval_preds): 315 | preds, labels = eval_preds 316 | if isinstance(preds, tuple): 317 | preds = preds[0] 318 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 319 | if data_args.ignore_pad_token_for_loss: 320 | # Replace -100 in the labels as we can't decode them. 321 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 322 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 323 | 324 | score_dict = { 325 | "rouge-1": [], 326 | "rouge-2": [], 327 | "rouge-l": [], 328 | "bleu-4": [] 329 | } 330 | for pred, label in zip(decoded_preds, decoded_labels): 331 | hypothesis = list(jieba.cut(pred)) 332 | reference = list(jieba.cut(label)) 333 | rouge = Rouge() 334 | scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference)) 335 | result = scores[0] 336 | 337 | for k, v in result.items(): 338 | score_dict[k].append(round(v["f"] * 100, 4)) 339 | bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) 340 | score_dict["bleu-4"].append(round(bleu_score * 100, 4)) 341 | 342 | for k, v in score_dict.items(): 343 | score_dict[k] = float(np.mean(v)) 344 | return score_dict 345 | 346 | # Override the decoding parameters of Seq2SeqTrainer 347 | training_args.generation_max_length = ( 348 | training_args.generation_max_length 349 | if training_args.generation_max_length is not None 350 | else data_args.val_max_target_length 351 | ) 352 | training_args.generation_num_beams = ( 353 | data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams 354 | ) 355 | # Initialize our Trainer 356 | trainer = Seq2SeqTrainer( 357 | model=model, 358 | args=training_args, 359 | train_dataset=train_dataset if training_args.do_train else None, 360 | eval_dataset=eval_dataset if training_args.do_eval else None, 361 | tokenizer=tokenizer, 362 | data_collator=data_collator, 363 | compute_metrics=compute_metrics if training_args.predict_with_generate else None, 364 | save_prefixencoder=model_args.pre_seq_len is not None 365 | ) 366 | 367 | # Training 368 | if training_args.do_train: 369 | checkpoint = None 370 | if training_args.resume_from_checkpoint is not None: 371 | checkpoint = training_args.resume_from_checkpoint 372 | # elif last_checkpoint is not None: 373 | # checkpoint = last_checkpoint 374 | model.gradient_checkpointing_enable() 375 | model.enable_input_require_grads() 376 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 377 | # trainer.save_model() # Saves the tokenizer too for easy upload 378 | 379 | metrics = train_result.metrics 380 | max_train_samples = ( 381 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 382 | ) 383 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 384 | 385 | trainer.log_metrics("train", metrics) 386 | trainer.save_metrics("train", metrics) 387 | trainer.save_state() 388 | 389 | # Evaluation 390 | results = {} 391 | max_seq_length = data_args.max_source_length + data_args.max_target_length + 1 392 | if training_args.do_eval: 393 | logger.info("*** Evaluate ***") 394 | metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=max_seq_length, temperature=0.95) 395 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 396 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 397 | 398 | trainer.log_metrics("eval", metrics) 399 | trainer.save_metrics("eval", metrics) 400 | 401 | if training_args.do_predict: 402 | logger.info("*** Predict ***") 403 | predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=max_seq_length, do_sample=True, top_p=0.7, temperature=0.95) 404 | metrics = predict_results.metrics 405 | max_predict_samples = ( 406 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) 407 | ) 408 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 409 | 410 | trainer.log_metrics("predict", metrics) 411 | trainer.save_metrics("predict", metrics) 412 | 413 | if trainer.is_world_process_zero(): 414 | if training_args.predict_with_generate: 415 | predictions = tokenizer.batch_decode( 416 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True 417 | ) 418 | predictions = [pred.strip() for pred in predictions] 419 | labels = tokenizer.batch_decode( 420 | predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True 421 | ) 422 | labels = [label.strip() for label in labels] 423 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt") 424 | with open(output_prediction_file, "w", encoding="utf-8") as writer: 425 | for p, l in zip(predictions, labels): 426 | res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False) 427 | writer.write(f"{res}\n") 428 | return results 429 | 430 | 431 | def _mp_fn(index): 432 | # For xla_spawn (TPUs) 433 | main() 434 | 435 | 436 | if __name__ == "__main__": 437 | main() 438 | -------------------------------------------------------------------------------- /chatglm_model_v2/run_peft.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2021 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for sequence to sequence. 18 | """ 19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. 20 | 21 | import logging 22 | import os 23 | import sys 24 | import json 25 | 26 | import numpy as np 27 | from datasets import load_dataset 28 | import jieba 29 | from rouge_chinese import Rouge 30 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 31 | import torch 32 | 33 | import transformers 34 | from transformers import ( 35 | AutoConfig, 36 | AutoModel, 37 | AutoTokenizer, 38 | DataCollatorForSeq2Seq, 39 | HfArgumentParser, 40 | Seq2SeqTrainingArguments, 41 | set_seed, 42 | ) 43 | 44 | # 添加PEFT配置 45 | from peft import ( 46 | LoraConfig, 47 | PrefixTuningConfig, 48 | PromptEncoderConfig, 49 | PromptEncoderReparameterizationType, 50 | PromptTuningConfig, 51 | PromptTuningInit, 52 | TaskType, 53 | get_peft_model, 54 | ) 55 | 56 | from trainer_seq2seq import Seq2SeqTrainer 57 | 58 | from arguments import ModelArguments, DataTrainingArguments, PeftArguments 59 | 60 | logger = logging.getLogger(__name__) 61 | 62 | def main(): 63 | # 加载模型、训练和数据参数配置 64 | # loading model, training and data augments 65 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, PeftArguments)) 66 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 67 | # If we pass only one argument to the script and it's the path to a json file, 68 | # let's parse it to get our arguments. 69 | model_args, data_args, training_args, peft_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 70 | else: 71 | model_args, data_args, training_args, peft_args = parser.parse_args_into_dataclasses() 72 | 73 | # Setup logging 74 | logging.basicConfig( 75 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 76 | datefmt="%m/%d/%Y %H:%M:%S", 77 | handlers=[logging.StreamHandler(sys.stdout)], 78 | ) 79 | 80 | if training_args.should_log: 81 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 82 | transformers.utils.logging.set_verbosity_info() 83 | 84 | log_level = training_args.get_process_log_level() 85 | logger.setLevel(log_level) 86 | # datasets.utils.logging.set_verbosity(log_level) 87 | transformers.utils.logging.set_verbosity(log_level) 88 | transformers.utils.logging.enable_default_handler() 89 | transformers.utils.logging.enable_explicit_format() 90 | 91 | # Log on each process the small summary: 92 | logger.warning( 93 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 94 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 95 | ) 96 | logger.info(f"Training/evaluation parameters {training_args}") 97 | 98 | # Set seed before initializing model. 99 | set_seed(training_args.seed) 100 | 101 | # Load dataset 102 | data_files = {} 103 | if data_args.train_file is not None: 104 | data_files["train"] = data_args.train_file 105 | extension = data_args.train_file.split(".")[-1] 106 | if data_args.validation_file is not None: 107 | data_files["validation"] = data_args.validation_file 108 | extension = data_args.validation_file.split(".")[-1] 109 | if data_args.test_file is not None: 110 | data_files["test"] = data_args.test_file 111 | extension = data_args.test_file.split(".")[-1] 112 | 113 | # 读取为hugging face格式的数据 114 | raw_datasets = load_dataset( 115 | extension, 116 | data_files=data_files, 117 | cache_dir=model_args.cache_dir, 118 | use_auth_token=True if model_args.use_auth_token else None, 119 | ) 120 | 121 | # Load pretrained model and tokenizer 122 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) 123 | config.pre_seq_len = model_args.pre_seq_len 124 | config.prefix_projection = model_args.prefix_projection 125 | 126 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) 127 | 128 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) 129 | 130 | # 使用PEFT 131 | if peft_args.peft_type is not None: 132 | if peft_args.peft_type == "lora": 133 | peft_config = LoraConfig( 134 | task_type=TaskType.CAUSAL_LM, 135 | inference_mode=False, 136 | r=peft_args.lora_dim, 137 | lora_alpha=32, 138 | lora_dropout=0.1 139 | ) 140 | elif peft_args.peft_type == "ptuning": 141 | peft_config = PromptEncoderConfig( 142 | task_type=TaskType.CAUSAL_LM, 143 | inference_mode=False, 144 | encoder_reparameterization_type=PromptEncoderReparameterizationType.MLP, # 默认使用MLP表征Prompt 145 | encoder_num_layers=2, 146 | encoder_dropout=0.1, 147 | num_virtual_tokens=8 # soft prompt token数量 148 | ) 149 | elif peft_args.peft_type == "prefix": 150 | peft_config = PrefixTuningConfig( 151 | task_type=TaskType.CAUSAL_LM, 152 | inference_mode=False, 153 | num_virtual_tokens=4, 154 | num_attention_heads=12, 155 | num_layers=48, 156 | encoder_hidden_size=768, 157 | token_dim=1536, 158 | ) 159 | elif peft_args.peft_type == "prompt": 160 | peft_config = PromptTuningConfig( 161 | task_type=TaskType.CAUSAL_LM, 162 | inference_mode=False, 163 | prompt_tuning_init=PromptTuningInit.TEXT if peft_args.prompt_tuning_initial_text is not None else PromptTuningInit.RANDOM 164 | ) 165 | elif peft_args.peft_type == "adalora": 166 | raise NotImplementedError("Adalora is under developing") 167 | else: 168 | raise NotImplementedError("you must choose one of parameter-efficient learning method") 169 | 170 | logger.info("You have chosen {} as peft type, here is loading model ...".format(peft_args.peft_type)) 171 | model = get_peft_model(model, peft_config=peft_config) 172 | logger.info("Reduing trainable parameters: {}".format(model.print_trainable_parameters)) 173 | 174 | if model_args.quantization_bit is not None: 175 | print(f"Quantized to {model_args.quantization_bit} bit") 176 | model = model.quantize(model_args.quantization_bit) 177 | 178 | 179 | # if model_args.pre_seq_len is not None: 180 | # # P-tuning v2 181 | # model = model.half() # 开启半精度模式 182 | # model.transformer.prefix_encoder.float() 183 | # else: 184 | # # Finetune 185 | # model = model.float() 186 | model = model.float() 187 | 188 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 189 | 190 | # Preprocessing the datasets. 191 | # We need to tokenize inputs and targets. 192 | if training_args.do_train: 193 | column_names = raw_datasets["train"].column_names 194 | elif training_args.do_eval: 195 | column_names = raw_datasets["validation"].column_names 196 | elif training_args.do_predict: 197 | column_names = raw_datasets["test"].column_names 198 | else: 199 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") 200 | return 201 | 202 | # Get the column names for input/target. 203 | prompt_column = data_args.prompt_column 204 | response_column = data_args.response_column 205 | history_column = data_args.history_column 206 | 207 | # Temporarily set max_target_length for training. 208 | max_target_length = data_args.max_target_length 209 | 210 | def preprocess_function_eval(examples): 211 | inputs, targets = [], [] 212 | for i in range(len(examples[prompt_column])): 213 | if examples[prompt_column][i] and examples[response_column][i]: 214 | query = examples[prompt_column][i] 215 | history = examples[history_column][i] if history_column is not None else None 216 | prompt = tokenizer.build_prompt(query, history) 217 | inputs.append(prompt) 218 | targets.append(examples[response_column][i]) 219 | 220 | inputs = [prefix + inp for inp in inputs] 221 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True) 222 | labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True) 223 | 224 | if data_args.ignore_pad_token_for_loss: 225 | labels["input_ids"] = [ 226 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 227 | ] 228 | model_inputs["labels"] = labels["input_ids"] 229 | 230 | return model_inputs 231 | 232 | def preprocess_function_train(examples): 233 | max_seq_length = data_args.max_source_length + data_args.max_target_length + 1 234 | 235 | model_inputs = { 236 | "input_ids": [], 237 | "labels": [], 238 | } 239 | for i in range(len(examples[prompt_column])): 240 | if examples[prompt_column][i] and examples[response_column][i]: 241 | query, answer = examples[prompt_column][i], examples[response_column][i] 242 | 243 | history = examples[history_column][i] if history_column is not None else None 244 | prompt = tokenizer.build_prompt(query, history) 245 | 246 | prompt = prefix + prompt 247 | a_ids = tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True, 248 | max_length=data_args.max_source_length) 249 | b_ids = tokenizer.encode(text=answer, add_special_tokens=False, truncation=True, 250 | max_length=data_args.max_target_length) 251 | 252 | context_length = len(a_ids) 253 | input_ids = a_ids + b_ids + [tokenizer.eos_token_id] 254 | labels = [tokenizer.pad_token_id] * context_length + b_ids + [tokenizer.eos_token_id] 255 | 256 | pad_len = max_seq_length - len(input_ids) 257 | input_ids = input_ids + [tokenizer.pad_token_id] * pad_len 258 | labels = labels + [tokenizer.pad_token_id] * pad_len 259 | if data_args.ignore_pad_token_for_loss: 260 | labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels] 261 | 262 | model_inputs["input_ids"].append(input_ids) 263 | model_inputs["labels"].append(labels) 264 | 265 | return model_inputs 266 | 267 | def print_dataset_example(example): 268 | print("input_ids", example["input_ids"]) 269 | print("inputs", tokenizer.decode(example["input_ids"])) 270 | print("label_ids", example["labels"]) 271 | print("labels", tokenizer.decode(example["labels"])) 272 | 273 | base_cache_dir = os.path(data_args.base_cache_dir, data_args.task_name) 274 | if training_args.local_rank <= 0 and not os.path.exists(base_cache_dir): 275 | os.makedirs(base_cache_dir) 276 | 277 | if training_args.do_train: 278 | if "train" not in raw_datasets: 279 | raise ValueError("--do_train requires a train dataset") 280 | train_dataset = raw_datasets["train"] 281 | if data_args.max_train_samples is not None: 282 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 283 | train_dataset = train_dataset.select(range(max_train_samples)) 284 | with training_args.main_process_first(desc="train dataset map pre-processing"): 285 | train_dataset = train_dataset.map( 286 | preprocess_function_train, 287 | batched=True, 288 | num_proc=data_args.preprocessing_num_workers, 289 | remove_columns=column_names, 290 | load_from_cache_file=not data_args.overwrite_cache, 291 | desc="Running tokenizer on train dataset", 292 | cache_file_name=os.path.join(base_cache_dir, "train.arrow") 293 | ) 294 | print_dataset_example(train_dataset[0]) 295 | 296 | if training_args.do_eval: 297 | max_target_length = data_args.val_max_target_length 298 | if "validation" not in raw_datasets: 299 | raise ValueError("--do_eval requires a validation dataset") 300 | eval_dataset = raw_datasets["validation"] 301 | if data_args.max_eval_samples is not None: 302 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 303 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 304 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 305 | eval_dataset = eval_dataset.map( 306 | preprocess_function_eval, 307 | batched=True, 308 | num_proc=data_args.preprocessing_num_workers, 309 | remove_columns=column_names, 310 | load_from_cache_file=not data_args.overwrite_cache, 311 | desc="Running tokenizer on validation dataset", 312 | cache_file_name=os.path.join(base_cache_dir, "eval.arrow") 313 | ) 314 | print_dataset_example(eval_dataset[0]) 315 | 316 | if training_args.do_predict: 317 | max_target_length = data_args.val_max_target_length 318 | if "test" not in raw_datasets: 319 | raise ValueError("--do_predict requires a test dataset") 320 | predict_dataset = raw_datasets["test"] 321 | if data_args.max_predict_samples is not None: 322 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) 323 | predict_dataset = predict_dataset.select(range(max_predict_samples)) 324 | with training_args.main_process_first(desc="prediction dataset map pre-processing"): 325 | predict_dataset = predict_dataset.map( 326 | preprocess_function_eval, 327 | batched=True, 328 | num_proc=data_args.preprocessing_num_workers, 329 | remove_columns=column_names, 330 | load_from_cache_file=not data_args.overwrite_cache, 331 | desc="Running tokenizer on prediction dataset", 332 | cache_file_name=os.path.join(base_cache_dir, "predict.arrow") 333 | ) 334 | print_dataset_example(predict_dataset[0]) 335 | 336 | # Data collator 337 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 338 | data_collator = DataCollatorForSeq2Seq( 339 | tokenizer, 340 | model=model, 341 | label_pad_token_id=label_pad_token_id, 342 | pad_to_multiple_of=None, 343 | padding=False 344 | ) 345 | 346 | # Metric 347 | def compute_metrics(eval_preds): 348 | preds, labels = eval_preds 349 | if isinstance(preds, tuple): 350 | preds = preds[0] 351 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 352 | if data_args.ignore_pad_token_for_loss: 353 | # Replace -100 in the labels as we can't decode them. 354 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 355 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 356 | 357 | score_dict = { 358 | "rouge-1": [], 359 | "rouge-2": [], 360 | "rouge-l": [], 361 | "bleu-4": [] 362 | } 363 | for pred, label in zip(decoded_preds, decoded_labels): 364 | hypothesis = list(jieba.cut(pred)) 365 | reference = list(jieba.cut(label)) 366 | rouge = Rouge() 367 | scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference)) 368 | result = scores[0] 369 | 370 | for k, v in result.items(): 371 | score_dict[k].append(round(v["f"] * 100, 4)) 372 | bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) 373 | score_dict["bleu-4"].append(round(bleu_score * 100, 4)) 374 | 375 | for k, v in score_dict.items(): 376 | score_dict[k] = float(np.mean(v)) 377 | return score_dict 378 | 379 | # Override the decoding parameters of Seq2SeqTrainer 380 | training_args.generation_max_length = ( 381 | training_args.generation_max_length 382 | if training_args.generation_max_length is not None 383 | else data_args.val_max_target_length 384 | ) 385 | training_args.generation_num_beams = ( 386 | data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams 387 | ) 388 | # Initialize our Trainer 389 | trainer = Seq2SeqTrainer( 390 | model=model, 391 | args=training_args, 392 | train_dataset=train_dataset if training_args.do_train else None, 393 | eval_dataset=eval_dataset if training_args.do_eval else None, 394 | tokenizer=tokenizer, 395 | data_collator=data_collator, 396 | compute_metrics=compute_metrics if training_args.predict_with_generate else None, 397 | save_changed=model_args.pre_seq_len is not None 398 | ) 399 | 400 | # Training 401 | if training_args.do_train: 402 | checkpoint = None 403 | if training_args.resume_from_checkpoint is not None: 404 | checkpoint = training_args.resume_from_checkpoint 405 | # elif last_checkpoint is not None: 406 | # checkpoint = last_checkpoint 407 | model.gradient_checkpointing_enable() 408 | model.enable_input_require_grads() 409 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 410 | # trainer.save_model() # Saves the tokenizer too for easy upload 411 | 412 | metrics = train_result.metrics 413 | max_train_samples = ( 414 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 415 | ) 416 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 417 | 418 | trainer.log_metrics("train", metrics) 419 | trainer.save_metrics("train", metrics) 420 | trainer.save_state() 421 | 422 | # Evaluation 423 | results = {} 424 | max_seq_length = data_args.max_source_length + data_args.max_target_length + 1 425 | if training_args.do_eval: 426 | logger.info("*** Evaluate ***") 427 | metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=max_seq_length, temperature=0.95) 428 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 429 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 430 | 431 | trainer.log_metrics("eval", metrics) 432 | trainer.save_metrics("eval", metrics) 433 | 434 | if training_args.do_predict: 435 | logger.info("*** Predict ***") 436 | predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=max_seq_length, do_sample=True, top_p=0.7, temperature=0.95) 437 | metrics = predict_results.metrics 438 | max_predict_samples = ( 439 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) 440 | ) 441 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 442 | 443 | trainer.log_metrics("predict", metrics) 444 | trainer.save_metrics("predict", metrics) 445 | 446 | if trainer.is_world_process_zero(): 447 | if training_args.predict_with_generate: 448 | predictions = tokenizer.batch_decode( 449 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True 450 | ) 451 | predictions = [pred.strip() for pred in predictions] 452 | labels = tokenizer.batch_decode( 453 | predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True 454 | ) 455 | labels = [label.strip() for label in labels] 456 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt") 457 | with open(output_prediction_file, "w", encoding="utf-8") as writer: 458 | for p, l in zip(predictions, labels): 459 | res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False) 460 | writer.write(f"{res}\n") 461 | return results 462 | 463 | 464 | def _mp_fn(index): 465 | # For xla_spawn (TPUs) 466 | main() 467 | 468 | 469 | if __name__ == "__main__": 470 | main() 471 | --------------------------------------------------------------------------------