├── .DS_Store
├── resources
├── math.png
├── cli-demo.png
├── web-demo.gif
├── web-demo.png
├── wechat.jpg
├── knowledge.png
├── long-context.png
└── WECHAT.md
├── requirements.txt
├── scripts
├── web_demo.sh
├── train.sh
├── train_chat.sh
├── train_ptuning.sh
├── train_lora.sh
├── ds_train_ptuning.sh
└── ds_train_peft.sh
├── chatglm_model_v1
├── web_demo.sh
├── deepspeed.json
├── evaluate_finetune.sh
├── evaluate.sh
├── train.sh
├── train_chat.sh
├── ds_train_finetune.sh
├── web_demo.py
├── README.md
├── README_en.md
├── arguments.py
├── trainer_seq2seq.py
└── run_ptuning.py
├── evaluation
├── README.md
└── evaluate_ceval.py
├── deepspeed
└── deepspeed.json
├── chatglm_model_v2
├── evaluate_finetune.sh
├── evaluate.sh
├── trainer.py
├── README.md
├── web_demo.py
├── arguments.py
├── trainer_seq2seq.py
├── run_ptuning.py
└── run_peft.py
├── FAQ.md
├── application
├── cli_demo.py
├── api.py
├── web_demo2.py
├── web_demo.py
└── openai_api.py
├── utils.py
├── MODEL_LICENSE
├── README.md
└── LICENSE
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wjn1996/ChatGLM2-Tuning/HEAD/.DS_Store
--------------------------------------------------------------------------------
/resources/math.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wjn1996/ChatGLM2-Tuning/HEAD/resources/math.png
--------------------------------------------------------------------------------
/resources/cli-demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wjn1996/ChatGLM2-Tuning/HEAD/resources/cli-demo.png
--------------------------------------------------------------------------------
/resources/web-demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wjn1996/ChatGLM2-Tuning/HEAD/resources/web-demo.gif
--------------------------------------------------------------------------------
/resources/web-demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wjn1996/ChatGLM2-Tuning/HEAD/resources/web-demo.png
--------------------------------------------------------------------------------
/resources/wechat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wjn1996/ChatGLM2-Tuning/HEAD/resources/wechat.jpg
--------------------------------------------------------------------------------
/resources/knowledge.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wjn1996/ChatGLM2-Tuning/HEAD/resources/knowledge.png
--------------------------------------------------------------------------------
/resources/long-context.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wjn1996/ChatGLM2-Tuning/HEAD/resources/long-context.png
--------------------------------------------------------------------------------
/resources/WECHAT.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
扫码关注公众号,加入「ChatGLM交流群」
5 |
Scan the QR code to follow the official account and join the "ChatGLM Discussion Group"
6 |
7 |
8 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | protobuf
2 | transformers==4.30.2
3 | cpm_kernels
4 | torch>=2.0
5 | gradio
6 | mdtex2html
7 | sentencepiece
8 | accelerate
9 | sse-starlette
10 | streamlit>=1.24.0
11 | rouge_chinese
12 | nltk
13 | jieba
14 | datasets
15 | peft
--------------------------------------------------------------------------------
/scripts/web_demo.sh:
--------------------------------------------------------------------------------
1 | PRE_SEQ_LEN=128
2 |
3 | CUDA_VISIBLE_DEVICES=0 python3 web_demo.py \
4 | --model_name_or_path THUDM/chatglm2-6b \
5 | --ptuning_checkpoint output/adgen-chatglm2-6b-pt-128-2e-2/checkpoint-3000 \
6 | --pre_seq_len $PRE_SEQ_LEN
7 |
8 |
--------------------------------------------------------------------------------
/chatglm_model_v1/web_demo.sh:
--------------------------------------------------------------------------------
1 | PRE_SEQ_LEN=128
2 |
3 | CUDA_VISIBLE_DEVICES=0 python3 web_demo.py \
4 | --model_name_or_path THUDM/chatglm-6b \
5 | --ptuning_checkpoint output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000 \
6 | --pre_seq_len $PRE_SEQ_LEN
7 |
8 |
--------------------------------------------------------------------------------
/evaluation/README.md:
--------------------------------------------------------------------------------
1 | 首先从 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/e84444333b6d434ea7b0) 下载处理好的 C-Eval 数据集,解压到 `evaluation` 目录下。然后运行
2 |
3 | ```shell
4 | cd evaluation
5 | python evaluate_ceval.py
6 | ```
7 |
8 | 这个脚本会在C-Eval的验证集上进行预测并输出准确率。如果想要得到测试集上的结果可以将代码中的 `./CEval/val/**/*.jsonl` 改为 `./CEval/test/**/*.jsonl`,并按照 C-Eval 规定的格式保存结果并在 [官网](https://cevalbenchmark.com/) 上提交。
9 |
10 | 汇报的结果使用的是内部的并行测试框架,结果可能会有轻微波动。
--------------------------------------------------------------------------------
/deepspeed/deepspeed.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_micro_batch_size_per_gpu": "auto",
3 | "zero_allow_untested_optimizer": true,
4 | "fp16": {
5 | "enabled": "auto",
6 | "loss_scale": 0,
7 | "initial_scale_power": 16,
8 | "loss_scale_window": 1000,
9 | "hysteresis": 2,
10 | "min_loss_scale": 1
11 | },
12 | "zero_optimization": {
13 | "stage": 2,
14 | "allgather_partitions": true,
15 | "allgather_bucket_size": 5e8,
16 | "overlap_comm": false,
17 | "reduce_scatter": true,
18 | "reduce_bucket_size": 5e8,
19 | "contiguous_gradients" : true
20 | }
21 | }
--------------------------------------------------------------------------------
/chatglm_model_v1/deepspeed.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_micro_batch_size_per_gpu": "auto",
3 | "zero_allow_untested_optimizer": true,
4 | "fp16": {
5 | "enabled": "auto",
6 | "loss_scale": 0,
7 | "initial_scale_power": 16,
8 | "loss_scale_window": 1000,
9 | "hysteresis": 2,
10 | "min_loss_scale": 1
11 | },
12 | "zero_optimization": {
13 | "stage": 2,
14 | "allgather_partitions": true,
15 | "allgather_bucket_size": 5e8,
16 | "overlap_comm": false,
17 | "reduce_scatter": true,
18 | "reduce_bucket_size": 5e8,
19 | "contiguous_gradients" : true
20 | }
21 | }
--------------------------------------------------------------------------------
/chatglm_model_v1/evaluate_finetune.sh:
--------------------------------------------------------------------------------
1 | CHECKPOINT=adgen-chatglm-6b-ft-1e-4
2 | STEP=3000
3 |
4 | CUDA_VISIBLE_DEVICES=0 python3 main.py \
5 | --do_predict \
6 | --validation_file AdvertiseGen/dev.json \
7 | --test_file AdvertiseGen/dev.json \
8 | --overwrite_cache \
9 | --prompt_column content \
10 | --response_column summary \
11 | --model_name_or_path ./output/$CHECKPOINT/checkpoint-$STEP \
12 | --output_dir ./output/$CHECKPOINT \
13 | --overwrite_output_dir \
14 | --max_source_length 256 \
15 | --max_target_length 256 \
16 | --per_device_eval_batch_size 1 \
17 | --predict_with_generate \
18 | --fp16_full_eval
19 |
--------------------------------------------------------------------------------
/chatglm_model_v2/evaluate_finetune.sh:
--------------------------------------------------------------------------------
1 | CHECKPOINT=adgen-chatglm2-6b-ft-1e-4
2 | STEP=3000
3 | NUM_GPUS=1
4 |
5 | torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \
6 | --do_predict \
7 | --validation_file AdvertiseGen/dev.json \
8 | --test_file AdvertiseGen/dev.json \
9 | --overwrite_cache \
10 | --prompt_column content \
11 | --response_column summary \
12 | --model_name_or_path ./output/$CHECKPOINT/checkpoint-$STEP \
13 | --output_dir ./output/$CHECKPOINT \
14 | --overwrite_output_dir \
15 | --max_source_length 256 \
16 | --max_target_length 256 \
17 | --per_device_eval_batch_size 1 \
18 | --predict_with_generate \
19 | --fp16_full_eval
20 |
--------------------------------------------------------------------------------
/FAQ.md:
--------------------------------------------------------------------------------
1 | ## Q1
2 |
3 | **Mac直接加载量化后的模型出现提示 `clang: error: unsupported option '-fopenmp'**
4 |
5 | 这是由于Mac由于本身缺乏omp导致的,此时可运行但是单核。需要单独安装 openmp 依赖,即可在Mac下使用OMP:
6 |
7 | ```bash
8 | # 参考`https://mac.r-project.org/openmp/`
9 | ## 假设: gcc(clang)是14.x版本,其他版本见R-Project提供的表格
10 | curl -O https://mac.r-project.org/openmp/openmp-14.0.6-darwin20-Release.tar.gz
11 | sudo tar fvxz openmp-14.0.6-darwin20-Release.tar.gz -C /
12 | ```
13 | 此时会安装下面几个文件:`/usr/local/lib/libomp.dylib`, `/usr/local/include/ompt.h`, `/usr/local/include/omp.h`, `/usr/local/include/omp-tools.h`。
14 |
15 | > 注意:如果你之前运行`ChatGLM2-6B`项目失败过,最好清一下Hugging Face的缓存,i.e. 默认下是 `rm -rf ${HOME}/.cache/huggingface/modules/transformers_modules/chatglm-6b-int4`。由于使用了`rm`命令,请明确知道自己在删除什么。
16 |
--------------------------------------------------------------------------------
/chatglm_model_v1/evaluate.sh:
--------------------------------------------------------------------------------
1 | PRE_SEQ_LEN=128
2 | CHECKPOINT=adgen-chatglm-6b-pt-128-2e-2
3 | STEP=3000
4 |
5 | CUDA_VISIBLE_DEVICES=0 python3 main.py \
6 | --do_predict \
7 | --validation_file AdvertiseGen/dev.json \
8 | --test_file AdvertiseGen/dev.json \
9 | --overwrite_cache \
10 | --prompt_column content \
11 | --response_column summary \
12 | --model_name_or_path THUDM/chatglm-6b \
13 | --ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \
14 | --output_dir ./output/$CHECKPOINT \
15 | --overwrite_output_dir \
16 | --max_source_length 64 \
17 | --max_target_length 64 \
18 | --per_device_eval_batch_size 1 \
19 | --predict_with_generate \
20 | --pre_seq_len $PRE_SEQ_LEN \
21 | --quantization_bit 4
22 |
--------------------------------------------------------------------------------
/chatglm_model_v2/evaluate.sh:
--------------------------------------------------------------------------------
1 | PRE_SEQ_LEN=128
2 | CHECKPOINT=adgen-chatglm2-6b-pt-128-2e-2
3 | STEP=3000
4 | NUM_GPUS=1
5 |
6 | torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \
7 | --do_predict \
8 | --validation_file AdvertiseGen/dev.json \
9 | --test_file AdvertiseGen/dev.json \
10 | --overwrite_cache \
11 | --prompt_column content \
12 | --response_column summary \
13 | --model_name_or_path THUDM/chatglm2-6b \
14 | --ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \
15 | --output_dir ./output/$CHECKPOINT \
16 | --overwrite_output_dir \
17 | --max_source_length 64 \
18 | --max_target_length 64 \
19 | --per_device_eval_batch_size 1 \
20 | --predict_with_generate \
21 | --pre_seq_len $PRE_SEQ_LEN \
22 | --quantization_bit 4
23 |
--------------------------------------------------------------------------------
/chatglm_model_v1/train.sh:
--------------------------------------------------------------------------------
1 | PRE_SEQ_LEN=128
2 | LR=2e-2
3 |
4 | CUDA_VISIBLE_DEVICES=0 python3 main.py \
5 | --do_train \
6 | --train_file AdvertiseGen/train.json \
7 | --validation_file AdvertiseGen/dev.json \
8 | --prompt_column content \
9 | --response_column summary \
10 | --overwrite_cache \
11 | --model_name_or_path THUDM/chatglm-6b \
12 | --output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
13 | --overwrite_output_dir \
14 | --max_source_length 64 \
15 | --max_target_length 64 \
16 | --per_device_train_batch_size 1 \
17 | --per_device_eval_batch_size 1 \
18 | --gradient_accumulation_steps 16 \
19 | --predict_with_generate \
20 | --max_steps 3000 \
21 | --logging_steps 10 \
22 | --save_steps 1000 \
23 | --learning_rate $LR \
24 | --pre_seq_len $PRE_SEQ_LEN \
25 | --quantization_bit 4
26 |
27 |
--------------------------------------------------------------------------------
/chatglm_model_v1/train_chat.sh:
--------------------------------------------------------------------------------
1 | PRE_SEQ_LEN=128
2 | LR=1e-2
3 |
4 | CUDA_VISIBLE_DEVICES=0 python3 main.py \
5 | --do_train \
6 | --train_file $CHAT_TRAIN_DATA \
7 | --validation_file $CHAT_VAL_DATA \
8 | --prompt_column prompt \
9 | --response_column response \
10 | --history_column history \
11 | --overwrite_cache \
12 | --model_name_or_path THUDM/chatglm-6b \
13 | --output_dir $CHECKPOINT_NAME \
14 | --overwrite_output_dir \
15 | --max_source_length 256 \
16 | --max_target_length 256 \
17 | --per_device_train_batch_size 1 \
18 | --per_device_eval_batch_size 1 \
19 | --gradient_accumulation_steps 16 \
20 | --predict_with_generate \
21 | --max_steps 3000 \
22 | --logging_steps 10 \
23 | --save_steps 1000 \
24 | --learning_rate $LR \
25 | --pre_seq_len $PRE_SEQ_LEN \
26 | --quantization_bit 4
27 |
28 |
--------------------------------------------------------------------------------
/chatglm_model_v1/ds_train_finetune.sh:
--------------------------------------------------------------------------------
1 |
2 | LR=1e-4
3 |
4 | MASTER_PORT=$(shuf -n 1 -i 10000-65535)
5 |
6 | deepspeed --num_gpus=4 --master_port $MASTER_PORT main.py \
7 | --deepspeed deepspeed.json \
8 | --do_train \
9 | --train_file AdvertiseGen/train.json \
10 | --test_file AdvertiseGen/dev.json \
11 | --prompt_column content \
12 | --response_column summary \
13 | --overwrite_cache \
14 | --model_name_or_path THUDM/chatglm-6b \
15 | --output_dir ./output/adgen-chatglm-6b-ft-$LR \
16 | --overwrite_output_dir \
17 | --max_source_length 64 \
18 | --max_target_length 64 \
19 | --per_device_train_batch_size 4 \
20 | --per_device_eval_batch_size 1 \
21 | --gradient_accumulation_steps 1 \
22 | --predict_with_generate \
23 | --max_steps 5000 \
24 | --logging_steps 10 \
25 | --save_steps 1000 \
26 | --learning_rate $LR \
27 | --fp16
28 |
29 |
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | PRE_SEQ_LEN=128
2 | LR=2e-2
3 | NUM_GPUS=1
4 |
5 | torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \
6 | --do_train \
7 | --train_file AdvertiseGen/train.json \
8 | --validation_file AdvertiseGen/dev.json \
9 | --preprocessing_num_workers 10 \
10 | --prompt_column content \
11 | --response_column summary \
12 | --overwrite_cache \
13 | --model_name_or_path THUDM/chatglm2-6b \
14 | --output_dir output/adgen-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \
15 | --overwrite_output_dir \
16 | --max_source_length 64 \
17 | --max_target_length 128 \
18 | --per_device_train_batch_size 1 \
19 | --per_device_eval_batch_size 1 \
20 | --gradient_accumulation_steps 16 \
21 | --predict_with_generate \
22 | --max_steps 3000 \
23 | --logging_steps 10 \
24 | --save_steps 1000 \
25 | --learning_rate $LR \
26 | --pre_seq_len $PRE_SEQ_LEN \
27 | --quantization_bit 4
28 |
29 |
--------------------------------------------------------------------------------
/scripts/train_chat.sh:
--------------------------------------------------------------------------------
1 | PRE_SEQ_LEN=128
2 | LR=1e-2
3 | NUM_GPUS=1
4 |
5 | torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \
6 | --do_train \
7 | --train_file $CHAT_TRAIN_DATA \
8 | --validation_file $CHAT_VAL_DATA \
9 | --preprocessing_num_workers 10 \
10 | --prompt_column prompt \
11 | --response_column response \
12 | --history_column history \
13 | --overwrite_cache \
14 | --model_name_or_path THUDM/chatglm2-6b \
15 | --output_dir $CHECKPOINT_NAME \
16 | --overwrite_output_dir \
17 | --max_source_length 256 \
18 | --max_target_length 256 \
19 | --per_device_train_batch_size 1 \
20 | --per_device_eval_batch_size 1 \
21 | --gradient_accumulation_steps 16 \
22 | --predict_with_generate \
23 | --max_steps 3000 \
24 | --logging_steps 10 \
25 | --save_steps 1000 \
26 | --learning_rate $LR \
27 | --pre_seq_len $PRE_SEQ_LEN \
28 | --quantization_bit 4
29 |
30 |
--------------------------------------------------------------------------------
/scripts/train_ptuning.sh:
--------------------------------------------------------------------------------
1 | TASK_NAME=default_task
2 | PRE_SEQ_LEN=128
3 | LR=2e-2
4 |
5 | CHAT_TRAIN_DATA=./data/train.json
6 | CHAT_VAL_DATA=./data/dev.json
7 |
8 | MODEL_NAME_OR_PATH=pre-trained-lm/chatglm-6b
9 |
10 | NUM_GPUS=8
11 |
12 | MODEL_VERSION=v1 # v1 or v2
13 |
14 |
15 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
16 | torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS chatglm_model_$MODEL_VERSION/run_ptuning.py \
17 | --do_train \
18 | --train_file $CHAT_TRAIN_DATA \
19 | --validation_file $CHAT_VAL_DATA \
20 | --prompt_column input \
21 | --response_column output \
22 | --model_name_or_path $MODEL_NAME_OR_PATH \
23 | --output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
24 | --overwrite_output_dir \
25 | --max_source_length 256 \
26 | --max_target_length 256 \
27 | --per_device_train_batch_size 32 \
28 | --per_device_eval_batch_size 32 \
29 | --gradient_accumulation_steps 1 \
30 | --predict_with_generate \
31 | --max_steps 9000 \
32 | --logging_steps 10 \
33 | --save_steps 1000 \
34 | --learning_rate $LR \
35 | --pre_seq_len $PRE_SEQ_LEN \
36 | --task_name $TASK_NAME \
37 | --base_cache_dir ./.cache/
38 | # --quantization_bit 4
--------------------------------------------------------------------------------
/scripts/train_lora.sh:
--------------------------------------------------------------------------------
1 |
2 | PEFT_TYPE=lora
3 | LORA_DIM=8
4 | LR=2e-2
5 | NUM_GPUS=1
6 | TRAIN_DATA=./data/train.json
7 | EVAL_DATA=./data/dev.json
8 |
9 | MODEL_VERSION=v1 # v1 or v2
10 |
11 | torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS chatglm_model_$MODEL_VERSION/run_peft.py \
12 | --do_train \
13 | --train_file $TRAIN_DATA \
14 | --validation_file $EVAL_DATA \
15 | --preprocessing_num_workers 10 \
16 | --prompt_column content \
17 | --response_column summary \
18 | --model_name_or_path THUDM/chatglm2-6b \
19 | --output_dir output/adgen-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \
20 | --overwrite_output_dir \
21 | --max_source_length 256 \
22 | --max_target_length 256 \
23 | --per_device_train_batch_size 32 \
24 | --per_device_eval_batch_size 32 \
25 | --gradient_accumulation_steps 1 \
26 | --predict_with_generate \
27 | --max_steps 9000 \
28 | --logging_steps 10 \
29 | --save_steps 1000 \
30 | --learning_rate $LR \
31 | --peft_type $PEFT_TYPE \
32 | --lora_dim $LORA_DIM \
33 | --task_name $TASK_NAME \
34 | --base_cache_dir ./cache
35 | # --quantization_bit 4
36 | # --overwrite_cache \
37 |
38 |
--------------------------------------------------------------------------------
/scripts/ds_train_ptuning.sh:
--------------------------------------------------------------------------------
1 | TASK_NAME=default_task
2 | PRE_SEQ_LEN=128
3 | LR=1e-4
4 |
5 | CHAT_TRAIN_DATA=data/train.json
6 | CHAT_VAL_DATA=data/dev.json
7 |
8 | MODEL_NAME_OR_PATH=pre-trained-lm/chatglm-6b
9 |
10 | NUM_GPUS=8
11 |
12 | MASTER_PORT=$(shuf -n 1 -i 10000-65535)
13 |
14 | MODEL_VERSION=v1 # v1 or v2
15 |
16 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
17 | deepspeed --num_gpus=$NUM_GPUS --master_port $MASTER_PORT chatglm_model_$MODEL_VERSION/run_ptuning.py \
18 | --deepspeed deepspeed/deepspeed.json \
19 | --do_train \
20 | --train_file $CHAT_TRAIN_DATA \
21 | --test_file $CHAT_VAL_DATA \
22 | --prompt_column input \
23 | --response_column output \
24 | --model_name_or_path $MODEL_NAME_OR_PATH \
25 | --output_dir ./output/deepspeed/adgen-chatglm-6b-ft-$LR \
26 | --overwrite_output_dir \
27 | --max_source_length 256 \
28 | --max_target_length 256 \
29 | --per_device_train_batch_size 32 \
30 | --per_device_eval_batch_size 32 \
31 | --gradient_accumulation_steps 1 \
32 | --predict_with_generate \
33 | --max_steps 9000 \
34 | --logging_steps 10 \
35 | --save_steps 1000 \
36 | --learning_rate $LR \
37 | --task_name $TASK_NAME \
38 | --base_cache_dir ./.cache \
39 | --fp16
40 | # --overwrite_cache \
--------------------------------------------------------------------------------
/scripts/ds_train_peft.sh:
--------------------------------------------------------------------------------
1 | TASK_NAME=default_task
2 | # PRE_SEQ_LEN=128
3 | PEFT_TYPE=lora
4 | LORA_DIM=8
5 | LR=1e-4
6 |
7 | CHAT_TRAIN_DATA=./data/train.json
8 | CHAT_VAL_DATA=./data/dev.json
9 |
10 | MODEL_NAME_OR_PATH=./pre-trained-lm/chatglm-6b
11 |
12 | NUM_GPUS=8
13 |
14 | MASTER_PORT=$(shuf -n 1 -i 10000-65535)
15 |
16 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
17 | deepspeed --num_gpus=$NUM_GPUS --master_port $MASTER_PORT chatglm_model_v1/run_peft.py \
18 | --deepspeed deepspeed/deepspeed.json \
19 | --do_train \
20 | --train_file $CHAT_TRAIN_DATA \
21 | --test_file $CHAT_VAL_DATA \
22 | --prompt_column input \
23 | --response_column output \
24 | --model_name_or_path $MODEL_NAME_OR_PATH \
25 | --output_dir ./output/deepspeed/chatglm-6b-$TASK_NAME-$PEFT_TYPE-$LORA_DIM-$LR \
26 | --overwrite_output_dir \
27 | --max_source_length 256 \
28 | --max_target_length 1024 \
29 | --per_device_train_batch_size 32 \
30 | --per_device_eval_batch_size 32 \
31 | --gradient_accumulation_steps 1 \
32 | --predict_with_generate \
33 | --max_steps 9000 \
34 | --logging_steps 10 \
35 | --save_steps 1000 \
36 | --learning_rate $LR \
37 | --peft_type $PEFT_TYPE \
38 | --lora_dim $LORA_DIM \
39 | --task_name $TASK_NAME \
40 | --base_cache_dir ./.cache/ \
41 | --fp16
42 | # --overwrite_cache \
--------------------------------------------------------------------------------
/application/cli_demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import platform
3 | import signal
4 | from transformers import AutoTokenizer, AutoModel
5 | import readline
6 |
7 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
8 | model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda()
9 | # 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量
10 | # from utils import load_model_on_gpus
11 | # model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
12 | model = model.eval()
13 |
14 | os_name = platform.system()
15 | clear_command = 'cls' if os_name == 'Windows' else 'clear'
16 | stop_stream = False
17 |
18 |
19 | def build_prompt(history):
20 | prompt = "欢迎使用 ChatGLM2-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
21 | for query, response in history:
22 | prompt += f"\n\n用户:{query}"
23 | prompt += f"\n\nChatGLM2-6B:{response}"
24 | return prompt
25 |
26 |
27 | def signal_handler(signal, frame):
28 | global stop_stream
29 | stop_stream = True
30 |
31 |
32 | def main():
33 | past_key_values, history = None, []
34 | global stop_stream
35 | print("欢迎使用 ChatGLM2-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
36 | while True:
37 | query = input("\n用户:")
38 | if query.strip() == "stop":
39 | break
40 | if query.strip() == "clear":
41 | past_key_values, history = None, []
42 | os.system(clear_command)
43 | print("欢迎使用 ChatGLM2-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
44 | continue
45 | print("\nChatGLM:", end="")
46 | current_length = 0
47 | for response, history, past_key_values in model.stream_chat(tokenizer, query, history=history,
48 | past_key_values=past_key_values,
49 | return_past_key_values=True):
50 | if stop_stream:
51 | stop_stream = False
52 | break
53 | else:
54 | print(response[current_length:], end="", flush=True)
55 | current_length = len(response)
56 | print("")
57 |
58 |
59 | if __name__ == "__main__":
60 | main()
61 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Dict, Tuple, Union, Optional
3 |
4 | from torch.nn import Module
5 | from transformers import AutoModel
6 |
7 |
8 | def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
9 | # transformer.word_embeddings 占用1层
10 | # transformer.final_layernorm 和 lm_head 占用1层
11 | # transformer.layers 占用 28 层
12 | # 总共30层分配到num_gpus张卡上
13 | num_trans_layers = 28
14 | per_gpu_layers = 30 / num_gpus
15 |
16 | # bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError
17 | # windows下 model.device 会被设置成 transformer.word_embeddings.device
18 | # linux下 model.device 会被设置成 lm_head.device
19 | # 在调用chat或者stream_chat时,input_ids会被放到model.device上
20 | # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
21 | # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
22 | # 本文件来源于https://github.com/THUDM/ChatGLM-6B/blob/main/utils.py
23 | # 仅此处做少许修改以支持ChatGLM2
24 | device_map = {
25 | 'transformer.embedding.word_embeddings': 0,
26 | 'transformer.encoder.final_layernorm': 0,
27 | 'transformer.output_layer': 0,
28 | 'transformer.rotary_pos_emb': 0,
29 | 'lm_head': 0
30 | }
31 |
32 | used = 2
33 | gpu_target = 0
34 | for i in range(num_trans_layers):
35 | if used >= per_gpu_layers:
36 | gpu_target += 1
37 | used = 0
38 | assert gpu_target < num_gpus
39 | device_map[f'transformer.encoder.layers.{i}'] = gpu_target
40 | used += 1
41 |
42 | return device_map
43 |
44 |
45 | def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 2,
46 | device_map: Optional[Dict[str, int]] = None, **kwargs) -> Module:
47 | if num_gpus < 2 and device_map is None:
48 | model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half().cuda()
49 | else:
50 | from accelerate import dispatch_model
51 |
52 | model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half()
53 |
54 | if device_map is None:
55 | device_map = auto_configure_device_map(num_gpus)
56 |
57 | model = dispatch_model(model, device_map=device_map)
58 |
59 | return model
60 |
--------------------------------------------------------------------------------
/application/api.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, Request
2 | from transformers import AutoTokenizer, AutoModel
3 | import uvicorn, json, datetime
4 | import torch
5 |
6 | DEVICE = "cuda"
7 | DEVICE_ID = "0"
8 | CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
9 |
10 |
11 | def torch_gc():
12 | if torch.cuda.is_available():
13 | with torch.cuda.device(CUDA_DEVICE):
14 | torch.cuda.empty_cache()
15 | torch.cuda.ipc_collect()
16 |
17 |
18 | app = FastAPI()
19 |
20 |
21 | @app.post("/")
22 | async def create_item(request: Request):
23 | global model, tokenizer
24 | json_post_raw = await request.json()
25 | json_post = json.dumps(json_post_raw)
26 | json_post_list = json.loads(json_post)
27 | prompt = json_post_list.get('prompt')
28 | history = json_post_list.get('history')
29 | max_length = json_post_list.get('max_length')
30 | top_p = json_post_list.get('top_p')
31 | temperature = json_post_list.get('temperature')
32 | response, history = model.chat(tokenizer,
33 | prompt,
34 | history=history,
35 | max_length=max_length if max_length else 2048,
36 | top_p=top_p if top_p else 0.7,
37 | temperature=temperature if temperature else 0.95)
38 | now = datetime.datetime.now()
39 | time = now.strftime("%Y-%m-%d %H:%M:%S")
40 | answer = {
41 | "response": response,
42 | "history": history,
43 | "status": 200,
44 | "time": time
45 | }
46 | log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
47 | print(log)
48 | torch_gc()
49 | return answer
50 |
51 |
52 | if __name__ == '__main__':
53 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
54 | model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda()
55 | # 多显卡支持,使用下面三行代替上面两行,将num_gpus改为你实际的显卡数量
56 | # model_path = "THUDM/chatglm2-6b"
57 | # tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
58 | # model = load_model_on_gpus(model_path, num_gpus=2)
59 | model.eval()
60 | uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
61 |
--------------------------------------------------------------------------------
/application/web_demo2.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModel, AutoTokenizer
2 | import streamlit as st
3 |
4 |
5 | st.set_page_config(
6 | page_title="ChatGLM2-6b 演示",
7 | page_icon=":robot:",
8 | layout='wide'
9 | )
10 |
11 |
12 | @st.cache_resource
13 | def get_model():
14 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
15 | model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda()
16 | # 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量
17 | # from utils import load_model_on_gpus
18 | # model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
19 | model = model.eval()
20 | return tokenizer, model
21 |
22 |
23 | tokenizer, model = get_model()
24 |
25 | st.title("ChatGLM2-6B")
26 |
27 | max_length = st.sidebar.slider(
28 | 'max_length', 0, 32768, 8192, step=1
29 | )
30 | top_p = st.sidebar.slider(
31 | 'top_p', 0.0, 1.0, 0.8, step=0.01
32 | )
33 | temperature = st.sidebar.slider(
34 | 'temperature', 0.0, 1.0, 0.8, step=0.01
35 | )
36 |
37 | if 'history' not in st.session_state:
38 | st.session_state.history = []
39 |
40 | if 'past_key_values' not in st.session_state:
41 | st.session_state.past_key_values = None
42 |
43 | for i, (query, response) in enumerate(st.session_state.history):
44 | with st.chat_message(name="user", avatar="user"):
45 | st.markdown(query)
46 | with st.chat_message(name="assistant", avatar="assistant"):
47 | st.markdown(response)
48 | with st.chat_message(name="user", avatar="user"):
49 | input_placeholder = st.empty()
50 | with st.chat_message(name="assistant", avatar="assistant"):
51 | message_placeholder = st.empty()
52 |
53 | prompt_text = st.text_area(label="用户命令输入",
54 | height=100,
55 | placeholder="请在这儿输入您的命令")
56 |
57 | button = st.button("发送", key="predict")
58 |
59 | if button:
60 | input_placeholder.markdown(prompt_text)
61 | history, past_key_values = st.session_state.history, st.session_state.past_key_values
62 | for response, history, past_key_values in model.stream_chat(tokenizer, prompt_text, history,
63 | past_key_values=past_key_values,
64 | max_length=max_length, top_p=top_p,
65 | temperature=temperature,
66 | return_past_key_values=True):
67 | message_placeholder.markdown(response)
68 |
69 | st.session_state.history = history
70 | st.session_state.past_key_values = past_key_values
71 |
--------------------------------------------------------------------------------
/evaluation/evaluate_ceval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import re
4 | import json
5 | import torch
6 | import torch.utils.data
7 | from transformers import AutoTokenizer, AutoModel
8 | from tqdm import tqdm
9 |
10 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
11 | model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).bfloat16().cuda()
12 |
13 | choices = ["A", "B", "C", "D"]
14 | choice_tokens = [tokenizer.encode(choice, add_special_tokens=False)[0] for choice in choices]
15 |
16 |
17 | def build_prompt(text):
18 | return "[Round {}]\n\n问:{}\n\n答:".format(1, text)
19 |
20 |
21 | extraction_prompt = '综上所述,ABCD中正确的选项是:'
22 |
23 | accuracy_dict, count_dict = {}, {}
24 | with torch.no_grad():
25 | for entry in glob.glob("./CEval/val/**/*.jsonl", recursive=True):
26 | dataset = []
27 | with open(entry, encoding='utf-8') as file:
28 | for line in file:
29 | dataset.append(json.loads(line))
30 | correct = 0
31 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
32 | for batch in tqdm(dataloader):
33 | texts = batch["inputs_pretokenized"]
34 | queries = [build_prompt(query) for query in texts]
35 | inputs = tokenizer(queries, padding=True, return_tensors="pt", truncation=True, max_length=2048).to('cuda')
36 | outputs = model.generate(**inputs, do_sample=False, max_new_tokens=512)
37 | intermediate_outputs = []
38 | for idx in range(len(outputs)):
39 | output = outputs.tolist()[idx][len(inputs["input_ids"][idx]):]
40 | response = tokenizer.decode(output)
41 | intermediate_outputs.append(response)
42 | answer_texts = [text + intermediate + "\n" + extraction_prompt for text, intermediate in
43 | zip(texts, intermediate_outputs)]
44 | input_tokens = [build_prompt(answer_text) for answer_text in answer_texts]
45 | inputs = tokenizer(input_tokens, padding=True, return_tensors="pt", truncation=True, max_length=2048).to('cuda')
46 | outputs = model(**inputs, return_last_logit=True)
47 | logits = outputs.logits[:, -1]
48 | logits = logits[:, choice_tokens]
49 | preds = logits.argmax(dim=-1)
50 | correct += (preds.cpu() == batch["label"]).sum().item()
51 | accuracy = correct / len(dataset)
52 | print(entry, accuracy)
53 | accuracy_dict[entry] = accuracy
54 | count_dict[entry] = len(dataset)
55 |
56 | acc_total, count_total = 0.0, 0
57 | for key in accuracy_dict:
58 | acc_total += accuracy_dict[key] * count_dict[key]
59 | count_total += count_dict[key]
60 | print(acc_total / count_total)
--------------------------------------------------------------------------------
/MODEL_LICENSE:
--------------------------------------------------------------------------------
1 | The ChatGLM-6B License
2 |
3 | 一、定义
4 |
5 | “许可方”是指分发其软件的 ChatGLM2-6B 模型团队。
6 |
7 | “软件”是指根据本许可提供的 ChatGLM2-6B 模型参数。
8 |
9 | 2. 许可授予
10 |
11 | 根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可,仅用于您的非商业研究目的。
12 |
13 | 上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。
14 |
15 | 3.限制
16 |
17 | 您不得出于任何商业、军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。
18 |
19 | 您不得利用本软件从事任何危害国家安全和国家统一、危害社会公共利益、侵犯人身权益的行为。
20 |
21 | 4.免责声明
22 |
23 | 本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。 在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关 软件。
24 |
25 | 5. 责任限制
26 |
27 | 除适用法律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。
28 |
29 | 6.争议解决
30 |
31 | 本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。
32 |
33 | 请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 glm-130b@googlegroups.com 与我们联系。
34 |
35 | 1. Definitions
36 |
37 | “Licensor” means the ChatGLM2-6B Model Team that distributes its Software.
38 |
39 | “Software” means the ChatGLM2-6B model parameters made available under this license.
40 |
41 | 2. License Grant
42 |
43 | Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes.
44 |
45 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
46 |
47 | 3. Restriction
48 |
49 | You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes.
50 |
51 | You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings.
52 |
53 | 4. Disclaimer
54 |
55 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
56 |
57 | 5. Limitation of Liability
58 |
59 | EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
60 |
61 | 6. Dispute Resolution
62 |
63 | This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing.
64 |
65 | Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com.
66 |
--------------------------------------------------------------------------------
/chatglm_model_v2/trainer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """
16 | The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
17 | """
18 | import os
19 | from typing import Optional
20 | from transformers import Trainer
21 |
22 | import torch
23 | from transformers.modeling_utils import PreTrainedModel, unwrap_model
24 | from transformers.utils import logging
25 |
26 | logger = logging.get_logger(__name__)
27 |
28 | WEIGHTS_NAME = "pytorch_model.bin"
29 | TRAINING_ARGS_NAME = "training_args.bin"
30 |
31 |
32 | class PrefixTrainer(Trainer):
33 | def __init__(self, *args, save_changed=False, **kwargs):
34 | self.save_changed = save_changed
35 | super().__init__(*args, **kwargs)
36 |
37 | def _save(self, output_dir: Optional[str] = None, state_dict=None):
38 | # If we are executing this function, we are the process zero, so we don't check for that.
39 | output_dir = output_dir if output_dir is not None else self.args.output_dir
40 | os.makedirs(output_dir, exist_ok=True)
41 | logger.info(f"Saving model checkpoint to {output_dir}")
42 | # Save a trained model and configuration using `save_pretrained()`.
43 | # They can then be reloaded using `from_pretrained()`
44 | if not isinstance(self.model, PreTrainedModel):
45 | if isinstance(unwrap_model(self.model), PreTrainedModel):
46 | if state_dict is None:
47 | state_dict = self.model.state_dict()
48 | unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict)
49 | else:
50 | logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
51 | if state_dict is None:
52 | state_dict = self.model.state_dict()
53 | torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
54 | else:
55 | if self.save_changed:
56 | print("Saving PrefixEncoder")
57 | state_dict = self.model.state_dict()
58 | filtered_state_dict = {}
59 | for k, v in self.model.named_parameters():
60 | if v.requires_grad:
61 | filtered_state_dict[k] = state_dict[k]
62 | self.model.save_pretrained(output_dir, state_dict=filtered_state_dict)
63 | else:
64 | print("Saving the whole model")
65 | self.model.save_pretrained(output_dir, state_dict=state_dict)
66 | if self.tokenizer is not None:
67 | self.tokenizer.save_pretrained(output_dir)
68 |
69 | # Good practice: save your training arguments together with the trained model
70 | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
71 |
--------------------------------------------------------------------------------
/application/web_demo.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModel, AutoTokenizer
2 | import gradio as gr
3 | import mdtex2html
4 | from utils import load_model_on_gpus
5 |
6 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
7 | model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda()
8 | # 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量
9 | # from utils import load_model_on_gpus
10 | # model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
11 | model = model.eval()
12 |
13 | """Override Chatbot.postprocess"""
14 |
15 |
16 | def postprocess(self, y):
17 | if y is None:
18 | return []
19 | for i, (message, response) in enumerate(y):
20 | y[i] = (
21 | None if message is None else mdtex2html.convert((message)),
22 | None if response is None else mdtex2html.convert(response),
23 | )
24 | return y
25 |
26 |
27 | gr.Chatbot.postprocess = postprocess
28 |
29 |
30 | def parse_text(text):
31 | """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
32 | lines = text.split("\n")
33 | lines = [line for line in lines if line != ""]
34 | count = 0
35 | for i, line in enumerate(lines):
36 | if "```" in line:
37 | count += 1
38 | items = line.split('`')
39 | if count % 2 == 1:
40 | lines[i] = f''
41 | else:
42 | lines[i] = f'
'
43 | else:
44 | if i > 0:
45 | if count % 2 == 1:
46 | line = line.replace("`", "\`")
47 | line = line.replace("<", "<")
48 | line = line.replace(">", ">")
49 | line = line.replace(" ", " ")
50 | line = line.replace("*", "*")
51 | line = line.replace("_", "_")
52 | line = line.replace("-", "-")
53 | line = line.replace(".", ".")
54 | line = line.replace("!", "!")
55 | line = line.replace("(", "(")
56 | line = line.replace(")", ")")
57 | line = line.replace("$", "$")
58 | lines[i] = "
"+line
59 | text = "".join(lines)
60 | return text
61 |
62 |
63 | def predict(input, chatbot, max_length, top_p, temperature, history, past_key_values):
64 | chatbot.append((parse_text(input), ""))
65 | for response, history, past_key_values in model.stream_chat(tokenizer, input, history, past_key_values=past_key_values,
66 | return_past_key_values=True,
67 | max_length=max_length, top_p=top_p,
68 | temperature=temperature):
69 | chatbot[-1] = (parse_text(input), parse_text(response))
70 |
71 | yield chatbot, history, past_key_values
72 |
73 |
74 | def reset_user_input():
75 | return gr.update(value='')
76 |
77 |
78 | def reset_state():
79 | return [], [], None
80 |
81 |
82 | with gr.Blocks() as demo:
83 | gr.HTML("""ChatGLM2-6B
""")
84 |
85 | chatbot = gr.Chatbot()
86 | with gr.Row():
87 | with gr.Column(scale=4):
88 | with gr.Column(scale=12):
89 | user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
90 | container=False)
91 | with gr.Column(min_width=32, scale=1):
92 | submitBtn = gr.Button("Submit", variant="primary")
93 | with gr.Column(scale=1):
94 | emptyBtn = gr.Button("Clear History")
95 | max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
96 | top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
97 | temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
98 |
99 | history = gr.State([])
100 | past_key_values = gr.State(None)
101 |
102 | submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history, past_key_values],
103 | [chatbot, history, past_key_values], show_progress=True)
104 | submitBtn.click(reset_user_input, [], [user_input])
105 |
106 | emptyBtn.click(reset_state, outputs=[chatbot, history, past_key_values], show_progress=True)
107 |
108 | demo.queue().launch(share=False, inbrowser=True)
109 |
--------------------------------------------------------------------------------
/chatglm_model_v2/README.md:
--------------------------------------------------------------------------------
1 | # ChatGLM2-6B-PT
2 | 本仓库实现了对于 ChatGLM2-6B 模型基于 [P-Tuning v2](https://github.com/THUDM/P-tuning-v2) 的微调。P-Tuning v2 将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,最低只需要 7GB 显存即可运行。
3 |
4 | 下面以 [ADGEN](https://aclanthology.org/D19-1321.pdf) (广告生成) 数据集为例介绍代码的使用方法。
5 |
6 | ## 软件依赖
7 | 运行微调除 ChatGLM2-6B 的依赖之外,还需要安装以下依赖
8 | ```
9 | pip install rouge_chinese nltk jieba datasets
10 | ```
11 | ## 使用方法
12 |
13 | ### 下载数据集
14 | ADGEN 数据集任务为根据输入(content)生成一段广告词(summary)。
15 |
16 | ```json
17 | {
18 | "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",
19 | "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
20 | }
21 | ```
22 |
23 | 从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) 或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载处理好的 ADGEN 数据集,将解压后的 `AdvertiseGen` 目录放到本目录下。
24 |
25 | ### 训练
26 |
27 | #### P-Tuning v2
28 |
29 | 运行以下指令进行训练:
30 | ```shell
31 | bash train.sh
32 | ```
33 | `train.sh` 中的 `PRE_SEQ_LEN` 和 `LR` 分别是 soft prompt 长度和训练的学习率,可以进行调节以取得最佳的效果。P-Tuning-v2 方法会冻结全部的模型参数,可通过调整 `quantization_bit` 来被原始模型的量化等级,不加此选项则为 FP16 精度加载。
34 |
35 | 在默认配置 `quantization_bit=4`、`per_device_train_batch_size=1`、`gradient_accumulation_steps=16` 下,INT4 的模型参数被冻结,一次训练迭代会以 1 的批处理大小进行 16 次累加的前后向传播,等效为 16 的总批处理大小,此时最低只需 6.7G 显存。若想在同等批处理大小下提升训练效率,可在二者乘积不变的情况下,加大 `per_device_train_batch_size` 的值,但也会带来更多的显存消耗,请根据实际情况酌情调整。
36 |
37 | 如果你想要[从本地加载模型](../README.md#从本地加载模型),可以将 `train.sh` 中的 `THUDM/chatglm2-6b` 改为你本地的模型路径。
38 |
39 | #### Finetune
40 |
41 | 如果需要进行全参数的 Finetune,需要安装 [Deepspeed](https://github.com/microsoft/DeepSpeed),然后运行以下指令:
42 |
43 | ```shell
44 | bash ds_train_finetune.sh
45 | ```
46 |
47 | ### 推理
48 |
49 | 在 P-tuning v2 训练时模型只保存 PrefixEncoder 部分的参数,所以在推理时需要同时加载原 ChatGLM2-6B 模型以及 PrefixEncoder 的权重,因此需要指定 `evaluate.sh` 中的参数:
50 |
51 | ```shell
52 | --model_name_or_path THUDM/chatglm2-6b
53 | --ptuning_checkpoint $CHECKPOINT_PATH
54 | ```
55 |
56 | 如果是,只需要跟之前一样设定 `model_name_or_path`:
57 |
58 | ```shell
59 | --model_name_or_path $CHECKPOINT_PATH
60 | ```
61 |
62 | 评测指标为中文 Rouge score 和 BLEU-4。生成的结果保存在
63 | `./output/adgen-chatglm2-6b-pt-128-2e-2/generated_predictions.txt`。
64 |
65 | ### 例子
66 | #### 示例1
67 | * Input: 类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞
68 | * Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。
69 | * Output[微调前]: 这件上衣的材质是牛仔布,颜色是白色,风格是简约,图案是刺绣,衣样式是外套,衣款式是破洞。
70 | * Output[微调后]: 这是一款简约的牛仔外套,破洞设计,将牛仔布破洞,带来一种随意与个性。破洞的牛仔外套,展现出时尚气息,带来一种休闲感。同时,刺绣图案,让整件外套更加立体。
71 |
72 | #### 示例2
73 |
74 | * Input: 类型#裙\*版型#显瘦\*风格#文艺\*风格#简约\*图案#印花\*图案#撞色\*裙下摆#压褶\*裙长#连衣裙\*裙领型#圆领
75 | * Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。
76 | * Output[微调前]: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领 1\. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2\. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3\. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4\. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5\. 风格:文艺风格,让连衣裙更加有内涵和品味。6\. 图案:印花设计,在连衣裙上印有独特的图案。7\. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。
77 | * Output[微调后]: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。
78 |
79 |
80 | ## 模型部署
81 | 首先载入Tokenizer:
82 |
83 | ```python
84 | from transformers import AutoConfig, AutoModel, AutoTokenizer
85 |
86 | # 载入Tokenizer
87 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
88 | ```
89 |
90 | 1. 如果需要加载的 P-Tuning 的 checkpoint:
91 |
92 | ```python
93 | config = AutoConfig.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, pre_seq_len=128)
94 | model = AutoModel.from_pretrained("THUDM/chatglm2-6b", config=config, trust_remote_code=True)
95 | prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
96 | new_prefix_state_dict = {}
97 | for k, v in prefix_state_dict.items():
98 | if k.startswith("transformer.prefix_encoder."):
99 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
100 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
101 | ```
102 | 注意你可能需要将 `pre_seq_len` 改成你训练时的实际值。如果你是[从本地加载模型](../README.md#从本地加载模型)的话,需要将 `THUDM/chatglm2-6b` 改成本地的模型路径(注意不是checkpoint路径)。
103 |
104 | 2. 如果需要加载的是全参数微调的 checkpoint,则直接加载整个 checkpoint:
105 |
106 | ```python
107 | model = AutoModel.from_pretrained(CHECKPOINT_PATH, trust_remote_code=True)
108 | ```
109 |
110 | 之后根据需求可以进行量化,也可以直接使用:
111 |
112 | ```python
113 | # Comment out the following line if you don't use quantization
114 | model = model.quantize(4)
115 | model = model.cuda()
116 | model = model.eval()
117 |
118 | response, history = model.chat(tokenizer, "你好", history=[])
119 | ```
120 |
121 | 你也可以直接运行支持加载 P-Tuning v2 checkpoint 的 [web demo](./web_demo.py)
122 | ```shell
123 | bash web_demo.sh
124 | ```
125 | 可能需要修改 [web_demo.sh](./web_demo.sh) 的内容以符合你实际的 checkpoint 情况。
126 |
127 | ## 使用自己的数据集
128 | 修改 `train.sh` 和 `evaluate.sh` 中的 `train_file`、`validation_file`和`test_file`为你自己的 JSON 格式数据集路径,并将 `prompt_column` 和 `response_column` 改为 JSON 文件中输入文本和输出文本对应的 KEY。可能还需要增大 `max_source_length` 和 `max_target_length` 来匹配你自己的数据集中的最大输入输出长度。
129 |
130 | ## 对话数据集
131 |
132 | 如需要使用多轮对话数据对模型进行微调,可以提供聊天历史,例如以下是一个三轮对话的训练数据:
133 |
134 | ```json lines
135 | {"prompt": "长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "response": "用电脑能读数据流吗?水温多少", "history": []}
136 | {"prompt": "95", "response": "上下水管温差怎么样啊?空气是不是都排干净了呢?", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"]]}
137 | {"prompt": "是的。上下水管都好的", "response": "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"], ["95", "上下水管温差怎么样啊?空气是不是都排干净了呢?"]]}
138 | ```
139 |
140 | 训练时需要指定 `--history_column` 为数据中聊天历史的 key(在此例子中是 `history`),将自动把聊天历史拼接。要注意超过输入长度 `max_source_length` 的内容会被截断。
141 |
142 | 可以参考以下指令:
143 |
144 | ```shell
145 | bash train_chat.sh
146 | ```
147 |
148 | ## 引用
149 |
150 | ```
151 | @inproceedings{liu2022p,
152 | title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},
153 | author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},
154 | booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)},
155 | pages={61--68},
156 | year={2022}
157 | }
158 | ```
159 |
160 |
161 |
162 |
--------------------------------------------------------------------------------
/chatglm_model_v1/web_demo.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 |
3 | import gradio as gr
4 | import mdtex2html
5 |
6 | import torch
7 | import transformers
8 | from transformers import (
9 | AutoConfig,
10 | AutoModel,
11 | AutoTokenizer,
12 | AutoTokenizer,
13 | DataCollatorForSeq2Seq,
14 | HfArgumentParser,
15 | Seq2SeqTrainingArguments,
16 | set_seed,
17 | )
18 |
19 | from arguments import ModelArguments, DataTrainingArguments
20 |
21 |
22 | model = None
23 | tokenizer = None
24 |
25 | """Override Chatbot.postprocess"""
26 |
27 |
28 | def postprocess(self, y):
29 | if y is None:
30 | return []
31 | for i, (message, response) in enumerate(y):
32 | y[i] = (
33 | None if message is None else mdtex2html.convert((message)),
34 | None if response is None else mdtex2html.convert(response),
35 | )
36 | return y
37 |
38 |
39 | gr.Chatbot.postprocess = postprocess
40 |
41 |
42 | def parse_text(text):
43 | """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
44 | lines = text.split("\n")
45 | lines = [line for line in lines if line != ""]
46 | count = 0
47 | for i, line in enumerate(lines):
48 | if "```" in line:
49 | count += 1
50 | items = line.split('`')
51 | if count % 2 == 1:
52 | lines[i] = f''
53 | else:
54 | lines[i] = f'
'
55 | else:
56 | if i > 0:
57 | if count % 2 == 1:
58 | line = line.replace("`", "\`")
59 | line = line.replace("<", "<")
60 | line = line.replace(">", ">")
61 | line = line.replace(" ", " ")
62 | line = line.replace("*", "*")
63 | line = line.replace("_", "_")
64 | line = line.replace("-", "-")
65 | line = line.replace(".", ".")
66 | line = line.replace("!", "!")
67 | line = line.replace("(", "(")
68 | line = line.replace(")", ")")
69 | line = line.replace("$", "$")
70 | lines[i] = "
"+line
71 | text = "".join(lines)
72 | return text
73 |
74 |
75 | def predict(input, chatbot, max_length, top_p, temperature, history):
76 | chatbot.append((parse_text(input), ""))
77 | for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
78 | temperature=temperature):
79 | chatbot[-1] = (parse_text(input), parse_text(response))
80 |
81 | yield chatbot, history
82 |
83 |
84 | def reset_user_input():
85 | return gr.update(value='')
86 |
87 |
88 | def reset_state():
89 | return [], []
90 |
91 |
92 | with gr.Blocks() as demo:
93 | gr.HTML("""ChatGLM
""")
94 |
95 | chatbot = gr.Chatbot()
96 | with gr.Row():
97 | with gr.Column(scale=4):
98 | with gr.Column(scale=12):
99 | user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
100 | container=False)
101 | with gr.Column(min_width=32, scale=1):
102 | submitBtn = gr.Button("Submit", variant="primary")
103 | with gr.Column(scale=1):
104 | emptyBtn = gr.Button("Clear History")
105 | max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
106 | top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
107 | temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
108 |
109 | history = gr.State([])
110 |
111 | submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
112 | show_progress=True)
113 | submitBtn.click(reset_user_input, [], [user_input])
114 |
115 | emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
116 |
117 |
118 |
119 | def main():
120 | global model, tokenizer
121 |
122 | parser = HfArgumentParser((
123 | ModelArguments))
124 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
125 | # If we pass only one argument to the script and it's the path to a json file,
126 | # let's parse it to get our arguments.
127 | model_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
128 | else:
129 | model_args = parser.parse_args_into_dataclasses()[0]
130 |
131 | tokenizer = AutoTokenizer.from_pretrained(
132 | model_args.model_name_or_path, trust_remote_code=True)
133 | config = AutoConfig.from_pretrained(
134 | model_args.model_name_or_path, trust_remote_code=True)
135 |
136 | config.pre_seq_len = model_args.pre_seq_len
137 | config.prefix_projection = model_args.prefix_projection
138 |
139 | if model_args.ptuning_checkpoint is not None:
140 | print(f"Loading prefix_encoder weight from {model_args.ptuning_checkpoint}")
141 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
142 | prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
143 | new_prefix_state_dict = {}
144 | for k, v in prefix_state_dict.items():
145 | if k.startswith("transformer.prefix_encoder."):
146 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
147 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
148 | else:
149 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
150 |
151 | if model_args.quantization_bit is not None:
152 | print(f"Quantized to {model_args.quantization_bit} bit")
153 | model = model.quantize(model_args.quantization_bit)
154 |
155 | if model_args.pre_seq_len is not None:
156 | # P-tuning v2
157 | model = model.half().cuda()
158 | model.transformer.prefix_encoder.float().cuda()
159 |
160 | model = model.eval()
161 | demo.queue().launch(share=False, inbrowser=True)
162 |
163 |
164 |
165 | if __name__ == "__main__":
166 | main()
--------------------------------------------------------------------------------
/application/openai_api.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Implements API for ChatGLM2-6B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
3 | # Usage: python openai_api.py
4 | # Visit http://localhost:8000/docs for documents.
5 |
6 |
7 | import time
8 | import torch
9 | import uvicorn
10 | from pydantic import BaseModel, Field
11 | from fastapi import FastAPI, HTTPException
12 | from fastapi.middleware.cors import CORSMiddleware
13 | from contextlib import asynccontextmanager
14 | from typing import Any, Dict, List, Literal, Optional, Union
15 | from transformers import AutoTokenizer, AutoModel
16 | from sse_starlette.sse import ServerSentEvent, EventSourceResponse
17 |
18 |
19 | @asynccontextmanager
20 | async def lifespan(app: FastAPI): # collects GPU memory
21 | yield
22 | if torch.cuda.is_available():
23 | torch.cuda.empty_cache()
24 | torch.cuda.ipc_collect()
25 |
26 |
27 | app = FastAPI(lifespan=lifespan)
28 |
29 | app.add_middleware(
30 | CORSMiddleware,
31 | allow_origins=["*"],
32 | allow_credentials=True,
33 | allow_methods=["*"],
34 | allow_headers=["*"],
35 | )
36 |
37 | class ModelCard(BaseModel):
38 | id: str
39 | object: str = "model"
40 | created: int = Field(default_factory=lambda: int(time.time()))
41 | owned_by: str = "owner"
42 | root: Optional[str] = None
43 | parent: Optional[str] = None
44 | permission: Optional[list] = None
45 |
46 |
47 | class ModelList(BaseModel):
48 | object: str = "list"
49 | data: List[ModelCard] = []
50 |
51 |
52 | class ChatMessage(BaseModel):
53 | role: Literal["user", "assistant", "system"]
54 | content: str
55 |
56 |
57 | class DeltaMessage(BaseModel):
58 | role: Optional[Literal["user", "assistant", "system"]] = None
59 | content: Optional[str] = None
60 |
61 |
62 | class ChatCompletionRequest(BaseModel):
63 | model: str
64 | messages: List[ChatMessage]
65 | temperature: Optional[float] = None
66 | top_p: Optional[float] = None
67 | max_length: Optional[int] = None
68 | stream: Optional[bool] = False
69 |
70 |
71 | class ChatCompletionResponseChoice(BaseModel):
72 | index: int
73 | message: ChatMessage
74 | finish_reason: Literal["stop", "length"]
75 |
76 |
77 | class ChatCompletionResponseStreamChoice(BaseModel):
78 | index: int
79 | delta: DeltaMessage
80 | finish_reason: Optional[Literal["stop", "length"]]
81 |
82 |
83 | class ChatCompletionResponse(BaseModel):
84 | model: str
85 | object: Literal["chat.completion", "chat.completion.chunk"]
86 | choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
87 | created: Optional[int] = Field(default_factory=lambda: int(time.time()))
88 |
89 |
90 | @app.get("/v1/models", response_model=ModelList)
91 | async def list_models():
92 | global model_args
93 | model_card = ModelCard(id="gpt-3.5-turbo")
94 | return ModelList(data=[model_card])
95 |
96 |
97 | @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
98 | async def create_chat_completion(request: ChatCompletionRequest):
99 | global model, tokenizer
100 |
101 | if request.messages[-1].role != "user":
102 | raise HTTPException(status_code=400, detail="Invalid request")
103 | query = request.messages[-1].content
104 |
105 | prev_messages = request.messages[:-1]
106 | if len(prev_messages) > 0 and prev_messages[0].role == "system":
107 | query = prev_messages.pop(0).content + query
108 |
109 | history = []
110 | if len(prev_messages) % 2 == 0:
111 | for i in range(0, len(prev_messages), 2):
112 | if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
113 | history.append([prev_messages[i].content, prev_messages[i+1].content])
114 |
115 | if request.stream:
116 | generate = predict(query, history, request.model)
117 | return EventSourceResponse(generate, media_type="text/event-stream")
118 |
119 | response, _ = model.chat(tokenizer, query, history=history)
120 | choice_data = ChatCompletionResponseChoice(
121 | index=0,
122 | message=ChatMessage(role="assistant", content=response),
123 | finish_reason="stop"
124 | )
125 |
126 | return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
127 |
128 |
129 | async def predict(query: str, history: List[List[str]], model_id: str):
130 | global model, tokenizer
131 |
132 | choice_data = ChatCompletionResponseStreamChoice(
133 | index=0,
134 | delta=DeltaMessage(role="assistant"),
135 | finish_reason=None
136 | )
137 | chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
138 | yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
139 |
140 | current_length = 0
141 |
142 | for new_response, _ in model.stream_chat(tokenizer, query, history):
143 | if len(new_response) == current_length:
144 | continue
145 |
146 | new_text = new_response[current_length:]
147 | current_length = len(new_response)
148 |
149 | choice_data = ChatCompletionResponseStreamChoice(
150 | index=0,
151 | delta=DeltaMessage(content=new_text),
152 | finish_reason=None
153 | )
154 | chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
155 | yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
156 |
157 |
158 | choice_data = ChatCompletionResponseStreamChoice(
159 | index=0,
160 | delta=DeltaMessage(),
161 | finish_reason="stop"
162 | )
163 | chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
164 | yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
165 | yield '[DONE]'
166 |
167 |
168 |
169 | if __name__ == "__main__":
170 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
171 | model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda()
172 | # 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量
173 | # from utils import load_model_on_gpus
174 | # model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
175 | model.eval()
176 |
177 | uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
178 |
--------------------------------------------------------------------------------
/chatglm_model_v2/web_demo.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 |
3 | import gradio as gr
4 | import mdtex2html
5 |
6 | import torch
7 | import transformers
8 | from transformers import (
9 | AutoConfig,
10 | AutoModel,
11 | AutoTokenizer,
12 | AutoTokenizer,
13 | DataCollatorForSeq2Seq,
14 | HfArgumentParser,
15 | Seq2SeqTrainingArguments,
16 | set_seed,
17 | )
18 |
19 | from arguments import ModelArguments, DataTrainingArguments
20 |
21 |
22 | model = None
23 | tokenizer = None
24 |
25 | """Override Chatbot.postprocess"""
26 |
27 |
28 | def postprocess(self, y):
29 | if y is None:
30 | return []
31 | for i, (message, response) in enumerate(y):
32 | y[i] = (
33 | None if message is None else mdtex2html.convert((message)),
34 | None if response is None else mdtex2html.convert(response),
35 | )
36 | return y
37 |
38 |
39 | gr.Chatbot.postprocess = postprocess
40 |
41 |
42 | def parse_text(text):
43 | """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
44 | lines = text.split("\n")
45 | lines = [line for line in lines if line != ""]
46 | count = 0
47 | for i, line in enumerate(lines):
48 | if "```" in line:
49 | count += 1
50 | items = line.split('`')
51 | if count % 2 == 1:
52 | lines[i] = f''
53 | else:
54 | lines[i] = f'
'
55 | else:
56 | if i > 0:
57 | if count % 2 == 1:
58 | line = line.replace("`", "\`")
59 | line = line.replace("<", "<")
60 | line = line.replace(">", ">")
61 | line = line.replace(" ", " ")
62 | line = line.replace("*", "*")
63 | line = line.replace("_", "_")
64 | line = line.replace("-", "-")
65 | line = line.replace(".", ".")
66 | line = line.replace("!", "!")
67 | line = line.replace("(", "(")
68 | line = line.replace(")", ")")
69 | line = line.replace("$", "$")
70 | lines[i] = "
"+line
71 | text = "".join(lines)
72 | return text
73 |
74 |
75 | def predict(input, chatbot, max_length, top_p, temperature, history, past_key_values):
76 | chatbot.append((parse_text(input), ""))
77 | for response, history, past_key_values in model.stream_chat(tokenizer, input, history, past_key_values=past_key_values,
78 | return_past_key_values=True,
79 | max_length=max_length, top_p=top_p,
80 | temperature=temperature):
81 | chatbot[-1] = (parse_text(input), parse_text(response))
82 |
83 | yield chatbot, history, past_key_values
84 |
85 |
86 | def reset_user_input():
87 | return gr.update(value='')
88 |
89 |
90 | def reset_state():
91 | return [], [], None
92 |
93 |
94 | with gr.Blocks() as demo:
95 | gr.HTML("""ChatGLM2-6B
""")
96 |
97 | chatbot = gr.Chatbot()
98 | with gr.Row():
99 | with gr.Column(scale=4):
100 | with gr.Column(scale=12):
101 | user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
102 | container=False)
103 | with gr.Column(min_width=32, scale=1):
104 | submitBtn = gr.Button("Submit", variant="primary")
105 | with gr.Column(scale=1):
106 | emptyBtn = gr.Button("Clear History")
107 | max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
108 | top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
109 | temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
110 |
111 | history = gr.State([])
112 | past_key_values = gr.State(None)
113 |
114 | submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history, past_key_values],
115 | [chatbot, history, past_key_values], show_progress=True)
116 | submitBtn.click(reset_user_input, [], [user_input])
117 |
118 | emptyBtn.click(reset_state, outputs=[chatbot, history, past_key_values], show_progress=True)
119 |
120 |
121 | def main():
122 | global model, tokenizer
123 |
124 | parser = HfArgumentParser((
125 | ModelArguments))
126 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
127 | # If we pass only one argument to the script and it's the path to a json file,
128 | # let's parse it to get our arguments.
129 | model_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
130 | else:
131 | model_args = parser.parse_args_into_dataclasses()[0]
132 |
133 | tokenizer = AutoTokenizer.from_pretrained(
134 | model_args.model_name_or_path, trust_remote_code=True)
135 | config = AutoConfig.from_pretrained(
136 | model_args.model_name_or_path, trust_remote_code=True)
137 |
138 | config.pre_seq_len = model_args.pre_seq_len
139 | config.prefix_projection = model_args.prefix_projection
140 |
141 | if model_args.ptuning_checkpoint is not None:
142 | print(f"Loading prefix_encoder weight from {model_args.ptuning_checkpoint}")
143 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
144 | prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
145 | new_prefix_state_dict = {}
146 | for k, v in prefix_state_dict.items():
147 | if k.startswith("transformer.prefix_encoder."):
148 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
149 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
150 | else:
151 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
152 |
153 | if model_args.quantization_bit is not None:
154 | print(f"Quantized to {model_args.quantization_bit} bit")
155 | model = model.quantize(model_args.quantization_bit)
156 | model = model.cuda()
157 | if model_args.pre_seq_len is not None:
158 | # P-tuning v2
159 | model.transformer.prefix_encoder.float()
160 |
161 | model = model.eval()
162 | demo.queue().launch(share=False, inbrowser=True)
163 |
164 |
165 |
166 | if __name__ == "__main__":
167 | main()
--------------------------------------------------------------------------------
/chatglm_model_v1/README.md:
--------------------------------------------------------------------------------
1 | # ChatGLM-6B-PT
2 | 本仓库实现了对于 ChatGLM-6B 模型基于 [P-Tuning v2](https://github.com/THUDM/P-tuning-v2) 的微调。P-Tuning v2 将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,最低只需要 7GB 显存即可运行。
3 |
4 | 下面以 [ADGEN](https://aclanthology.org/D19-1321.pdf) (广告生成) 数据集为例介绍代码的使用方法。
5 |
6 | *Read this in [English](README_en.md).
7 |
8 | ## 软件依赖
9 | 运行微调需要4.27.1版本的`transformers`。除 ChatGLM-6B 的依赖之外,还需要安装以下依赖
10 | ```
11 | pip install rouge_chinese nltk jieba datasets
12 | ```
13 | ## 使用方法
14 |
15 | ### 下载数据集
16 | ADGEN 数据集任务为根据输入(content)生成一段广告词(summary)。
17 |
18 | ```json
19 | {
20 | "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",
21 | "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
22 | }
23 | ```
24 |
25 | 从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) 或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载处理好的 ADGEN 数据集,将解压后的 `AdvertiseGen` 目录放到本目录下。
26 |
27 | ### 训练
28 |
29 | #### P-Tuning v2
30 |
31 | 运行以下指令进行训练:
32 | ```shell
33 | bash train.sh
34 | ```
35 | `train.sh` 中的 `PRE_SEQ_LEN` 和 `LR` 分别是 soft prompt 长度和训练的学习率,可以进行调节以取得最佳的效果。P-Tuning-v2 方法会冻结全部的模型参数,可通过调整 `quantization_bit` 来被原始模型的量化等级,不加此选项则为 FP16 精度加载。
36 |
37 | 在默认配置 `quantization_bit=4`、`per_device_train_batch_size=1`、`gradient_accumulation_steps=16` 下,INT4 的模型参数被冻结,一次训练迭代会以 1 的批处理大小进行 16 次累加的前后向传播,等效为 16 的总批处理大小,此时最低只需 6.7G 显存。若想在同等批处理大小下提升训练效率,可在二者乘积不变的情况下,加大 `per_device_train_batch_size` 的值,但也会带来更多的显存消耗,请根据实际情况酌情调整。
38 |
39 | 如果你想要[从本地加载模型](../README_en.md#load-the-model-locally),可以将 `train.sh` 中的 `THUDM/chatglm-6b` 改为你本地的模型路径。
40 |
41 | #### Finetune
42 |
43 | 如果需要进行全参数的 Finetune,需要安装 [Deepspeed](https://github.com/microsoft/DeepSpeed),然后运行以下指令:
44 |
45 | ```shell
46 | bash ds_train_finetune.sh
47 | ```
48 |
49 | ### 推理
50 |
51 | 在 P-tuning v2 训练时模型只保存 PrefixEncoder 部分的参数,所以在推理时需要同时加载原 ChatGLM-6B 模型以及 PrefixEncoder 的权重,因此需要指定 `evaluate.sh` 中的参数:
52 |
53 | ```shell
54 | --model_name_or_path THUDM/chatglm-6b
55 | --ptuning_checkpoint $CHECKPOINT_PATH
56 | ```
57 |
58 | 仍然兼容旧版全参保存的 Checkpoint,只需要跟之前一样设定 `model_name_or_path`:
59 |
60 | ```shell
61 | --model_name_or_path $CHECKPOINT_PATH
62 | ```
63 |
64 | 评测指标为中文 Rouge score 和 BLEU-4。生成的结果保存在
65 | `./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`。
66 |
67 | ### 例子
68 | #### 示例1
69 | * Input: 类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞
70 | * Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。
71 | * Output[微调前]: 这件上衣的材质是牛仔布,颜色是白色,风格是简约,图案是刺绣,衣样式是外套,衣款式是破洞。
72 | * Output[微调后]: 这是一款简约的牛仔外套,破洞设计,将牛仔布破洞,带来一种随意与个性。破洞的牛仔外套,展现出时尚气息,带来一种休闲感。同时,刺绣图案,让整件外套更加立体。
73 |
74 | #### 示例2
75 |
76 | * Input: 类型#裙\*版型#显瘦\*风格#文艺\*风格#简约\*图案#印花\*图案#撞色\*裙下摆#压褶\*裙长#连衣裙\*裙领型#圆领
77 | * Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。
78 | * Output[微调前]: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领 1\. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2\. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3\. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4\. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5\. 风格:文艺风格,让连衣裙更加有内涵和品味。6\. 图案:印花设计,在连衣裙上印有独特的图案。7\. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。
79 | * Output[微调后]: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。
80 |
81 | ### 评估结果
82 |
83 | | | Finetune | P-tuning v2 | LoRA |
84 | | ------------- | ----------- | ----- | ------------- |
85 | | BLEU-4 | 8.01 | 8.10 | 7.62 |
86 | | Rouge-1 | 31.23 | 31.12 | 30.60 |
87 | | Rouge-2 | 7.36 | 7.11 | 6.96 |
88 | | Rouge-l | 25.08 | 24.97 | 24.80 |
89 | | Training Loss | 3.00 | 3.74 | 3.32 |
90 |
91 |
92 |
93 | #### 实验设置
94 |
95 | ```
96 | max_source_length=64
97 | max_target_length=64
98 | max_steps=3000
99 | ```
100 |
101 | ##### P-tuning v2
102 |
103 | ```
104 | pre_seq_len=128
105 | learning_rate=2e-2
106 | quantization_bit=4
107 | per_device_train_batch_size=16
108 | gradient_accumulation_steps=1
109 | ```
110 |
111 | ##### Finetune
112 |
113 | ```
114 | learning_rate=1e-4
115 | fp16
116 | num_gpus=4
117 | per_device_train_batch_size=4
118 | gradient_accumulation_steps=1
119 | ```
120 |
121 | ##### LoRA
122 |
123 | 实现采用的是 [simple_thu_chatglm6b](https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/simple_thu_chatglm6b)
124 |
125 | ```
126 | learning_rate=5e-4
127 | per_device_train_batch_size=16
128 | gradient_accumulation_steps=1
129 | ```
130 |
131 | ## 模型部署
132 | 首先载入Tokenizer:
133 |
134 | ```python
135 | from transformers import AutoConfig, AutoModel, AutoTokenizer
136 |
137 | # 载入Tokenizer
138 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
139 | ```
140 |
141 | 1. 如果需要加载的是新 Checkpoint(只包含 PrefixEncoder 参数):
142 |
143 | ```python
144 | config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128)
145 | model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True)
146 | prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
147 | new_prefix_state_dict = {}
148 | for k, v in prefix_state_dict.items():
149 | if k.startswith("transformer.prefix_encoder."):
150 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
151 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
152 | ```
153 | 注意你可能需要将 `pre_seq_len` 改成你训练时的实际值。如果你是[从本地加载模型](https://github.com/THUDM/ChatGLM-6B#%E4%BB%8E%E6%9C%AC%E5%9C%B0%E5%8A%A0%E8%BD%BD%E6%A8%A1%E5%9E%8B)的话,需要将 `THUDM/chatglm-6b` 改成本地的模型路径(注意不是checkpoint路径)。
154 |
155 | 2. 如果需要加载的是旧 Checkpoint(包含 ChatGLM-6B 以及 PrefixEncoder 参数),或者进行的是全参数微调,则直接加载整个 Checkpoint:
156 |
157 | ```python
158 | model = AutoModel.from_pretrained(CHECKPOINT_PATH, trust_remote_code=True)
159 | ```
160 |
161 | 之后根据需求可以进行量化,也可以直接使用:
162 |
163 | ```python
164 | # Comment out the following line if you don't use quantization
165 | model = model.quantize(4)
166 | model = model.half().cuda()
167 | model.transformer.prefix_encoder.float()
168 | model = model.eval()
169 |
170 | response, history = model.chat(tokenizer, "你好", history=[])
171 | ```
172 |
173 | **[23/04/19]** 你也可以直接运行支持加载 P-Tuning v2 checkpoint 的 [web demo](./web_demo.py)
174 | ```shell
175 | bash web_demo.sh
176 | ```
177 | 可能需要修改 [web_demo.sh](./web_demo.sh) 的内容以符合你实际的 checkpoint 情况。
178 |
179 | ## 使用自己的数据集
180 | 修改 `train.sh` 和 `evaluate.sh` 中的 `train_file`、`validation_file`和`test_file`为你自己的 JSON 格式数据集路径,并将 `prompt_column` 和 `response_column` 改为 JSON 文件中输入文本和输出文本对应的 KEY。可能还需要增大 `max_source_length` 和 `max_target_length` 来匹配你自己的数据集中的最大输入输出长度。
181 |
182 | ## 对话数据集
183 |
184 | 如需要使用多轮对话数据对模型进行微调,可以提供聊天历史,例如以下是一个三轮对话的训练数据:
185 |
186 | ```json lines
187 | {"prompt": "长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "response": "用电脑能读数据流吗?水温多少", "history": []}
188 | {"prompt": "95", "response": "上下水管温差怎么样啊?空气是不是都排干净了呢?", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"]]}
189 | {"prompt": "是的。上下水管都好的", "response": "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"], ["95", "上下水管温差怎么样啊?空气是不是都排干净了呢?"]]}
190 | ```
191 |
192 | 训练时需要指定 `--history_column` 为数据中聊天历史的 key(在此例子中是 `history`),将自动把聊天历史拼接。要注意超过输入长度 `max_source_length` 的内容会被截断。
193 |
194 | 可以参考以下指令:
195 |
196 | ```shell
197 | bash train_chat.sh
198 | ```
199 |
200 | ## 引用
201 |
202 | ```
203 | @inproceedings{liu2022p,
204 | title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},
205 | author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},
206 | booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)},
207 | pages={61--68},
208 | year={2022}
209 | }
210 | ```
211 |
212 |
213 |
214 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ChatGLM2-Tuning
2 |
3 |
4 | ## 一、介绍
5 |
6 | [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) 是开源中英双语对话模型 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) 的第二代版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,进一步优化了模型,使得其具有更大的性能、更长的输入、更有效的部署和更开放的协议。ChatGLM2-6B也因此登顶C-Eval榜单。
7 |
8 | 本项目结合了 **ChatGLM-6B** 和 **ChatGLM2-6B** 进行微调,可进行全参数微调,也可以使用如下优化技术:
9 | - Peft参数有效性训练:Ptuning、Prompt-tuning、Prefix-tuning、LoRA;
10 | - DeepSpeed ZeRO训练;
11 | - 量化感知训练&推理部署;
12 |
13 | ---
14 |
15 | 开发进程:
16 | - 代码调试 ✅
17 | - 全参数训练 ✅
18 | - 参数有效性训练 ✅
19 | - 量化感知训练 ✅
20 | - 指令微调 ✅
21 | - 多轮对话 ✅
22 |
23 | ---
24 |
25 | ## 二、开始使用
26 | ### 2.1 环境安装
27 | 首先需要下载本仓库:
28 | ```shell
29 | git clone https://github.com/wjn1996/ChatGLM2-Tuning
30 | cd ChatGLM2-Tuning
31 | ```
32 |
33 | 安装环境依赖:
34 | ```
35 | pip install -r requirements.txt
36 | ```
37 |
38 | ### 2.2 数据集准备
39 |
40 | ##### (1)使用自定义的指令微调数据集
41 |
42 | 指令微调数据集中包括一个任务的指令(instruction),以及对应任务的输入(input)输出(output)。模型在训练时只会计算output的loss。
43 |
44 | 数据集格式样例:
45 | ```json
46 | {
47 | "instruction": "请为下面的评论的情感类别进行分类,候选为【积极】和【消极】",
48 | "input": "《消失的她》这部电影很好看,但是我觉得大多数人看完后都会emo",
49 | "output": "消极",
50 | }
51 | ```
52 |
53 |
54 |
55 | ##### (2)使用自定义的多轮对话数据集
56 |
57 | 多轮对话数据集在训练时有两种模式,一种是in-the-loop,另一种是session:
58 | - **in-the-loop**:一个多轮对话根据对话轮次拆解成多个样本,在训练时每个样本视为独立,根据对话历史history和当前的prompt,计算response的loss;ChatGLM2-6B默认采用这种方式进行训练多轮对话。
59 |
60 | ```json
61 | {
62 | "prompt": "长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线",
63 | "response": "用电脑能读数据流吗?水温多少",
64 | "history": []
65 | }
66 | {
67 | "prompt": "95",
68 | "response": "上下水管温差怎么样啊?空气是不是都排干净了呢?",
69 | "history": [
70 | ["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"]
71 | ]
72 | }
73 | {
74 | "prompt": "是的。上下水管都好的",
75 | "response": "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!",
76 | "history": [
77 | ["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"],
78 | ["95", "上下水管温差怎么样啊?空气是不是都排干净了呢?"]
79 | ]
80 | }
81 | ```
82 | > 以上面的一个多轮对话为例,in-the-loop设置中,数据处理时,一个多轮对话将会生成3个独立的样本,每个样本是一个序列,包含对话历史、当前的prompt以及输出response。
83 | - **session**:将整个多轮对话当作一个样本,计算所有token(或每一轮对话的output)对应的loss;
84 |
85 | ```json
86 | {
87 | "prompt": [
88 | "长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线",
89 | "95",
90 | "是的。上下水管都好的"
91 | ],
92 | "response": [
93 | "用电脑能读数据流吗?水温多少",
94 | "上下水管温差怎么样啊?空气是不是都排干净了呢?",
95 | "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!"
96 | ],
97 | }
98 | ```
99 | > 以上面的多轮对话为例,只会生成一个样本,每一轮对话的prompt和response拼接起来,所有轮次的对话拼接起来,形成类似“Q1 A1 Q2 A2 ...”格式的序列。
100 |
101 |
102 | ##### (3)获取开源评测数据集
103 |
104 | TODO
105 |
106 |
107 | ### 2.3 模型训练
108 |
109 | 训练采用Causal LM进行训练,前向传播时只会计算指定token的loss,对于指令、对话历史、input和padding部分可以通过设置label为“-100”忽略对应的loss计算。
110 |
111 | ##### (1)P-tuning训练
112 | ```bash
113 | TASK_NAME=default_task # 指定任务名称
114 | PRE_SEQ_LEN=128 # prefix token数量
115 | LR=1e-4 # 学习率
116 |
117 | CHAT_TRAIN_DATA=data/train.json
118 | CHAT_VAL_DATA=data/dev.json
119 |
120 | MODEL_NAME_OR_PATH=pre-trained-lm/chatglm-6b
121 |
122 | NUM_GPUS=8
123 |
124 | MASTER_PORT=$(shuf -n 1 -i 10000-65535)
125 |
126 | MODEL_VERSION=v1 # V1:初始化为ChatGLM-6B,V2:初始化为ChatGLM2-6B
127 |
128 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
129 | deepspeed --num_gpus=$NUM_GPUS --master_port $MASTER_PORT chatglm_model_$MODEL_VERSION/run_ptuning.py \
130 | --deepspeed deepspeed/deepspeed.json \
131 | --do_train \
132 | --train_file $CHAT_TRAIN_DATA \
133 | --test_file $CHAT_VAL_DATA \
134 | --prompt_column input \
135 | --response_column output \
136 | --model_name_or_path $MODEL_NAME_OR_PATH \
137 | --output_dir ./output/deepspeed/adgen-chatglm-6b-ft-$LR \
138 | --overwrite_output_dir \
139 | --max_source_length 256 \
140 | --max_target_length 256 \
141 | --per_device_train_batch_size 32 \
142 | --per_device_eval_batch_size 32 \
143 | --gradient_accumulation_steps 1 \
144 | --predict_with_generate \
145 | --max_steps 9000 \
146 | --logging_steps 10 \
147 | --save_steps 1000 \
148 | --learning_rate $LR \
149 | --task_name $TASK_NAME \
150 | --base_cache_dir ./.cache \
151 | --fp16
152 | # --overwrite_cache \
153 | ```
154 |
155 | 参考脚本:scripts/ds_train_ptuning.sh
156 |
157 | ##### (2)LoRA训练
158 | ```bash
159 | TASK_NAME=default_task # 指定任务名称
160 | # PRE_SEQ_LEN=128
161 | PEFT_TYPE=lora # 指定参数有效性方法
162 | LORA_DIM=8 # 指定LoRA Rank
163 | LR=1e-4 # 学习率
164 |
165 | CHAT_TRAIN_DATA=./data/train.json
166 | CHAT_VAL_DATA=./data/dev.json
167 |
168 | MODEL_NAME_OR_PATH=./pre-trained-lm/chatglm-6b
169 |
170 | NUM_GPUS=8
171 |
172 | MASTER_PORT=$(shuf -n 1 -i 10000-65535)
173 |
174 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
175 | deepspeed --num_gpus=$NUM_GPUS --master_port $MASTER_PORT chatglm_model_v1/run_peft.py \
176 | --deepspeed deepspeed/deepspeed.json \
177 | --do_train \
178 | --train_file $CHAT_TRAIN_DATA \
179 | --test_file $CHAT_VAL_DATA \
180 | --prompt_column input \
181 | --response_column output \
182 | --model_name_or_path $MODEL_NAME_OR_PATH \
183 | --output_dir ./output/deepspeed/chatglm-6b-$TASK_NAME-$PEFT_TYPE-$LORA_DIM-$LR \
184 | --overwrite_output_dir \
185 | --max_source_length 256 \
186 | --max_target_length 1024 \
187 | --per_device_train_batch_size 32 \
188 | --per_device_eval_batch_size 32 \
189 | --gradient_accumulation_steps 1 \
190 | --predict_with_generate \
191 | --max_steps 9000 \
192 | --logging_steps 10 \
193 | --save_steps 1000 \
194 | --learning_rate $LR \
195 | --peft_type $PEFT_TYPE \
196 | --lora_dim $LORA_DIM \
197 | --task_name $TASK_NAME \
198 | --base_cache_dir ./.cache/ \
199 | --fp16
200 | # --overwrite_cache \
201 | ```
202 |
203 | 参考脚本:scripts/ds_train_peft.sh
204 |
205 | 如果要使用INT4量化感知训练,添加参数
206 |
207 | > --quantization_bit 4
208 |
209 | 即可。
210 |
211 | ### 2.4 模型推理与部署
212 |
213 | #### API部署
214 | 部署文件“api.py”:
215 | ```python
216 | from fastapi import FastAPI, Request
217 | from transformers import AutoTokenizer, AutoModel
218 | import uvicorn, json, datetime
219 | import torch
220 |
221 | DEVICE = "cuda"
222 | DEVICE_ID = "0"
223 | CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
224 |
225 |
226 | def torch_gc():
227 | if torch.cuda.is_available():
228 | with torch.cuda.device(CUDA_DEVICE):
229 | torch.cuda.empty_cache()
230 | torch.cuda.ipc_collect()
231 |
232 |
233 | app = FastAPI()
234 |
235 |
236 | @app.post("/")
237 | async def create_item(request: Request):
238 | global model, tokenizer
239 | json_post_raw = await request.json()
240 | json_post = json.dumps(json_post_raw)
241 | json_post_list = json.loads(json_post)
242 | prompt = json_post_list.get('prompt')
243 | history = json_post_list.get('history')
244 | max_length = json_post_list.get('max_length')
245 | top_p = json_post_list.get('top_p')
246 | temperature = json_post_list.get('temperature')
247 | response, history = model.chat(tokenizer,
248 | prompt,
249 | history=history,
250 | max_length=max_length if max_length else 2048,
251 | top_p=top_p if top_p else 0.7,
252 | temperature=temperature if temperature else 0.95)
253 | now = datetime.datetime.now()
254 | time = now.strftime("%Y-%m-%d %H:%M:%S")
255 | answer = {
256 | "response": response,
257 | "history": history,
258 | "status": 200,
259 | "time": time
260 | }
261 | log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
262 | print(log)
263 | torch_gc()
264 | return answer
265 |
266 |
267 | if __name__ == '__main__':
268 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
269 | model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() # FP16
270 | model.eval()
271 | uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
272 | ```
273 |
274 | 执行:
275 | > python3 api.py
276 |
277 | API调用方式(例如选择8000进行POST请求):
278 | ```bash
279 | curl -X POST "http://127.0.0.1:8000" \
280 | -H 'Content-Type: application/json' \
281 | -d '{"prompt": "你好", "history": []}'
282 | ```
283 | #### 量化
284 | 加载模型时选择FP16+INT8量化:
285 | model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).quantize(8).half().cuda()
286 |
287 | 更多部署方法详见:https://github.com/THUDM/ChatGLM-6B
288 |
--------------------------------------------------------------------------------
/chatglm_model_v1/README_en.md:
--------------------------------------------------------------------------------
1 | # ChatGLM-6B-PT
2 | This repository implements tuning of the ChatGLM-6B model based on [P-Tuning v2](https://github.com/THUDM/P-tuning-v2). P-Tuning v2 reduces the amount of parameters that need to be optimized to 0.1% of the full fine-tuning, and then through model quantization, Gradient Checkpoint and other methods, it only needs a minimum of 7GB of video memory to run.
3 |
4 | The following uses the [ADGEN](https://aclanthology.org/D19-1321.pdf) (advertising generation) dataset as an example to introduce how to use the code.
5 |
6 | ## Software dependencies
7 | Running p-tuning requires version 4.27.1 of `transformers`. In addition to the dependencies of ChatGLM-6B, the following dependencies are required
8 | ```
9 | pip install rouge_chinese nltk jieba datasets
10 | ```
11 | ## Instructions
12 |
13 | ### Download the dataset
14 | The task of the ADGEN dataset is to generate an advertisement word (summary) based on the input (content).
15 |
16 | ```json
17 | {
18 | "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",
19 | "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
20 | }
21 | ```
22 |
23 | From [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) download the processed ADGEN dataset, and put the decompressed `AdvertiseGen` directory into this directory.
24 |
25 | ### Training
26 |
27 | #### P-Tuning v2
28 |
29 | Run the following commands for training:
30 | ```shell
31 | bash train.sh
32 | ```
33 | `PRE_SEQ_LEN` and `LR` in `train.sh` are soft prompt length and training learning rate respectively, which can be adjusted to achieve the best results. The P-Tuning-v2 method will freeze all model parameters, and the quantization level of the original model can be adjusted by adjusting `quantization_bit`. If this option is not added, it will be loaded with FP16 precision.
34 |
35 | Under the default configuration of `per_device_train_batch_size=1`, `gradient_accumulation_steps=16`, the model parameters of INT4 are frozen, and a training iteration will perform 16 cumulative forward and backward propagations with a batch size of 1, which is equivalent to the total batch size of 16, and only 6.7G GPU memory is required at this time with `quantization_bit=4`. If you want to improve the training efficiency under the same batch size, you can increase the value of `per_device_train_batch_size` while keeping the product of the two unchanged, but it will also bring more GPU memory consumption, please adjust it according to the actual situation.
36 |
37 | If you want to [load the model locally](../README_en.md#load-the-model-locally), you can change `THUDM/chatglm-6b` in `train.sh` to your local model path.
38 |
39 | #### Finetune
40 | To finetune the full parameters, you need to install [Deepspeed](https://github.com/microsoft/DeepSpeed), and then run the following command:
41 |
42 | ```shell
43 | bash ds_train_finetune.sh
44 | ```
45 |
46 | ### Inference
47 |
48 | During P-tuning v2 training, the model only saves the parameters of the PrefixEncoder part, so the original ChatGLM-6B model and the weight of the PrefixEncoder need to be loaded at the same time during inference, and the arguments need to be specified in `evaluate.sh`:
49 |
50 | ```shell
51 | --model_name_or_path THUDM/chatglm-6b
52 | --ptuning_checkpoint $CHECKPOINT_PATH
53 | ```
54 |
55 | It is still compatible with the old version of Checkpoint saved with full parameters, just set `model_name_or_path` as before:
56 |
57 | ```shell
58 | --model_name_or_path $CHECKPOINT_PATH
59 | ```
60 |
61 | The evaluation indicators are Chinese Rouge score and BLEU-4. The generated results are saved in
62 | `./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`.
63 |
64 | ### Example
65 | #### Example 1
66 | * Input: 类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞
67 | * Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。
68 | * Output[before tuning]: 这件上衣的材质是牛仔布,颜色是白色,风格是简约,图案是刺绣,衣样式是外套,衣款式是破洞。
69 | * Output[after tuning]: 这是一款简约的牛仔外套,破洞设计,将牛仔布破洞,带来一种随意与个性。破洞的牛仔外套,展现出时尚气息,带来一种休闲感。同时,刺绣图案,让整件外套更加立体。
70 |
71 | #### Example 2
72 |
73 | * Input: 类型#裙\*版型#显瘦\*风格#文艺\*风格#简约\*图案#印花\*图案#撞色\*裙下摆#压褶\*裙长#连衣裙\*裙领型#圆领
74 | * Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。
75 | * Output[before tuning]: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领 1\. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2\. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3\. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4\. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5\. 风格:文艺风格,让连衣裙更加有内涵和品味。6\. 图案:印花设计,在连衣裙上印有独特的图案。7\. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。
76 | * Output[after tuning]: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。
77 |
78 | ### evaluation result
79 |
80 | | | Finetune | P-tuning v2 | LoRA |
81 | | ------------- | ----------- | ----- | ------------- |
82 | | BLEU-4 | 8.01 | 8.10 | 7.62 |
83 | | Rouge-1 | 31.23 | 31.12 | 30.60 |
84 | | Rouge-2 | 7.36 | 7.11 | 6.96 |
85 | | Rouge-l | 25.08 | 24.97 | 24.80 |
86 | | Training Loss | 3.00 | 3.74 | 3.32 |
87 |
88 | #### Experiment Settings
89 |
90 | ```
91 | max_source_length=64
92 | max_target_length=64
93 | max_steps=3000
94 | ```
95 |
96 | ##### P-tuning v2
97 |
98 | ```
99 | pre_seq_len=128
100 | learning_rate=2e-2
101 | quantization_bit=4
102 | per_device_train_batch_size=16
103 | gradient_accumulation_steps=1
104 | ```
105 |
106 | ##### Finetune
107 |
108 | ```
109 | learning_rate=1e-4
110 | fp16
111 | num_gpus=4
112 | per_device_train_batch_size=4
113 | gradient_accumulation_steps=1
114 | ```
115 |
116 | ##### LoRA
117 |
118 | The implementation uses [simple_thu_chatglm6b](https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/simple_thu_chatglm6b)
119 |
120 | ```
121 | learning_rate=5e-4
122 | per_device_train_batch_size=16
123 | gradient_accumulation_steps=1
124 | ```
125 |
126 | ## Model Deployment
127 | First load the tokenizer:
128 |
129 | ```python
130 | from transformers import AutoConfig, AutoModel, AutoTokenizer
131 |
132 | # Load Tokenizer
133 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
134 | ```
135 |
136 | 1. If a new Checkpoint needs to be loaded (only contains the PrefixEncoder parameter):
137 |
138 | ```python
139 | config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128)
140 | model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True)
141 | prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
142 | new_prefix_state_dict = {}
143 | for k, v in prefix_state_dict.items():
144 | if k.startswith("transformer.prefix_encoder."):
145 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
146 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
147 | ```
148 | Note that you may need to change `pre_seq_len` to the actual value of your training. If you [load model from local](../README_en.md#load-the-model-locally), you need to change `THUDM/chatglm-6b` to the local model path (not the checkpoint path).
149 |
150 | 2. If you need to load the old checkpoint (including both ChatGLM-6B and PrefixEncoder parameters), or perform full parameter fine-tuning, then directly load the entire checkpoint:
151 |
152 | ```python
153 | model = AutoModel.from_pretrained(CHECKPOINT_PATH, trust_remote_code=True)
154 | ```
155 |
156 | Then it can be quantified according to the needs, or it can be used directly:
157 |
158 | ```python
159 | # Comment out the following line if you don't use quantization
160 | model = model. quantize(4)
161 | model = model.half().cuda()
162 | model.transformer.prefix_encoder.float()
163 | model = model.eval()
164 |
165 | response, history = model.chat(tokenizer, "Hello", history=[])
166 | ```
167 |
168 | **[23/04/19]** You can also directly run [web demo](./web_demo.py) which supports loading P-Tuning v2 checkpoint
169 | ```shell
170 | bash web_demo.sh
171 | ```
172 | It may be necessary to modify the content of [web_demo.sh](./web_demo.sh) to match your actual checkpoint situation.
173 |
174 | ## Use your own dataset
175 | Modify `train_file`, `validation_file` and `test_file` in `train.sh` and `evaluate.sh` to your own JSON format dataset paths, and change `prompt_column` and `response_column` to the keys in the JSON file corresponding to input text and output text.
176 | You may also need to increase `max_source_length` and `max_target_length` to match the maximum input and output lengths in your own dataset.
177 |
178 | ## Dialog Dataset
179 |
180 | If you need to use multiple rounds of dialogue data to train the model, you can provide chat history. For example, the following is the training data for a three-round dialogue:
181 |
182 | ```json lines
183 | {"prompt": "长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "response": "用电脑能读数据流吗?水温多少", "history": []}
184 | {"prompt": "95", "response": "上下水管温差怎么样啊?空气是不是都排干净了呢?", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"]]}
185 | {"prompt": "是的。上下水管都好的", "response": "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"], ["95", "上下水管温差怎么样啊?空气是不是都排干净了呢?"]]}
186 | ```
187 |
188 | During training, you need to specify `--history_column` as the key of the chat history in the data (`history` in this example), and the chat history will be stitched automatically. Note that content exceeding the input length `max_source_length` will be truncated.
189 |
190 | You can refer to the following instructions:
191 |
192 | ```shell
193 | bash train_chat.sh
194 | ```
195 |
196 | ## Citation
197 |
198 | ```
199 | @inproceedings{liu2022p,
200 | title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},
201 | author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},
202 | booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)},
203 | pages={61--68},
204 | year={2022}
205 | }
206 | ```
--------------------------------------------------------------------------------
/chatglm_model_v1/arguments.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Optional
3 |
4 |
5 | @dataclass
6 | class ModelArguments:
7 | """
8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
9 | """
10 |
11 | model_name_or_path: str = field(
12 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
13 | )
14 | ptuning_checkpoint: str = field(
15 | default=None, metadata={"help": "Path to p-tuning v2 checkpoints"}
16 | )
17 | config_name: Optional[str] = field(
18 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
19 | )
20 | tokenizer_name: Optional[str] = field(
21 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
22 | )
23 | cache_dir: Optional[str] = field(
24 | default=None,
25 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
26 | )
27 | use_fast_tokenizer: bool = field(
28 | default=True,
29 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
30 | )
31 | model_revision: str = field(
32 | default="main",
33 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
34 | )
35 | use_auth_token: bool = field(
36 | default=False,
37 | metadata={
38 | "help": (
39 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
40 | "with private models)."
41 | )
42 | },
43 | )
44 | resize_position_embeddings: Optional[bool] = field(
45 | default=None,
46 | metadata={
47 | "help": (
48 | "Whether to automatically resize the position embeddings if `max_source_length` exceeds "
49 | "the model's position embeddings."
50 | )
51 | },
52 | )
53 | quantization_bit: Optional[int] = field(
54 | default=None
55 | )
56 | pre_seq_len: Optional[int] = field(
57 | default=None
58 | )
59 | prefix_projection: bool = field(
60 | default=False
61 | )
62 |
63 | @dataclass
64 | class PeftArguments:
65 | peft_type: Optional[str] = field(
66 | default=None, metadata={"help": "The kind of parameter-efficient learning."}
67 | )
68 | prompt_tuning_initial_text: Optional[str] = field(
69 | default=None, metadata={"help": "The initial token list of the soft prompt-tuning"}
70 | )
71 | lora_dim: Optional[int] = field(
72 | default=8, metadata={"help": "The dimension of LoRA"}
73 | )
74 |
75 |
76 | @dataclass
77 | class DataTrainingArguments:
78 | """
79 | Arguments pertaining to what data we are going to input our model for training and eval.
80 | """
81 |
82 | lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."})
83 |
84 | dataset_name: Optional[str] = field(
85 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
86 | )
87 | dataset_config_name: Optional[str] = field(
88 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
89 | )
90 | prompt_column: Optional[str] = field(
91 | default=None,
92 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
93 | )
94 | response_column: Optional[str] = field(
95 | default=None,
96 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
97 | )
98 | history_column: Optional[str] = field(
99 | default=None,
100 | metadata={"help": "The name of the column in the datasets containing the history of chat."},
101 | )
102 | train_file: Optional[str] = field(
103 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
104 | )
105 | validation_file: Optional[str] = field(
106 | default=None,
107 | metadata={
108 | "help": (
109 | "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
110 | )
111 | },
112 | )
113 | test_file: Optional[str] = field(
114 | default=None,
115 | metadata={
116 | "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
117 | },
118 | )
119 | task_name: Optional[str] = field(
120 | default="default_task",
121 | metadata={
122 | "help": "The task name."
123 | },
124 | )
125 | overwrite_cache: bool = field(
126 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
127 | )
128 | preprocessing_num_workers: Optional[int] = field(
129 | default=None,
130 | metadata={"help": "The number of processes to use for the preprocessing."},
131 | )
132 | max_source_length: Optional[int] = field(
133 | default=1024,
134 | metadata={
135 | "help": (
136 | "The maximum total input sequence length after tokenization. Sequences longer "
137 | "than this will be truncated, sequences shorter will be padded."
138 | )
139 | },
140 | )
141 | max_target_length: Optional[int] = field(
142 | default=128,
143 | metadata={
144 | "help": (
145 | "The maximum total sequence length for target text after tokenization. Sequences longer "
146 | "than this will be truncated, sequences shorter will be padded."
147 | )
148 | },
149 | )
150 | val_max_target_length: Optional[int] = field(
151 | default=None,
152 | metadata={
153 | "help": (
154 | "The maximum total sequence length for validation target text after tokenization. Sequences longer "
155 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
156 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
157 | "during ``evaluate`` and ``predict``."
158 | )
159 | },
160 | )
161 | pad_to_max_length: bool = field(
162 | default=False,
163 | metadata={
164 | "help": (
165 | "Whether to pad all samples to model maximum sentence length. "
166 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
167 | "efficient on GPU but very bad for TPU."
168 | )
169 | },
170 | )
171 | max_train_samples: Optional[int] = field(
172 | default=None,
173 | metadata={
174 | "help": (
175 | "For debugging purposes or quicker training, truncate the number of training examples to this "
176 | "value if set."
177 | )
178 | },
179 | )
180 | max_eval_samples: Optional[int] = field(
181 | default=None,
182 | metadata={
183 | "help": (
184 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
185 | "value if set."
186 | )
187 | },
188 | )
189 | max_predict_samples: Optional[int] = field(
190 | default=None,
191 | metadata={
192 | "help": (
193 | "For debugging purposes or quicker training, truncate the number of prediction examples to this "
194 | "value if set."
195 | )
196 | },
197 | )
198 | num_beams: Optional[int] = field(
199 | default=None,
200 | metadata={
201 | "help": (
202 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
203 | "which is used during ``evaluate`` and ``predict``."
204 | )
205 | },
206 | )
207 | ignore_pad_token_for_loss: bool = field(
208 | default=True,
209 | metadata={
210 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
211 | },
212 | )
213 | source_prefix: Optional[str] = field(
214 | default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
215 | )
216 |
217 | forced_bos_token: Optional[str] = field(
218 | default=None,
219 | metadata={
220 | "help": (
221 | "The token to force as the first generated token after the decoder_start_token_id."
222 | "Useful for multilingual models like mBART where the first generated token."
223 | "needs to be the target language token (Usually it is the target language token)."
224 | )
225 | },
226 | )
227 |
228 | base_cache_dir: Optional[str] = field(
229 | default=None,
230 | metadata={
231 | "help": (
232 | "The path of cache."
233 | )
234 | },
235 | )
236 |
237 |
238 |
239 | def __post_init__(self):
240 | if self.dataset_name is None and self.train_file is None and self.validation_file is None and self.test_file is None:
241 | raise ValueError("Need either a dataset name or a training/validation/test file.")
242 | else:
243 | if self.train_file is not None:
244 | extension = self.train_file.split(".")[-1]
245 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
246 | if self.validation_file is not None:
247 | extension = self.validation_file.split(".")[-1]
248 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
249 | if self.val_max_target_length is None:
250 | self.val_max_target_length = self.max_target_length
251 |
252 |
--------------------------------------------------------------------------------
/chatglm_model_v2/arguments.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Optional
3 |
4 |
5 | @dataclass
6 | class ModelArguments:
7 | """
8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
9 | """
10 |
11 | model_name_or_path: str = field(
12 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
13 | )
14 | ptuning_checkpoint: str = field(
15 | default=None, metadata={"help": "Path to p-tuning v2 checkpoints"}
16 | )
17 | config_name: Optional[str] = field(
18 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
19 | )
20 | tokenizer_name: Optional[str] = field(
21 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
22 | )
23 | cache_dir: Optional[str] = field(
24 | default=None,
25 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
26 | )
27 | use_fast_tokenizer: bool = field(
28 | default=True,
29 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
30 | )
31 | model_revision: str = field(
32 | default="main",
33 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
34 | )
35 | use_auth_token: bool = field(
36 | default=False,
37 | metadata={
38 | "help": (
39 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
40 | "with private models)."
41 | )
42 | },
43 | )
44 | resize_position_embeddings: Optional[bool] = field(
45 | default=None,
46 | metadata={
47 | "help": (
48 | "Whether to automatically resize the position embeddings if `max_source_length` exceeds "
49 | "the model's position embeddings."
50 | )
51 | },
52 | )
53 | quantization_bit: Optional[int] = field(
54 | default=None
55 | )
56 | pre_seq_len: Optional[int] = field(
57 | default=None
58 | )
59 | prefix_projection: bool = field(
60 | default=False
61 | )
62 |
63 | @dataclass
64 | class PeftArguments:
65 | peft_type: Optional[str] = field(
66 | default=None, metadata={"help": "The kind of parameter-efficient learning."}
67 | )
68 | prompt_tuning_initial_text: Optional[str] = field(
69 | default=None, metadata={"help": "The initial token list of the soft prompt-tuning"}
70 | )
71 | lora_dim: Optional[int] = field(
72 | default=8, metadata={"help": "The dimension of LoRA"}
73 | )
74 |
75 |
76 | @dataclass
77 | class DataTrainingArguments:
78 | """
79 | Arguments pertaining to what data we are going to input our model for training and eval.
80 | """
81 |
82 | lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."})
83 |
84 | dataset_name: Optional[str] = field(
85 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
86 | )
87 | dataset_config_name: Optional[str] = field(
88 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
89 | )
90 | prompt_column: Optional[str] = field(
91 | default=None,
92 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
93 | )
94 | response_column: Optional[str] = field(
95 | default=None,
96 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
97 | )
98 | history_column: Optional[str] = field(
99 | default=None,
100 | metadata={"help": "The name of the column in the datasets containing the history of chat."},
101 | )
102 | train_file: Optional[str] = field(
103 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
104 | )
105 | validation_file: Optional[str] = field(
106 | default=None,
107 | metadata={
108 | "help": (
109 | "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
110 | )
111 | },
112 | )
113 | test_file: Optional[str] = field(
114 | default=None,
115 | metadata={
116 | "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
117 | },
118 | )
119 | task_name: Optional[str] = field(
120 | default="default_task",
121 | metadata={
122 | "help": "The task name."
123 | },
124 | )
125 | overwrite_cache: bool = field(
126 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
127 | )
128 | preprocessing_num_workers: Optional[int] = field(
129 | default=None,
130 | metadata={"help": "The number of processes to use for the preprocessing."},
131 | )
132 | max_source_length: Optional[int] = field(
133 | default=1024,
134 | metadata={
135 | "help": (
136 | "The maximum total input sequence length after tokenization. Sequences longer "
137 | "than this will be truncated, sequences shorter will be padded."
138 | )
139 | },
140 | )
141 | max_target_length: Optional[int] = field(
142 | default=128,
143 | metadata={
144 | "help": (
145 | "The maximum total sequence length for target text after tokenization. Sequences longer "
146 | "than this will be truncated, sequences shorter will be padded."
147 | )
148 | },
149 | )
150 | val_max_target_length: Optional[int] = field(
151 | default=None,
152 | metadata={
153 | "help": (
154 | "The maximum total sequence length for validation target text after tokenization. Sequences longer "
155 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
156 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
157 | "during ``evaluate`` and ``predict``."
158 | )
159 | },
160 | )
161 | pad_to_max_length: bool = field(
162 | default=False,
163 | metadata={
164 | "help": (
165 | "Whether to pad all samples to model maximum sentence length. "
166 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
167 | "efficient on GPU but very bad for TPU."
168 | )
169 | },
170 | )
171 | max_train_samples: Optional[int] = field(
172 | default=None,
173 | metadata={
174 | "help": (
175 | "For debugging purposes or quicker training, truncate the number of training examples to this "
176 | "value if set."
177 | )
178 | },
179 | )
180 | max_eval_samples: Optional[int] = field(
181 | default=None,
182 | metadata={
183 | "help": (
184 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
185 | "value if set."
186 | )
187 | },
188 | )
189 | max_predict_samples: Optional[int] = field(
190 | default=None,
191 | metadata={
192 | "help": (
193 | "For debugging purposes or quicker training, truncate the number of prediction examples to this "
194 | "value if set."
195 | )
196 | },
197 | )
198 | num_beams: Optional[int] = field(
199 | default=None,
200 | metadata={
201 | "help": (
202 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
203 | "which is used during ``evaluate`` and ``predict``."
204 | )
205 | },
206 | )
207 | ignore_pad_token_for_loss: bool = field(
208 | default=True,
209 | metadata={
210 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
211 | },
212 | )
213 | source_prefix: Optional[str] = field(
214 | default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
215 | )
216 |
217 | forced_bos_token: Optional[str] = field(
218 | default=None,
219 | metadata={
220 | "help": (
221 | "The token to force as the first generated token after the decoder_start_token_id."
222 | "Useful for multilingual models like mBART where the first generated token."
223 | "needs to be the target language token (Usually it is the target language token)."
224 | )
225 | },
226 | )
227 |
228 | base_cache_dir: Optional[str] = field(
229 | default=None,
230 | metadata={
231 | "help": (
232 | "The path of cache."
233 | )
234 | },
235 | )
236 |
237 |
238 |
239 | def __post_init__(self):
240 | if self.dataset_name is None and self.train_file is None and self.validation_file is None and self.test_file is None:
241 | raise ValueError("Need either a dataset name or a training/validation/test file.")
242 | else:
243 | if self.train_file is not None:
244 | extension = self.train_file.split(".")[-1]
245 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
246 | if self.validation_file is not None:
247 | extension = self.validation_file.split(".")[-1]
248 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
249 | if self.val_max_target_length is None:
250 | self.val_max_target_length = self.max_target_length
251 |
252 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/chatglm_model_v1/trainer_seq2seq.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Any, Dict, List, Optional, Tuple, Union
16 |
17 | import torch
18 | from torch import nn
19 | from torch.utils.data import Dataset
20 |
21 | from transformers.deepspeed import is_deepspeed_zero3_enabled
22 | from trainer import Trainer
23 | from transformers.trainer_utils import PredictionOutput
24 | from transformers.utils import logging
25 |
26 |
27 | logger = logging.get_logger(__name__)
28 |
29 |
30 | class Seq2SeqTrainer(Trainer):
31 | def evaluate(
32 | self,
33 | eval_dataset: Optional[Dataset] = None,
34 | ignore_keys: Optional[List[str]] = None,
35 | metric_key_prefix: str = "eval",
36 | **gen_kwargs
37 | ) -> Dict[str, float]:
38 | """
39 | Run evaluation and returns metrics.
40 |
41 | The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
42 | (pass it to the init `compute_metrics` argument).
43 |
44 | You can also subclass and override this method to inject custom behavior.
45 |
46 | Args:
47 | eval_dataset (`Dataset`, *optional*):
48 | Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns
49 | not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
50 | method.
51 | ignore_keys (`List[str]`, *optional*):
52 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when
53 | gathering predictions.
54 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
55 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
56 | "eval_bleu" if the prefix is `"eval"` (default)
57 | max_length (`int`, *optional*):
58 | The maximum target length to use when predicting with the generate method.
59 | num_beams (`int`, *optional*):
60 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no
61 | beam search.
62 | gen_kwargs:
63 | Additional `generate` specific kwargs.
64 |
65 | Returns:
66 | A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
67 | dictionary also contains the epoch number which comes from the training state.
68 | """
69 |
70 | gen_kwargs = gen_kwargs.copy()
71 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
72 | gen_kwargs["max_length"] = self.args.generation_max_length
73 | gen_kwargs["num_beams"] = (
74 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
75 | )
76 | self._gen_kwargs = gen_kwargs
77 |
78 | return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
79 |
80 | def predict(
81 | self,
82 | test_dataset: Dataset,
83 | ignore_keys: Optional[List[str]] = None,
84 | metric_key_prefix: str = "test",
85 | **gen_kwargs
86 | ) -> PredictionOutput:
87 | """
88 | Run prediction and returns predictions and potential metrics.
89 |
90 | Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
91 | will also return metrics, like in `evaluate()`.
92 |
93 | Args:
94 | test_dataset (`Dataset`):
95 | Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the
96 | `model.forward()` method are automatically removed. Has to implement the method `__len__`
97 | ignore_keys (`List[str]`, *optional*):
98 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when
99 | gathering predictions.
100 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
101 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
102 | "eval_bleu" if the prefix is `"eval"` (default)
103 | max_length (`int`, *optional*):
104 | The maximum target length to use when predicting with the generate method.
105 | num_beams (`int`, *optional*):
106 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no
107 | beam search.
108 | gen_kwargs:
109 | Additional `generate` specific kwargs.
110 |
111 |
112 |
113 | If your predictions or labels have different sequence lengths (for instance because you're doing dynamic
114 | padding in a token classification task) the predictions will be padded (on the right) to allow for
115 | concatenation into one array. The padding index is -100.
116 |
117 |
118 |
119 | Returns: *NamedTuple* A namedtuple with the following keys:
120 |
121 | - predictions (`np.ndarray`): The predictions on `test_dataset`.
122 | - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
123 | - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
124 | labels).
125 | """
126 |
127 | gen_kwargs = gen_kwargs.copy()
128 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
129 | gen_kwargs["max_length"] = self.args.generation_max_length
130 | gen_kwargs["num_beams"] = (
131 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
132 | )
133 | self._gen_kwargs = gen_kwargs
134 |
135 |
136 | return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
137 |
138 | def prediction_step(
139 | self,
140 | model: nn.Module,
141 | inputs: Dict[str, Union[torch.Tensor, Any]],
142 | prediction_loss_only: bool,
143 | ignore_keys: Optional[List[str]] = None,
144 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
145 | """
146 | Perform an evaluation step on `model` using `inputs`.
147 |
148 | Subclass and override to inject custom behavior.
149 |
150 | Args:
151 | model (`nn.Module`):
152 | The model to evaluate.
153 | inputs (`Dict[str, Union[torch.Tensor, Any]]`):
154 | The inputs and targets of the model.
155 |
156 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
157 | argument `labels`. Check your model's documentation for all accepted arguments.
158 | prediction_loss_only (`bool`):
159 | Whether or not to return the loss only.
160 |
161 | Return:
162 | Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
163 | labels (each being optional).
164 | """
165 |
166 | if not self.args.predict_with_generate or prediction_loss_only:
167 | return super().prediction_step(
168 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
169 | )
170 |
171 | has_labels = "labels" in inputs
172 | inputs = self._prepare_inputs(inputs)
173 |
174 | # XXX: adapt synced_gpus for fairscale as well
175 | gen_kwargs = self._gen_kwargs.copy()
176 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
177 | gen_kwargs["max_length"] = self.model.config.max_length
178 | gen_kwargs["num_beams"] = (
179 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
180 | )
181 | default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
182 | gen_kwargs["synced_gpus"] = (
183 | gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
184 | )
185 |
186 | if "attention_mask" in inputs:
187 | gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
188 | if "position_ids" in inputs:
189 | gen_kwargs["position_ids"] = inputs.get("position_ids", None)
190 | if "global_attention_mask" in inputs:
191 | gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None)
192 |
193 | # prepare generation inputs
194 | # some encoder-decoder models can have varying encoder's and thus
195 | # varying model input names
196 | if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
197 | generation_inputs = inputs[self.model.encoder.main_input_name]
198 | else:
199 | generation_inputs = inputs[self.model.main_input_name]
200 |
201 | gen_kwargs["input_ids"] = generation_inputs
202 | generated_tokens = self.model.generate(**gen_kwargs)
203 | generated_tokens = generated_tokens[:, generation_inputs.size()[-1]:]
204 |
205 | # in case the batch is shorter than max length, the output should be padded
206 | if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]:
207 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
208 | elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < (
209 | gen_kwargs["max_new_tokens"] + 1
210 | ):
211 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1)
212 |
213 | loss = None
214 |
215 | if self.args.prediction_loss_only:
216 | return (loss, None, None)
217 |
218 | if has_labels:
219 | labels = inputs["labels"]
220 | if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]:
221 | labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
222 | elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < (
223 | gen_kwargs["max_new_tokens"] + 1
224 | ):
225 | labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1))
226 | else:
227 | labels = None
228 |
229 | return (loss, generated_tokens, labels)
230 |
231 | def _pad_tensors_to_max_len(self, tensor, max_length):
232 | if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
233 | # If PAD token is not defined at least EOS token has to be defined
234 | pad_token_id = (
235 | self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
236 | )
237 | else:
238 | if self.model.config.pad_token_id is not None:
239 | pad_token_id = self.model.config.pad_token_id
240 | else:
241 | raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
242 |
243 | padded_tensor = pad_token_id * torch.ones(
244 | (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
245 | )
246 | padded_tensor[:, : tensor.shape[-1]] = tensor
247 | return padded_tensor
248 |
--------------------------------------------------------------------------------
/chatglm_model_v2/trainer_seq2seq.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Any, Dict, List, Optional, Tuple, Union
16 |
17 | import torch
18 | from torch import nn
19 | from torch.utils.data import Dataset
20 |
21 | from transformers.deepspeed import is_deepspeed_zero3_enabled
22 | from trainer import PrefixTrainer
23 | from transformers.trainer_utils import PredictionOutput
24 | from transformers.utils import logging
25 |
26 |
27 | logger = logging.get_logger(__name__)
28 |
29 |
30 | class Seq2SeqTrainer(PrefixTrainer):
31 | def evaluate(
32 | self,
33 | eval_dataset: Optional[Dataset] = None,
34 | ignore_keys: Optional[List[str]] = None,
35 | metric_key_prefix: str = "eval",
36 | **gen_kwargs
37 | ) -> Dict[str, float]:
38 | """
39 | Run evaluation and returns metrics.
40 |
41 | The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
42 | (pass it to the init `compute_metrics` argument).
43 |
44 | You can also subclass and override this method to inject custom behavior.
45 |
46 | Args:
47 | eval_dataset (`Dataset`, *optional*):
48 | Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns
49 | not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
50 | method.
51 | ignore_keys (`List[str]`, *optional*):
52 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when
53 | gathering predictions.
54 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
55 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
56 | "eval_bleu" if the prefix is `"eval"` (default)
57 | max_length (`int`, *optional*):
58 | The maximum target length to use when predicting with the generate method.
59 | num_beams (`int`, *optional*):
60 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no
61 | beam search.
62 | gen_kwargs:
63 | Additional `generate` specific kwargs.
64 |
65 | Returns:
66 | A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
67 | dictionary also contains the epoch number which comes from the training state.
68 | """
69 |
70 | gen_kwargs = gen_kwargs.copy()
71 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
72 | gen_kwargs["max_length"] = self.args.generation_max_length
73 | gen_kwargs["num_beams"] = (
74 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
75 | )
76 | self._gen_kwargs = gen_kwargs
77 |
78 | return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
79 |
80 | def predict(
81 | self,
82 | test_dataset: Dataset,
83 | ignore_keys: Optional[List[str]] = None,
84 | metric_key_prefix: str = "test",
85 | **gen_kwargs
86 | ) -> PredictionOutput:
87 | """
88 | Run prediction and returns predictions and potential metrics.
89 |
90 | Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
91 | will also return metrics, like in `evaluate()`.
92 |
93 | Args:
94 | test_dataset (`Dataset`):
95 | Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the
96 | `model.forward()` method are automatically removed. Has to implement the method `__len__`
97 | ignore_keys (`List[str]`, *optional*):
98 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when
99 | gathering predictions.
100 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
101 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
102 | "eval_bleu" if the prefix is `"eval"` (default)
103 | max_length (`int`, *optional*):
104 | The maximum target length to use when predicting with the generate method.
105 | num_beams (`int`, *optional*):
106 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no
107 | beam search.
108 | gen_kwargs:
109 | Additional `generate` specific kwargs.
110 |
111 |
112 |
113 | If your predictions or labels have different sequence lengths (for instance because you're doing dynamic
114 | padding in a token classification task) the predictions will be padded (on the right) to allow for
115 | concatenation into one array. The padding index is -100.
116 |
117 |
118 |
119 | Returns: *NamedTuple* A namedtuple with the following keys:
120 |
121 | - predictions (`np.ndarray`): The predictions on `test_dataset`.
122 | - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
123 | - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
124 | labels).
125 | """
126 |
127 | gen_kwargs = gen_kwargs.copy()
128 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
129 | gen_kwargs["max_length"] = self.args.generation_max_length
130 | gen_kwargs["num_beams"] = (
131 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
132 | )
133 | self._gen_kwargs = gen_kwargs
134 |
135 |
136 | return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
137 |
138 | def prediction_step(
139 | self,
140 | model: nn.Module,
141 | inputs: Dict[str, Union[torch.Tensor, Any]],
142 | prediction_loss_only: bool,
143 | ignore_keys: Optional[List[str]] = None,
144 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
145 | """
146 | Perform an evaluation step on `model` using `inputs`.
147 |
148 | Subclass and override to inject custom behavior.
149 |
150 | Args:
151 | model (`nn.Module`):
152 | The model to evaluate.
153 | inputs (`Dict[str, Union[torch.Tensor, Any]]`):
154 | The inputs and targets of the model.
155 |
156 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
157 | argument `labels`. Check your model's documentation for all accepted arguments.
158 | prediction_loss_only (`bool`):
159 | Whether or not to return the loss only.
160 |
161 | Return:
162 | Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
163 | labels (each being optional).
164 | """
165 |
166 | if not self.args.predict_with_generate or prediction_loss_only:
167 | return super().prediction_step(
168 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
169 | )
170 |
171 | has_labels = "labels" in inputs
172 | inputs = self._prepare_inputs(inputs)
173 |
174 | # XXX: adapt synced_gpus for fairscale as well
175 | gen_kwargs = self._gen_kwargs.copy()
176 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
177 | gen_kwargs["max_length"] = self.model.config.max_length
178 | gen_kwargs["num_beams"] = (
179 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
180 | )
181 | default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
182 | gen_kwargs["synced_gpus"] = (
183 | gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
184 | )
185 |
186 | if "attention_mask" in inputs:
187 | gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
188 | if "position_ids" in inputs:
189 | gen_kwargs["position_ids"] = inputs.get("position_ids", None)
190 | if "global_attention_mask" in inputs:
191 | gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None)
192 |
193 | # prepare generation inputs
194 | # some encoder-decoder models can have varying encoder's and thus
195 | # varying model input names
196 | if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
197 | generation_inputs = inputs[self.model.encoder.main_input_name]
198 | else:
199 | generation_inputs = inputs[self.model.main_input_name]
200 |
201 | gen_kwargs["input_ids"] = generation_inputs
202 | generated_tokens = self.model.generate(**gen_kwargs)
203 | generated_tokens = generated_tokens[:, generation_inputs.size()[-1]:]
204 |
205 | # in case the batch is shorter than max length, the output should be padded
206 | if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]:
207 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
208 | elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < (
209 | gen_kwargs["max_new_tokens"] + 1
210 | ):
211 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1)
212 |
213 | loss = None
214 |
215 | if self.args.prediction_loss_only:
216 | return (loss, None, None)
217 |
218 | if has_labels:
219 | labels = inputs["labels"]
220 | if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]:
221 | labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
222 | elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < (
223 | gen_kwargs["max_new_tokens"] + 1
224 | ):
225 | labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1))
226 | else:
227 | labels = None
228 |
229 | return (loss, generated_tokens, labels)
230 |
231 | def _pad_tensors_to_max_len(self, tensor, max_length):
232 | if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
233 | # If PAD token is not defined at least EOS token has to be defined
234 | pad_token_id = (
235 | self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
236 | )
237 | else:
238 | if self.model.config.pad_token_id is not None:
239 | pad_token_id = self.model.config.pad_token_id
240 | else:
241 | raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
242 |
243 | padded_tensor = pad_token_id * torch.ones(
244 | (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
245 | )
246 | padded_tensor[:, : tensor.shape[-1]] = tensor
247 | return padded_tensor
248 |
--------------------------------------------------------------------------------
/chatglm_model_v2/run_ptuning.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2021 The HuggingFace Team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """
17 | Fine-tuning the library models for sequence to sequence.
18 | """
19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20 |
21 | import logging
22 | import os
23 | import sys
24 | import json
25 |
26 | import numpy as np
27 | from datasets import load_dataset
28 | import jieba
29 | from rouge_chinese import Rouge
30 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
31 | import torch
32 |
33 | import transformers
34 | from transformers import (
35 | AutoConfig,
36 | AutoModel,
37 | AutoTokenizer,
38 | DataCollatorForSeq2Seq,
39 | HfArgumentParser,
40 | Seq2SeqTrainingArguments,
41 | set_seed,
42 | )
43 | from trainer_seq2seq import Seq2SeqTrainer
44 |
45 | from arguments import ModelArguments, DataTrainingArguments
46 |
47 | logger = logging.getLogger(__name__)
48 |
49 | def main():
50 | # 加载模型、训练和数据参数配置
51 | # loading model, training and data augments
52 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
53 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
54 | # If we pass only one argument to the script and it's the path to a json file,
55 | # let's parse it to get our arguments.
56 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
57 | else:
58 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
59 |
60 | # Setup logging
61 | logging.basicConfig(
62 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
63 | datefmt="%m/%d/%Y %H:%M:%S",
64 | handlers=[logging.StreamHandler(sys.stdout)],
65 | )
66 |
67 | if training_args.should_log:
68 | # The default of training_args.log_level is passive, so we set log level at info here to have that default.
69 | transformers.utils.logging.set_verbosity_info()
70 |
71 | log_level = training_args.get_process_log_level()
72 | logger.setLevel(log_level)
73 | # datasets.utils.logging.set_verbosity(log_level)
74 | transformers.utils.logging.set_verbosity(log_level)
75 | transformers.utils.logging.enable_default_handler()
76 | transformers.utils.logging.enable_explicit_format()
77 |
78 | # Log on each process the small summary:
79 | logger.warning(
80 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
81 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
82 | )
83 | logger.info(f"Training/evaluation parameters {training_args}")
84 |
85 | # Set seed before initializing model.
86 | set_seed(training_args.seed)
87 |
88 | # Load dataset
89 | data_files = {}
90 | if data_args.train_file is not None:
91 | data_files["train"] = data_args.train_file
92 | extension = data_args.train_file.split(".")[-1]
93 | if data_args.validation_file is not None:
94 | data_files["validation"] = data_args.validation_file
95 | extension = data_args.validation_file.split(".")[-1]
96 | if data_args.test_file is not None:
97 | data_files["test"] = data_args.test_file
98 | extension = data_args.test_file.split(".")[-1]
99 |
100 | # 读取为hugging face格式的数据
101 | raw_datasets = load_dataset(
102 | extension,
103 | data_files=data_files,
104 | cache_dir=model_args.cache_dir,
105 | use_auth_token=True if model_args.use_auth_token else None,
106 | )
107 |
108 | # Load pretrained model and tokenizer
109 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
110 | config.pre_seq_len = model_args.pre_seq_len
111 | config.prefix_projection = model_args.prefix_projection
112 |
113 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
114 |
115 | if model_args.ptuning_checkpoint is not None:
116 | # Evaluation
117 | # Loading extra state dict of prefix encoder
118 | # 推理时,只需要加载p-tuning v2对应的参数即可,然后将其与原始模型参数进行结合
119 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
120 | prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
121 | new_prefix_state_dict = {}
122 | # 将p-tuning v2的参与嵌入到原始模型中。
123 | for k, v in prefix_state_dict.items():
124 | if k.startswith("transformer.prefix_encoder."):
125 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
126 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
127 | else:
128 | # 直接加载原始模型
129 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
130 |
131 | if model_args.quantization_bit is not None:
132 | print(f"Quantized to {model_args.quantization_bit} bit")
133 | model = model.quantize(model_args.quantization_bit)
134 | if model_args.pre_seq_len is not None:
135 | # P-tuning v2
136 | model = model.half()
137 | model.transformer.prefix_encoder.float()
138 | else:
139 | # Finetune
140 | model = model.float()
141 |
142 | prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
143 |
144 | # Preprocessing the datasets.
145 | # We need to tokenize inputs and targets.
146 | if training_args.do_train:
147 | column_names = raw_datasets["train"].column_names
148 | elif training_args.do_eval:
149 | column_names = raw_datasets["validation"].column_names
150 | elif training_args.do_predict:
151 | column_names = raw_datasets["test"].column_names
152 | else:
153 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
154 | return
155 |
156 | # Get the column names for input/target.
157 | prompt_column = data_args.prompt_column
158 | response_column = data_args.response_column
159 | history_column = data_args.history_column
160 |
161 | # Temporarily set max_target_length for training.
162 | max_target_length = data_args.max_target_length
163 |
164 | def preprocess_function_eval(examples):
165 | inputs, targets = [], []
166 | for i in range(len(examples[prompt_column])):
167 | if examples[prompt_column][i] and examples[response_column][i]:
168 | query = examples[prompt_column][i]
169 | history = examples[history_column][i] if history_column is not None else None
170 | prompt = tokenizer.build_prompt(query, history)
171 | inputs.append(prompt)
172 | targets.append(examples[response_column][i])
173 |
174 | inputs = [prefix + inp for inp in inputs]
175 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True)
176 | labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)
177 |
178 | if data_args.ignore_pad_token_for_loss:
179 | labels["input_ids"] = [
180 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
181 | ]
182 | model_inputs["labels"] = labels["input_ids"]
183 |
184 | return model_inputs
185 |
186 | def preprocess_function_train(examples):
187 | max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
188 |
189 | model_inputs = {
190 | "input_ids": [],
191 | "labels": [],
192 | }
193 | for i in range(len(examples[prompt_column])):
194 | if examples[prompt_column][i] and examples[response_column][i]:
195 | query, answer = examples[prompt_column][i], examples[response_column][i]
196 |
197 | history = examples[history_column][i] if history_column is not None else None
198 | prompt = tokenizer.build_prompt(query, history)
199 |
200 | prompt = prefix + prompt
201 | a_ids = tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True,
202 | max_length=data_args.max_source_length)
203 | b_ids = tokenizer.encode(text=answer, add_special_tokens=False, truncation=True,
204 | max_length=data_args.max_target_length)
205 |
206 | context_length = len(a_ids)
207 | input_ids = a_ids + b_ids + [tokenizer.eos_token_id]
208 | labels = [tokenizer.pad_token_id] * context_length + b_ids + [tokenizer.eos_token_id]
209 |
210 | pad_len = max_seq_length - len(input_ids)
211 | input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
212 | labels = labels + [tokenizer.pad_token_id] * pad_len
213 | if data_args.ignore_pad_token_for_loss:
214 | labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]
215 |
216 | model_inputs["input_ids"].append(input_ids)
217 | model_inputs["labels"].append(labels)
218 |
219 | return model_inputs
220 |
221 | def print_dataset_example(example):
222 | print("input_ids", example["input_ids"])
223 | print("inputs", tokenizer.decode(example["input_ids"]))
224 | print("label_ids", example["labels"])
225 | print("labels", tokenizer.decode(example["labels"]))
226 |
227 | base_cache_dir = os.path(data_args.base_cache_dir, data_args.task_name)
228 | if training_args.local_rank <= 0 and not os.path.exists(base_cache_dir):
229 | os.makedirs(base_cache_dir)
230 |
231 | if training_args.do_train:
232 | if "train" not in raw_datasets:
233 | raise ValueError("--do_train requires a train dataset")
234 | train_dataset = raw_datasets["train"]
235 | if data_args.max_train_samples is not None:
236 | max_train_samples = min(len(train_dataset), data_args.max_train_samples)
237 | train_dataset = train_dataset.select(range(max_train_samples))
238 | with training_args.main_process_first(desc="train dataset map pre-processing"):
239 | train_dataset = train_dataset.map(
240 | preprocess_function_train,
241 | batched=True,
242 | num_proc=data_args.preprocessing_num_workers,
243 | remove_columns=column_names,
244 | load_from_cache_file=not data_args.overwrite_cache,
245 | desc="Running tokenizer on train dataset",
246 | cache_file_name=os.path.join(base_cache_dir, "train.arrow")
247 | )
248 | print_dataset_example(train_dataset[0])
249 |
250 | if training_args.do_eval:
251 | max_target_length = data_args.val_max_target_length
252 | if "validation" not in raw_datasets:
253 | raise ValueError("--do_eval requires a validation dataset")
254 | eval_dataset = raw_datasets["validation"]
255 | if data_args.max_eval_samples is not None:
256 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
257 | eval_dataset = eval_dataset.select(range(max_eval_samples))
258 | with training_args.main_process_first(desc="validation dataset map pre-processing"):
259 | eval_dataset = eval_dataset.map(
260 | preprocess_function_eval,
261 | batched=True,
262 | num_proc=data_args.preprocessing_num_workers,
263 | remove_columns=column_names,
264 | load_from_cache_file=not data_args.overwrite_cache,
265 | desc="Running tokenizer on validation dataset",
266 | cache_file_name=os.path.join(base_cache_dir, "eval.arrow")
267 | )
268 | print_dataset_example(eval_dataset[0])
269 |
270 | if training_args.do_predict:
271 | max_target_length = data_args.val_max_target_length
272 | if "test" not in raw_datasets:
273 | raise ValueError("--do_predict requires a test dataset")
274 | predict_dataset = raw_datasets["test"]
275 | if data_args.max_predict_samples is not None:
276 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
277 | predict_dataset = predict_dataset.select(range(max_predict_samples))
278 | with training_args.main_process_first(desc="prediction dataset map pre-processing"):
279 | predict_dataset = predict_dataset.map(
280 | preprocess_function_eval,
281 | batched=True,
282 | num_proc=data_args.preprocessing_num_workers,
283 | remove_columns=column_names,
284 | load_from_cache_file=not data_args.overwrite_cache,
285 | desc="Running tokenizer on prediction dataset",
286 | cache_file_name=os.path.join(base_cache_dir, "predict.arrow")
287 | )
288 | print_dataset_example(predict_dataset[0])
289 |
290 | # Data collator
291 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
292 | data_collator = DataCollatorForSeq2Seq(
293 | tokenizer,
294 | model=model,
295 | label_pad_token_id=label_pad_token_id,
296 | pad_to_multiple_of=None,
297 | padding=False
298 | )
299 |
300 | # Metric
301 | def compute_metrics(eval_preds):
302 | preds, labels = eval_preds
303 | if isinstance(preds, tuple):
304 | preds = preds[0]
305 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
306 | if data_args.ignore_pad_token_for_loss:
307 | # Replace -100 in the labels as we can't decode them.
308 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
309 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
310 |
311 | score_dict = {
312 | "rouge-1": [],
313 | "rouge-2": [],
314 | "rouge-l": [],
315 | "bleu-4": []
316 | }
317 | for pred, label in zip(decoded_preds, decoded_labels):
318 | hypothesis = list(jieba.cut(pred))
319 | reference = list(jieba.cut(label))
320 | rouge = Rouge()
321 | scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference))
322 | result = scores[0]
323 |
324 | for k, v in result.items():
325 | score_dict[k].append(round(v["f"] * 100, 4))
326 | bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
327 | score_dict["bleu-4"].append(round(bleu_score * 100, 4))
328 |
329 | for k, v in score_dict.items():
330 | score_dict[k] = float(np.mean(v))
331 | return score_dict
332 |
333 | # Override the decoding parameters of Seq2SeqTrainer
334 | training_args.generation_max_length = (
335 | training_args.generation_max_length
336 | if training_args.generation_max_length is not None
337 | else data_args.val_max_target_length
338 | )
339 | training_args.generation_num_beams = (
340 | data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
341 | )
342 | # Initialize our Trainer
343 | trainer = Seq2SeqTrainer(
344 | model=model,
345 | args=training_args,
346 | train_dataset=train_dataset if training_args.do_train else None,
347 | eval_dataset=eval_dataset if training_args.do_eval else None,
348 | tokenizer=tokenizer,
349 | data_collator=data_collator,
350 | compute_metrics=compute_metrics if training_args.predict_with_generate else None,
351 | save_changed=model_args.pre_seq_len is not None
352 | )
353 |
354 | # Training
355 | if training_args.do_train:
356 | checkpoint = None
357 | if training_args.resume_from_checkpoint is not None:
358 | checkpoint = training_args.resume_from_checkpoint
359 | # elif last_checkpoint is not None:
360 | # checkpoint = last_checkpoint
361 | model.gradient_checkpointing_enable()
362 | model.enable_input_require_grads()
363 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
364 | # trainer.save_model() # Saves the tokenizer too for easy upload
365 |
366 | metrics = train_result.metrics
367 | max_train_samples = (
368 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
369 | )
370 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
371 |
372 | trainer.log_metrics("train", metrics)
373 | trainer.save_metrics("train", metrics)
374 | trainer.save_state()
375 |
376 | # Evaluation
377 | results = {}
378 | max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
379 | if training_args.do_eval:
380 | logger.info("*** Evaluate ***")
381 | metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=max_seq_length, temperature=0.95)
382 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
383 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
384 |
385 | trainer.log_metrics("eval", metrics)
386 | trainer.save_metrics("eval", metrics)
387 |
388 | if training_args.do_predict:
389 | logger.info("*** Predict ***")
390 | predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=max_seq_length, do_sample=True, top_p=0.7, temperature=0.95)
391 | metrics = predict_results.metrics
392 | max_predict_samples = (
393 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
394 | )
395 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
396 |
397 | trainer.log_metrics("predict", metrics)
398 | trainer.save_metrics("predict", metrics)
399 |
400 | if trainer.is_world_process_zero():
401 | if training_args.predict_with_generate:
402 | predictions = tokenizer.batch_decode(
403 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
404 | )
405 | predictions = [pred.strip() for pred in predictions]
406 | labels = tokenizer.batch_decode(
407 | predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
408 | )
409 | labels = [label.strip() for label in labels]
410 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
411 | with open(output_prediction_file, "w", encoding="utf-8") as writer:
412 | for p, l in zip(predictions, labels):
413 | res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False)
414 | writer.write(f"{res}\n")
415 | return results
416 |
417 |
418 | def _mp_fn(index):
419 | # For xla_spawn (TPUs)
420 | main()
421 |
422 |
423 | if __name__ == "__main__":
424 | main()
425 |
--------------------------------------------------------------------------------
/chatglm_model_v1/run_ptuning.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2021 The HuggingFace Team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """
17 | Fine-tuning the library models for sequence to sequence.
18 | """
19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20 |
21 | import logging
22 | import os
23 | import sys
24 | import json
25 |
26 | import numpy as np
27 | from datasets import load_dataset
28 | import jieba
29 | from rouge_chinese import Rouge
30 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
31 | import torch
32 |
33 | import transformers
34 | from transformers import (
35 | AutoConfig,
36 | AutoModel,
37 | AutoTokenizer,
38 | DataCollatorForSeq2Seq,
39 | HfArgumentParser,
40 | Seq2SeqTrainingArguments,
41 | set_seed,
42 | )
43 | from trainer_seq2seq import Seq2SeqTrainer
44 |
45 | from arguments import ModelArguments, DataTrainingArguments
46 |
47 | logger = logging.getLogger(__name__)
48 |
49 | def main():
50 |
51 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
52 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
53 | # If we pass only one argument to the script and it's the path to a json file,
54 | # let's parse it to get our arguments.
55 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
56 | else:
57 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
58 |
59 | # Setup logging
60 | logging.basicConfig(
61 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
62 | datefmt="%m/%d/%Y %H:%M:%S",
63 | handlers=[logging.StreamHandler(sys.stdout)],
64 | )
65 |
66 | if training_args.should_log:
67 | # The default of training_args.log_level is passive, so we set log level at info here to have that default.
68 | transformers.utils.logging.set_verbosity_info()
69 |
70 | log_level = training_args.get_process_log_level()
71 | logger.setLevel(log_level)
72 | # datasets.utils.logging.set_verbosity(log_level)
73 | transformers.utils.logging.set_verbosity(log_level)
74 | transformers.utils.logging.enable_default_handler()
75 | transformers.utils.logging.enable_explicit_format()
76 |
77 | # Log on each process the small summary:
78 | logger.warning(
79 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
80 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
81 | )
82 | logger.info(f"Training/evaluation parameters {training_args}")
83 |
84 | # Set seed before initializing model.
85 | set_seed(training_args.seed)
86 |
87 | # Load dataset
88 | data_files = {}
89 | if data_args.train_file is not None:
90 | data_files["train"] = data_args.train_file
91 | extension = data_args.train_file.split(".")[-1]
92 | if data_args.validation_file is not None:
93 | data_files["validation"] = data_args.validation_file
94 | extension = data_args.validation_file.split(".")[-1]
95 | if data_args.test_file is not None:
96 | data_files["test"] = data_args.test_file
97 | extension = data_args.test_file.split(".")[-1]
98 |
99 | raw_datasets = load_dataset(
100 | extension,
101 | data_files=data_files,
102 | cache_dir=model_args.cache_dir,
103 | use_auth_token=True if model_args.use_auth_token else None,
104 | )
105 |
106 | # Load pretrained model and tokenizer
107 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
108 | config.pre_seq_len = model_args.pre_seq_len
109 | config.prefix_projection = model_args.prefix_projection
110 |
111 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
112 |
113 | if model_args.ptuning_checkpoint is not None:
114 | # Evaluation
115 | # Loading extra state dict of prefix encoder
116 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
117 | prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
118 | new_prefix_state_dict = {}
119 | for k, v in prefix_state_dict.items():
120 | if k.startswith("transformer.prefix_encoder."):
121 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
122 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
123 | else:
124 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
125 |
126 | if model_args.quantization_bit is not None:
127 | print(f"Quantized to {model_args.quantization_bit} bit")
128 | model = model.quantize(model_args.quantization_bit)
129 | if model_args.pre_seq_len is not None:
130 | # P-tuning v2
131 | model = model.half()
132 | model.transformer.prefix_encoder.float()
133 | else:
134 | # Finetune
135 | model = model.float()
136 |
137 | prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
138 |
139 | # Preprocessing the datasets.
140 | # We need to tokenize inputs and targets.
141 | if training_args.do_train:
142 | column_names = raw_datasets["train"].column_names
143 | elif training_args.do_eval:
144 | column_names = raw_datasets["validation"].column_names
145 | elif training_args.do_predict:
146 | column_names = raw_datasets["test"].column_names
147 | else:
148 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
149 | return
150 |
151 | # Get the column names for input/target.
152 | prompt_column = data_args.prompt_column
153 | response_column = data_args.response_column
154 | history_column = data_args.history_column
155 |
156 | # Temporarily set max_target_length for training.
157 | max_target_length = data_args.max_target_length
158 |
159 | def preprocess_function_eval(examples):
160 | inputs, targets = [], []
161 | for i in range(len(examples[prompt_column])):
162 | if examples[prompt_column][i] and examples[response_column][i]:
163 | query = examples[prompt_column][i]
164 | if history_column is None or len(examples[history_column][i]) == 0:
165 | prompt = query
166 | else:
167 | prompt = ""
168 | history = examples[history_column][i]
169 | for turn_idx, (old_query, response) in enumerate(history):
170 | prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response)
171 | prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
172 | inputs.append(prompt)
173 | targets.append(examples[response_column][i])
174 |
175 | inputs = [prefix + inp for inp in inputs]
176 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True)
177 | labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)
178 |
179 | if data_args.ignore_pad_token_for_loss:
180 | labels["input_ids"] = [
181 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
182 | ]
183 | model_inputs["labels"] = labels["input_ids"]
184 |
185 | return model_inputs
186 |
187 | def preprocess_function_train(examples):
188 | max_seq_length = data_args.max_source_length + data_args.max_target_length
189 |
190 | model_inputs = {
191 | "input_ids": [],
192 | "labels": [],
193 | }
194 | for i in range(len(examples[prompt_column])):
195 | if examples[prompt_column][i] and examples[response_column][i]:
196 | query, answer = examples[prompt_column][i], examples[response_column][i]
197 |
198 | if history_column is None:
199 | prompt = query
200 | else:
201 | prompt = ""
202 | history = examples[history_column][i]
203 | for turn_idx, (old_query, response) in enumerate(history):
204 | prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response)
205 | prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
206 |
207 | prompt = prefix + prompt
208 | a_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
209 | b_ids = tokenizer.encode(text=answer, add_special_tokens=False)
210 |
211 | if len(a_ids) > data_args.max_source_length - 1:
212 | a_ids = a_ids[: data_args.max_source_length - 1]
213 |
214 | if len(b_ids) > data_args.max_target_length - 2:
215 | b_ids = b_ids[: data_args.max_target_length - 2]
216 |
217 | input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids)
218 |
219 | context_length = input_ids.index(tokenizer.bos_token_id)
220 | mask_position = context_length - 1
221 | labels = [-100] * context_length + input_ids[mask_position+1:]
222 |
223 | pad_len = max_seq_length - len(input_ids)
224 | input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
225 | labels = labels + [tokenizer.pad_token_id] * pad_len
226 | if data_args.ignore_pad_token_for_loss:
227 | labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]
228 |
229 | model_inputs["input_ids"].append(input_ids)
230 | model_inputs["labels"].append(labels)
231 |
232 | return model_inputs
233 |
234 | def print_dataset_example(example):
235 | print("input_ids",example["input_ids"])
236 | print("inputs", tokenizer.decode(example["input_ids"]))
237 | print("label_ids", example["labels"])
238 | print("labels", tokenizer.decode(example["labels"]))
239 |
240 | base_cache_dir = os.path.join(data_args.base_cache_dir, data_args.task_name)
241 | if training_args.local_rank <= 0 and not os.path.exists(base_cache_dir):
242 | os.makedirs(base_cache_dir)
243 |
244 | if training_args.do_train:
245 | if "train" not in raw_datasets:
246 | raise ValueError("--do_train requires a train dataset")
247 | train_dataset = raw_datasets["train"]
248 | if data_args.max_train_samples is not None:
249 | max_train_samples = min(len(train_dataset), data_args.max_train_samples)
250 | train_dataset = train_dataset.select(range(max_train_samples))
251 | with training_args.main_process_first(desc="train dataset map pre-processing"):
252 | train_dataset = train_dataset.map(
253 | preprocess_function_train,
254 | batched=True,
255 | num_proc=data_args.preprocessing_num_workers,
256 | remove_columns=column_names,
257 | load_from_cache_file=not data_args.overwrite_cache,
258 | desc="Running tokenizer on train dataset",
259 | cache_file_name=os.path.join(base_cache_dir, "train.arrow")
260 | )
261 | print_dataset_example(train_dataset[0])
262 |
263 | if training_args.do_eval:
264 | max_target_length = data_args.val_max_target_length
265 | if "validation" not in raw_datasets:
266 | raise ValueError("--do_eval requires a validation dataset")
267 | eval_dataset = raw_datasets["validation"]
268 | if data_args.max_eval_samples is not None:
269 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
270 | eval_dataset = eval_dataset.select(range(max_eval_samples))
271 | with training_args.main_process_first(desc="validation dataset map pre-processing"):
272 | eval_dataset = eval_dataset.map(
273 | preprocess_function_eval,
274 | batched=True,
275 | num_proc=data_args.preprocessing_num_workers,
276 | remove_columns=column_names,
277 | load_from_cache_file=not data_args.overwrite_cache,
278 | desc="Running tokenizer on validation dataset",
279 | cache_file_name=os.path.join(base_cache_dir, "eval.arrow")
280 | )
281 | print_dataset_example(eval_dataset[0])
282 |
283 | if training_args.do_predict:
284 | max_target_length = data_args.val_max_target_length
285 | if "test" not in raw_datasets:
286 | raise ValueError("--do_predict requires a test dataset")
287 | predict_dataset = raw_datasets["test"]
288 | if data_args.max_predict_samples is not None:
289 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
290 | predict_dataset = predict_dataset.select(range(max_predict_samples))
291 | with training_args.main_process_first(desc="prediction dataset map pre-processing"):
292 | predict_dataset = predict_dataset.map(
293 | preprocess_function_eval,
294 | batched=True,
295 | num_proc=data_args.preprocessing_num_workers,
296 | remove_columns=column_names,
297 | load_from_cache_file=not data_args.overwrite_cache,
298 | desc="Running tokenizer on prediction dataset",
299 | cache_file_name=os.path.join(base_cache_dir, "predict.arrow")
300 | )
301 | print_dataset_example(predict_dataset[0])
302 |
303 | # Data collator
304 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
305 | data_collator = DataCollatorForSeq2Seq(
306 | tokenizer,
307 | model=model,
308 | label_pad_token_id=label_pad_token_id,
309 | pad_to_multiple_of=None,
310 | padding=False
311 | )
312 |
313 | # Metric
314 | def compute_metrics(eval_preds):
315 | preds, labels = eval_preds
316 | if isinstance(preds, tuple):
317 | preds = preds[0]
318 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
319 | if data_args.ignore_pad_token_for_loss:
320 | # Replace -100 in the labels as we can't decode them.
321 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
322 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
323 |
324 | score_dict = {
325 | "rouge-1": [],
326 | "rouge-2": [],
327 | "rouge-l": [],
328 | "bleu-4": []
329 | }
330 | for pred, label in zip(decoded_preds, decoded_labels):
331 | hypothesis = list(jieba.cut(pred))
332 | reference = list(jieba.cut(label))
333 | rouge = Rouge()
334 | scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference))
335 | result = scores[0]
336 |
337 | for k, v in result.items():
338 | score_dict[k].append(round(v["f"] * 100, 4))
339 | bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
340 | score_dict["bleu-4"].append(round(bleu_score * 100, 4))
341 |
342 | for k, v in score_dict.items():
343 | score_dict[k] = float(np.mean(v))
344 | return score_dict
345 |
346 | # Override the decoding parameters of Seq2SeqTrainer
347 | training_args.generation_max_length = (
348 | training_args.generation_max_length
349 | if training_args.generation_max_length is not None
350 | else data_args.val_max_target_length
351 | )
352 | training_args.generation_num_beams = (
353 | data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
354 | )
355 | # Initialize our Trainer
356 | trainer = Seq2SeqTrainer(
357 | model=model,
358 | args=training_args,
359 | train_dataset=train_dataset if training_args.do_train else None,
360 | eval_dataset=eval_dataset if training_args.do_eval else None,
361 | tokenizer=tokenizer,
362 | data_collator=data_collator,
363 | compute_metrics=compute_metrics if training_args.predict_with_generate else None,
364 | save_prefixencoder=model_args.pre_seq_len is not None
365 | )
366 |
367 | # Training
368 | if training_args.do_train:
369 | checkpoint = None
370 | if training_args.resume_from_checkpoint is not None:
371 | checkpoint = training_args.resume_from_checkpoint
372 | # elif last_checkpoint is not None:
373 | # checkpoint = last_checkpoint
374 | model.gradient_checkpointing_enable()
375 | model.enable_input_require_grads()
376 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
377 | # trainer.save_model() # Saves the tokenizer too for easy upload
378 |
379 | metrics = train_result.metrics
380 | max_train_samples = (
381 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
382 | )
383 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
384 |
385 | trainer.log_metrics("train", metrics)
386 | trainer.save_metrics("train", metrics)
387 | trainer.save_state()
388 |
389 | # Evaluation
390 | results = {}
391 | max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
392 | if training_args.do_eval:
393 | logger.info("*** Evaluate ***")
394 | metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=max_seq_length, temperature=0.95)
395 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
396 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
397 |
398 | trainer.log_metrics("eval", metrics)
399 | trainer.save_metrics("eval", metrics)
400 |
401 | if training_args.do_predict:
402 | logger.info("*** Predict ***")
403 | predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=max_seq_length, do_sample=True, top_p=0.7, temperature=0.95)
404 | metrics = predict_results.metrics
405 | max_predict_samples = (
406 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
407 | )
408 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
409 |
410 | trainer.log_metrics("predict", metrics)
411 | trainer.save_metrics("predict", metrics)
412 |
413 | if trainer.is_world_process_zero():
414 | if training_args.predict_with_generate:
415 | predictions = tokenizer.batch_decode(
416 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
417 | )
418 | predictions = [pred.strip() for pred in predictions]
419 | labels = tokenizer.batch_decode(
420 | predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
421 | )
422 | labels = [label.strip() for label in labels]
423 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
424 | with open(output_prediction_file, "w", encoding="utf-8") as writer:
425 | for p, l in zip(predictions, labels):
426 | res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False)
427 | writer.write(f"{res}\n")
428 | return results
429 |
430 |
431 | def _mp_fn(index):
432 | # For xla_spawn (TPUs)
433 | main()
434 |
435 |
436 | if __name__ == "__main__":
437 | main()
438 |
--------------------------------------------------------------------------------
/chatglm_model_v2/run_peft.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2021 The HuggingFace Team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """
17 | Fine-tuning the library models for sequence to sequence.
18 | """
19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20 |
21 | import logging
22 | import os
23 | import sys
24 | import json
25 |
26 | import numpy as np
27 | from datasets import load_dataset
28 | import jieba
29 | from rouge_chinese import Rouge
30 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
31 | import torch
32 |
33 | import transformers
34 | from transformers import (
35 | AutoConfig,
36 | AutoModel,
37 | AutoTokenizer,
38 | DataCollatorForSeq2Seq,
39 | HfArgumentParser,
40 | Seq2SeqTrainingArguments,
41 | set_seed,
42 | )
43 |
44 | # 添加PEFT配置
45 | from peft import (
46 | LoraConfig,
47 | PrefixTuningConfig,
48 | PromptEncoderConfig,
49 | PromptEncoderReparameterizationType,
50 | PromptTuningConfig,
51 | PromptTuningInit,
52 | TaskType,
53 | get_peft_model,
54 | )
55 |
56 | from trainer_seq2seq import Seq2SeqTrainer
57 |
58 | from arguments import ModelArguments, DataTrainingArguments, PeftArguments
59 |
60 | logger = logging.getLogger(__name__)
61 |
62 | def main():
63 | # 加载模型、训练和数据参数配置
64 | # loading model, training and data augments
65 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, PeftArguments))
66 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
67 | # If we pass only one argument to the script and it's the path to a json file,
68 | # let's parse it to get our arguments.
69 | model_args, data_args, training_args, peft_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
70 | else:
71 | model_args, data_args, training_args, peft_args = parser.parse_args_into_dataclasses()
72 |
73 | # Setup logging
74 | logging.basicConfig(
75 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
76 | datefmt="%m/%d/%Y %H:%M:%S",
77 | handlers=[logging.StreamHandler(sys.stdout)],
78 | )
79 |
80 | if training_args.should_log:
81 | # The default of training_args.log_level is passive, so we set log level at info here to have that default.
82 | transformers.utils.logging.set_verbosity_info()
83 |
84 | log_level = training_args.get_process_log_level()
85 | logger.setLevel(log_level)
86 | # datasets.utils.logging.set_verbosity(log_level)
87 | transformers.utils.logging.set_verbosity(log_level)
88 | transformers.utils.logging.enable_default_handler()
89 | transformers.utils.logging.enable_explicit_format()
90 |
91 | # Log on each process the small summary:
92 | logger.warning(
93 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
94 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
95 | )
96 | logger.info(f"Training/evaluation parameters {training_args}")
97 |
98 | # Set seed before initializing model.
99 | set_seed(training_args.seed)
100 |
101 | # Load dataset
102 | data_files = {}
103 | if data_args.train_file is not None:
104 | data_files["train"] = data_args.train_file
105 | extension = data_args.train_file.split(".")[-1]
106 | if data_args.validation_file is not None:
107 | data_files["validation"] = data_args.validation_file
108 | extension = data_args.validation_file.split(".")[-1]
109 | if data_args.test_file is not None:
110 | data_files["test"] = data_args.test_file
111 | extension = data_args.test_file.split(".")[-1]
112 |
113 | # 读取为hugging face格式的数据
114 | raw_datasets = load_dataset(
115 | extension,
116 | data_files=data_files,
117 | cache_dir=model_args.cache_dir,
118 | use_auth_token=True if model_args.use_auth_token else None,
119 | )
120 |
121 | # Load pretrained model and tokenizer
122 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
123 | config.pre_seq_len = model_args.pre_seq_len
124 | config.prefix_projection = model_args.prefix_projection
125 |
126 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
127 |
128 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
129 |
130 | # 使用PEFT
131 | if peft_args.peft_type is not None:
132 | if peft_args.peft_type == "lora":
133 | peft_config = LoraConfig(
134 | task_type=TaskType.CAUSAL_LM,
135 | inference_mode=False,
136 | r=peft_args.lora_dim,
137 | lora_alpha=32,
138 | lora_dropout=0.1
139 | )
140 | elif peft_args.peft_type == "ptuning":
141 | peft_config = PromptEncoderConfig(
142 | task_type=TaskType.CAUSAL_LM,
143 | inference_mode=False,
144 | encoder_reparameterization_type=PromptEncoderReparameterizationType.MLP, # 默认使用MLP表征Prompt
145 | encoder_num_layers=2,
146 | encoder_dropout=0.1,
147 | num_virtual_tokens=8 # soft prompt token数量
148 | )
149 | elif peft_args.peft_type == "prefix":
150 | peft_config = PrefixTuningConfig(
151 | task_type=TaskType.CAUSAL_LM,
152 | inference_mode=False,
153 | num_virtual_tokens=4,
154 | num_attention_heads=12,
155 | num_layers=48,
156 | encoder_hidden_size=768,
157 | token_dim=1536,
158 | )
159 | elif peft_args.peft_type == "prompt":
160 | peft_config = PromptTuningConfig(
161 | task_type=TaskType.CAUSAL_LM,
162 | inference_mode=False,
163 | prompt_tuning_init=PromptTuningInit.TEXT if peft_args.prompt_tuning_initial_text is not None else PromptTuningInit.RANDOM
164 | )
165 | elif peft_args.peft_type == "adalora":
166 | raise NotImplementedError("Adalora is under developing")
167 | else:
168 | raise NotImplementedError("you must choose one of parameter-efficient learning method")
169 |
170 | logger.info("You have chosen {} as peft type, here is loading model ...".format(peft_args.peft_type))
171 | model = get_peft_model(model, peft_config=peft_config)
172 | logger.info("Reduing trainable parameters: {}".format(model.print_trainable_parameters))
173 |
174 | if model_args.quantization_bit is not None:
175 | print(f"Quantized to {model_args.quantization_bit} bit")
176 | model = model.quantize(model_args.quantization_bit)
177 |
178 |
179 | # if model_args.pre_seq_len is not None:
180 | # # P-tuning v2
181 | # model = model.half() # 开启半精度模式
182 | # model.transformer.prefix_encoder.float()
183 | # else:
184 | # # Finetune
185 | # model = model.float()
186 | model = model.float()
187 |
188 | prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
189 |
190 | # Preprocessing the datasets.
191 | # We need to tokenize inputs and targets.
192 | if training_args.do_train:
193 | column_names = raw_datasets["train"].column_names
194 | elif training_args.do_eval:
195 | column_names = raw_datasets["validation"].column_names
196 | elif training_args.do_predict:
197 | column_names = raw_datasets["test"].column_names
198 | else:
199 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
200 | return
201 |
202 | # Get the column names for input/target.
203 | prompt_column = data_args.prompt_column
204 | response_column = data_args.response_column
205 | history_column = data_args.history_column
206 |
207 | # Temporarily set max_target_length for training.
208 | max_target_length = data_args.max_target_length
209 |
210 | def preprocess_function_eval(examples):
211 | inputs, targets = [], []
212 | for i in range(len(examples[prompt_column])):
213 | if examples[prompt_column][i] and examples[response_column][i]:
214 | query = examples[prompt_column][i]
215 | history = examples[history_column][i] if history_column is not None else None
216 | prompt = tokenizer.build_prompt(query, history)
217 | inputs.append(prompt)
218 | targets.append(examples[response_column][i])
219 |
220 | inputs = [prefix + inp for inp in inputs]
221 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True)
222 | labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)
223 |
224 | if data_args.ignore_pad_token_for_loss:
225 | labels["input_ids"] = [
226 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
227 | ]
228 | model_inputs["labels"] = labels["input_ids"]
229 |
230 | return model_inputs
231 |
232 | def preprocess_function_train(examples):
233 | max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
234 |
235 | model_inputs = {
236 | "input_ids": [],
237 | "labels": [],
238 | }
239 | for i in range(len(examples[prompt_column])):
240 | if examples[prompt_column][i] and examples[response_column][i]:
241 | query, answer = examples[prompt_column][i], examples[response_column][i]
242 |
243 | history = examples[history_column][i] if history_column is not None else None
244 | prompt = tokenizer.build_prompt(query, history)
245 |
246 | prompt = prefix + prompt
247 | a_ids = tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True,
248 | max_length=data_args.max_source_length)
249 | b_ids = tokenizer.encode(text=answer, add_special_tokens=False, truncation=True,
250 | max_length=data_args.max_target_length)
251 |
252 | context_length = len(a_ids)
253 | input_ids = a_ids + b_ids + [tokenizer.eos_token_id]
254 | labels = [tokenizer.pad_token_id] * context_length + b_ids + [tokenizer.eos_token_id]
255 |
256 | pad_len = max_seq_length - len(input_ids)
257 | input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
258 | labels = labels + [tokenizer.pad_token_id] * pad_len
259 | if data_args.ignore_pad_token_for_loss:
260 | labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]
261 |
262 | model_inputs["input_ids"].append(input_ids)
263 | model_inputs["labels"].append(labels)
264 |
265 | return model_inputs
266 |
267 | def print_dataset_example(example):
268 | print("input_ids", example["input_ids"])
269 | print("inputs", tokenizer.decode(example["input_ids"]))
270 | print("label_ids", example["labels"])
271 | print("labels", tokenizer.decode(example["labels"]))
272 |
273 | base_cache_dir = os.path(data_args.base_cache_dir, data_args.task_name)
274 | if training_args.local_rank <= 0 and not os.path.exists(base_cache_dir):
275 | os.makedirs(base_cache_dir)
276 |
277 | if training_args.do_train:
278 | if "train" not in raw_datasets:
279 | raise ValueError("--do_train requires a train dataset")
280 | train_dataset = raw_datasets["train"]
281 | if data_args.max_train_samples is not None:
282 | max_train_samples = min(len(train_dataset), data_args.max_train_samples)
283 | train_dataset = train_dataset.select(range(max_train_samples))
284 | with training_args.main_process_first(desc="train dataset map pre-processing"):
285 | train_dataset = train_dataset.map(
286 | preprocess_function_train,
287 | batched=True,
288 | num_proc=data_args.preprocessing_num_workers,
289 | remove_columns=column_names,
290 | load_from_cache_file=not data_args.overwrite_cache,
291 | desc="Running tokenizer on train dataset",
292 | cache_file_name=os.path.join(base_cache_dir, "train.arrow")
293 | )
294 | print_dataset_example(train_dataset[0])
295 |
296 | if training_args.do_eval:
297 | max_target_length = data_args.val_max_target_length
298 | if "validation" not in raw_datasets:
299 | raise ValueError("--do_eval requires a validation dataset")
300 | eval_dataset = raw_datasets["validation"]
301 | if data_args.max_eval_samples is not None:
302 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
303 | eval_dataset = eval_dataset.select(range(max_eval_samples))
304 | with training_args.main_process_first(desc="validation dataset map pre-processing"):
305 | eval_dataset = eval_dataset.map(
306 | preprocess_function_eval,
307 | batched=True,
308 | num_proc=data_args.preprocessing_num_workers,
309 | remove_columns=column_names,
310 | load_from_cache_file=not data_args.overwrite_cache,
311 | desc="Running tokenizer on validation dataset",
312 | cache_file_name=os.path.join(base_cache_dir, "eval.arrow")
313 | )
314 | print_dataset_example(eval_dataset[0])
315 |
316 | if training_args.do_predict:
317 | max_target_length = data_args.val_max_target_length
318 | if "test" not in raw_datasets:
319 | raise ValueError("--do_predict requires a test dataset")
320 | predict_dataset = raw_datasets["test"]
321 | if data_args.max_predict_samples is not None:
322 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
323 | predict_dataset = predict_dataset.select(range(max_predict_samples))
324 | with training_args.main_process_first(desc="prediction dataset map pre-processing"):
325 | predict_dataset = predict_dataset.map(
326 | preprocess_function_eval,
327 | batched=True,
328 | num_proc=data_args.preprocessing_num_workers,
329 | remove_columns=column_names,
330 | load_from_cache_file=not data_args.overwrite_cache,
331 | desc="Running tokenizer on prediction dataset",
332 | cache_file_name=os.path.join(base_cache_dir, "predict.arrow")
333 | )
334 | print_dataset_example(predict_dataset[0])
335 |
336 | # Data collator
337 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
338 | data_collator = DataCollatorForSeq2Seq(
339 | tokenizer,
340 | model=model,
341 | label_pad_token_id=label_pad_token_id,
342 | pad_to_multiple_of=None,
343 | padding=False
344 | )
345 |
346 | # Metric
347 | def compute_metrics(eval_preds):
348 | preds, labels = eval_preds
349 | if isinstance(preds, tuple):
350 | preds = preds[0]
351 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
352 | if data_args.ignore_pad_token_for_loss:
353 | # Replace -100 in the labels as we can't decode them.
354 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
355 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
356 |
357 | score_dict = {
358 | "rouge-1": [],
359 | "rouge-2": [],
360 | "rouge-l": [],
361 | "bleu-4": []
362 | }
363 | for pred, label in zip(decoded_preds, decoded_labels):
364 | hypothesis = list(jieba.cut(pred))
365 | reference = list(jieba.cut(label))
366 | rouge = Rouge()
367 | scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference))
368 | result = scores[0]
369 |
370 | for k, v in result.items():
371 | score_dict[k].append(round(v["f"] * 100, 4))
372 | bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
373 | score_dict["bleu-4"].append(round(bleu_score * 100, 4))
374 |
375 | for k, v in score_dict.items():
376 | score_dict[k] = float(np.mean(v))
377 | return score_dict
378 |
379 | # Override the decoding parameters of Seq2SeqTrainer
380 | training_args.generation_max_length = (
381 | training_args.generation_max_length
382 | if training_args.generation_max_length is not None
383 | else data_args.val_max_target_length
384 | )
385 | training_args.generation_num_beams = (
386 | data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
387 | )
388 | # Initialize our Trainer
389 | trainer = Seq2SeqTrainer(
390 | model=model,
391 | args=training_args,
392 | train_dataset=train_dataset if training_args.do_train else None,
393 | eval_dataset=eval_dataset if training_args.do_eval else None,
394 | tokenizer=tokenizer,
395 | data_collator=data_collator,
396 | compute_metrics=compute_metrics if training_args.predict_with_generate else None,
397 | save_changed=model_args.pre_seq_len is not None
398 | )
399 |
400 | # Training
401 | if training_args.do_train:
402 | checkpoint = None
403 | if training_args.resume_from_checkpoint is not None:
404 | checkpoint = training_args.resume_from_checkpoint
405 | # elif last_checkpoint is not None:
406 | # checkpoint = last_checkpoint
407 | model.gradient_checkpointing_enable()
408 | model.enable_input_require_grads()
409 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
410 | # trainer.save_model() # Saves the tokenizer too for easy upload
411 |
412 | metrics = train_result.metrics
413 | max_train_samples = (
414 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
415 | )
416 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
417 |
418 | trainer.log_metrics("train", metrics)
419 | trainer.save_metrics("train", metrics)
420 | trainer.save_state()
421 |
422 | # Evaluation
423 | results = {}
424 | max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
425 | if training_args.do_eval:
426 | logger.info("*** Evaluate ***")
427 | metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=max_seq_length, temperature=0.95)
428 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
429 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
430 |
431 | trainer.log_metrics("eval", metrics)
432 | trainer.save_metrics("eval", metrics)
433 |
434 | if training_args.do_predict:
435 | logger.info("*** Predict ***")
436 | predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=max_seq_length, do_sample=True, top_p=0.7, temperature=0.95)
437 | metrics = predict_results.metrics
438 | max_predict_samples = (
439 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
440 | )
441 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
442 |
443 | trainer.log_metrics("predict", metrics)
444 | trainer.save_metrics("predict", metrics)
445 |
446 | if trainer.is_world_process_zero():
447 | if training_args.predict_with_generate:
448 | predictions = tokenizer.batch_decode(
449 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
450 | )
451 | predictions = [pred.strip() for pred in predictions]
452 | labels = tokenizer.batch_decode(
453 | predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
454 | )
455 | labels = [label.strip() for label in labels]
456 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
457 | with open(output_prediction_file, "w", encoding="utf-8") as writer:
458 | for p, l in zip(predictions, labels):
459 | res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False)
460 | writer.write(f"{res}\n")
461 | return results
462 |
463 |
464 | def _mp_fn(index):
465 | # For xla_spawn (TPUs)
466 | main()
467 |
468 |
469 | if __name__ == "__main__":
470 | main()
471 |
--------------------------------------------------------------------------------