├── config ├── bitfit_config.json ├── adapter_config.json └── lora_config.json ├── scripts ├── run_glue_finetune.sh ├── run_glue_bitfit.sh ├── run_glue_lora.sh ├── run_glue_adapter.sh ├── run_glue_sora_schedule_dense.sh └── run_glue_sora_no_schedule.sh ├── requirements.txt ├── README.md ├── src ├── util.py ├── glue_tasks.py ├── processor.py ├── sparse_optimizer_multiply_lr.py ├── sparse_optimizer.py └── lora.py ├── glue.py ├── run_glue_adapter.py ├── run_glue_bitfit.py ├── run_glue_finetune.py └── run_glue.py /config/bitfit_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "unfrozen_modules": [ 3 | "deltas", 4 | "layer_norm", 5 | "final_layer_norm" 6 | ] 7 | } -------------------------------------------------------------------------------- /config/adapter_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "bottleneck_dim": 12, 3 | "unfrozen_modules": [ 4 | "deltas", 5 | "layer_norm", 6 | "final_layer_norm" 7 | ] 8 | } -------------------------------------------------------------------------------- /config/lora_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "lora_r": 32, 3 | "lora_alpha": 16, 4 | "unfrozen_modules": [ 5 | "deltas", 6 | "layer_norm", 7 | "final_layer_norm" 8 | ] 9 | } -------------------------------------------------------------------------------- /scripts/run_glue_finetune.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | for seed in 100 3 | do 4 | for lr in 5e-5 5 | do 6 | task=mnli-m # can be cola/mrpc/rte/stsb/qnli/sst2/qqp/mnli-m/mnli-mm 7 | bsz=80 8 | epoch=10 9 | echo $lr 10 | echo $seed 11 | echo $bsz 12 | echo $task 13 | CUDA_VISIBLE_DEVICES=0 \ 14 | python -u run_glue_finetune.py \ 15 | --do_eval \ 16 | --do_predict \ 17 | --do_train \ 18 | --task_name $task \ 19 | --eval_steps 1000 \ 20 | --evaluation_strategy steps \ 21 | --greater_is_better true \ 22 | --learning_rate $lr \ 23 | --max_grad_norm 0.1 \ 24 | --load_best_model_at_end \ 25 | --logging_steps 100 \ 26 | --max_steps -1 \ 27 | --model_name_or_path /root/xtlv/data/models/DeBERTaV3_base \ 28 | --num_train_epochs $epoch \ 29 | --output_dir results/${task}/${task}_lr_${lr}_bsz_${bsz}_epoch_${epoch}_seed_${seed} \ 30 | --overwrite_output_dir \ 31 | --per_device_eval_batch_size $bsz \ 32 | --per_device_train_batch_size $bsz \ 33 | --save_steps 1000 \ 34 | --save_strategy steps \ 35 | --save_total_limit 1 \ 36 | --tokenizer_name /root/xtlv/data/models/DeBERTaV3_base \ 37 | --warmup_ratio 0.06 \ 38 | --warmup_steps 0 \ 39 | --weight_decay 0.1 \ 40 | --disable_tqdm true \ 41 | --load_best_model_at_end \ 42 | --ddp_find_unused_parameters false \ 43 | --seed $seed \ 44 | --max_seq_length 256 > results/${task}/${task}_lr_${lr}_bsz_${bsz}_epoch_${epoch}_seed_${seed}.log 2>&1 45 | done 46 | done 47 | 48 | -------------------------------------------------------------------------------- /scripts/run_glue_bitfit.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | export WANDB_DISABLED=true 3 | for task in mnli-m # can be cola/mrpc/rte/stsb/qnli/sst2/qqp/mnli-m/mnli-mm 4 | do 5 | for lr in 8e-4 6 | do 7 | for seed in 100 8 | do 9 | bsz=100 10 | epoch=10 11 | echo $lr 12 | echo $seed 13 | echo $bsz 14 | echo $task 15 | CUDA_VISIBLE_DEVICES=0 \ 16 | python -u run_glue_bitfit.py \ 17 | --do_eval \ 18 | --do_predict \ 19 | --do_train \ 20 | --task_name $task \ 21 | --eval_steps 1000 \ 22 | --evaluation_strategy steps \ 23 | --greater_is_better true \ 24 | --learning_rate $lr \ 25 | --max_grad_norm 0.1 \ 26 | --load_best_model_at_end \ 27 | --logging_steps 100 \ 28 | --max_steps -1 \ 29 | --model_name_or_path /root/xtlv/data/models/DeBERTaV3_base \ 30 | --num_train_epochs $epoch \ 31 | --output_dir results/${task}_lr_${lr}_bsz_${bsz}_epoch_${epoch}_seed_${seed} \ 32 | --overwrite_output_dir \ 33 | --per_device_eval_batch_size $bsz \ 34 | --per_device_train_batch_size $bsz \ 35 | --save_steps 1000 \ 36 | --save_strategy steps \ 37 | --save_total_limit 1 \ 38 | --tokenizer_name /root/xtlv/data/models/DeBERTaV3_base \ 39 | --warmup_ratio 0.06 \ 40 | --warmup_steps 0 \ 41 | --weight_decay 0.1 \ 42 | --disable_tqdm true \ 43 | --load_best_model_at_end \ 44 | --ddp_find_unused_parameters false \ 45 | --sparse_lambda 0 \ 46 | --seed $seed \ 47 | --max_seq_length 256 > results/${task}_lr_${lr}_bsz_${bsz}_epoch_${epoch}_seed_${seed}.log 2>&1 48 | wait 49 | done 50 | done 51 | done -------------------------------------------------------------------------------- /scripts/run_glue_lora.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | for lora_r in 16 3 | do 4 | for seed in 100 5 | do 6 | task=mnli-mm # can be cola/mrpc/rte/stsb/qnli/sst2/qqp/mnli-m/mnli-mm 7 | lr=3e-4 8 | bsz=8 # rte: 32; the other: 8 9 | epoch=10 10 | echo $lr 11 | echo $seed 12 | echo $bsz 13 | echo $task 14 | CUDA_VISIBLE_DEVICES=0 \ 15 | python -u run_glue.py \ 16 | --do_eval \ 17 | --do_predict \ 18 | --do_train \ 19 | --task_name $task \ 20 | --eval_steps 1000 \ 21 | --evaluation_strategy steps \ 22 | --greater_is_better true \ 23 | --learning_rate $lr \ 24 | --max_grad_norm 0.1 \ 25 | --load_best_model_at_end \ 26 | --logging_steps 100 \ 27 | --max_steps -1 \ 28 | --model_name_or_path /root/xtlv/data/models/DeBERTaV3_base \ 29 | --num_train_epochs $epoch \ 30 | --output_dir results/${task}_lora_r_${lora_r}_lr_${lr}_bsz_${bsz}_epoch_${epoch}_seed_${seed} \ 31 | --overwrite_output_dir \ 32 | --per_device_eval_batch_size $bsz \ 33 | --per_device_train_batch_size $bsz \ 34 | --save_steps 1000 \ 35 | --save_strategy steps \ 36 | --save_total_limit 1 \ 37 | --tokenizer_name /root/xtlv/data/models/DeBERTaV3_base \ 38 | --warmup_ratio 0.06 \ 39 | --warmup_steps 0 \ 40 | --weight_decay 0.1 \ 41 | --disable_tqdm true \ 42 | --load_best_model_at_end \ 43 | --ddp_find_unused_parameters false \ 44 | --sparse_lambda 0 \ 45 | --seed $seed \ 46 | --lora_r $lora_r \ 47 | --max_seq_length 256 > results/${task}_lora_r_${lora_r}_lr_${lr}_bsz_${bsz}_epoch_${epoch}_seed_${seed}.log 2>&1 48 | done 49 | done 50 | -------------------------------------------------------------------------------- /scripts/run_glue_adapter.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | export WANDB_DISABLED=true 3 | for task in cola # can be cola/mrpc/rte/stsb/qnli/sst2/qqp/mnli-m/mnli-mm 4 | do 5 | for seed in 0 21 42 81 100 6 | do 7 | bottleneck_dim=16 8 | lr=8e-4 9 | bsz=8 10 | epoch=20 11 | echo $bottleneck_dim 12 | echo $lr 13 | echo $seed 14 | echo $bsz 15 | echo $task 16 | CUDA_VISIBLE_DEVICES=0 \ 17 | python -u run_glue_adapter.py \ 18 | --do_eval \ 19 | --do_predict \ 20 | --do_train \ 21 | --task_name $task \ 22 | --eval_steps 1000 \ 23 | --evaluation_strategy steps \ 24 | --greater_is_better true \ 25 | --learning_rate $lr \ 26 | --max_grad_norm 0.1 \ 27 | --load_best_model_at_end \ 28 | --logging_steps 100 \ 29 | --max_steps -1 \ 30 | --model_name_or_path /root/xtlv/data/models/DeBERTaV3_base \ 31 | --num_train_epochs $epoch \ 32 | --output_dir results/${task}_bottleneckdim_${bottleneck_dim}_lr_${lr}_bsz_${bsz}_epoch_${epoch}_seed_${seed} \ 33 | --overwrite_output_dir \ 34 | --per_device_eval_batch_size $bsz \ 35 | --per_device_train_batch_size $bsz \ 36 | --save_steps 1000 \ 37 | --save_strategy steps \ 38 | --save_total_limit 1 \ 39 | --tokenizer_name /root/xtlv/data/models/DeBERTaV3_base \ 40 | --warmup_ratio 0.06 \ 41 | --warmup_steps 0 \ 42 | --weight_decay 0.1 \ 43 | --disable_tqdm true \ 44 | --load_best_model_at_end \ 45 | --ddp_find_unused_parameters false \ 46 | --sparse_lambda 0 \ 47 | --seed $seed \ 48 | --bottleneck_dim $bottleneck_dim \ 49 | --max_seq_length 128 > results/${task}_bottleneckdim_${bottleneck_dim}_lr_${lr}_bsz_${bsz}_epoch_${epoch}_seed_${seed}.log 2>&1 50 | wait 51 | done 52 | done -------------------------------------------------------------------------------- /scripts/run_glue_sora_schedule_dense.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | export WANDB_DISABLED=false 3 | for task in rte # can be cola/mrpc/rte/stsb/qnli/sst2/qqp/mnli-m/mnli-mm 4 | do 5 | lambda=0.001 6 | lambda2=0 7 | max_lambda=7e-4 8 | lambda_num=7 9 | lr=1.2e-3 10 | r=8 11 | epoch=50 12 | seed=48 13 | bsz=32 14 | epoch2=15 15 | echo $task 16 | echo "lambda=" $lambda 17 | echo $seed 18 | CUDA_VISIBLE_DEVICES=0 \ 19 | python -u run_glue.py \ 20 | --do_eval \ 21 | --do_predict \ 22 | --do_train \ 23 | --task_name $task \ 24 | --eval_steps 1000 \ 25 | --evaluation_strategy steps \ 26 | --greater_is_better true \ 27 | --learning_rate $lr \ 28 | --max_grad_norm 0.1 \ 29 | --load_best_model_at_end \ 30 | --logging_steps 100 \ 31 | --max_steps -1 \ 32 | --model_name_or_path /root/xtlv/data/models/DeBERTaV3_base \ 33 | --num_train_epochs $epoch \ 34 | --output_dir results/$task-lambda2_${lambda2}_${max_lambda}_lambda_${lambda}_epoch_${epoch}_seed_${seed}_${epoch2} \ 35 | --overwrite_output_dir \ 36 | --per_device_eval_batch_size $bsz \ 37 | --per_device_train_batch_size $bsz \ 38 | --save_steps 1000 \ 39 | --save_strategy steps \ 40 | --save_total_limit 1 \ 41 | --tokenizer_name /root/xtlv/data/models/DeBERTaV3_base \ 42 | --warmup_ratio 0.06 \ 43 | --warmup_steps 0 \ 44 | --weight_decay 0.1 \ 45 | --disable_tqdm true \ 46 | --load_best_model_at_end \ 47 | --ddp_find_unused_parameters false \ 48 | --sparse_lambda $lambda \ 49 | --sparse_lambda_2 $lambda2 \ 50 | --seed $seed \ 51 | --lora_r $r \ 52 | --max_seq_length 320 \ 53 | --max_lambda $max_lambda \ 54 | --lambda_schedule linear \ 55 | --lambda_num $lambda_num \ 56 | --train_sparse > results/$task-lambda2_${lambda2}_${max_lambda}_lambda_${lambda}_epoch_${epoch}_seed_${seed}_${epoch2}.log 2>&1 57 | wait 58 | done -------------------------------------------------------------------------------- /scripts/run_glue_sora_no_schedule.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | export WANDB_DISABLED=false 3 | for task in stsb # can be cola/mrpc/rte/stsb/qnli/sst2/qqp/mnli-m/mnli-mm 4 | do 5 | for lora_r in 8 6 | do 7 | for lambda in 10 8 | do 9 | for lambda2 in 3e-4 10 | do 11 | for seed in 0 21 42 81 100 12 | do 13 | for lr in 8e-4 14 | do 15 | epoch=20 16 | bsz=8 17 | echo $task 18 | echo "lambda=" $lambda 19 | echo "lambda2=" $lambda2 20 | echo "lora_r=" $lora_r 21 | echo "seed=" $seed 22 | CUDA_VISIBLE_DEVICES=0 \ 23 | python -u run_glue.py \ 24 | --do_eval \ 25 | --do_predict \ 26 | --do_train \ 27 | --task_name $task \ 28 | --eval_steps 1000 \ 29 | --evaluation_strategy steps \ 30 | --greater_is_better true \ 31 | --learning_rate $lr \ 32 | --max_grad_norm 0.1 \ 33 | --load_best_model_at_end \ 34 | --logging_steps 100 \ 35 | --max_steps -1 \ 36 | --model_name_or_path /root/xtlv/data/models/DeBERTaV3_base \ 37 | --num_train_epochs $epoch \ 38 | --output_dir results/${task}_lora_r_${lora_r}_lambda_${lambda}_lambda2_${lambda2}_lr_${lr}_epoch_${epoch}_bsz_${bsz}_seed_${seed} \ 39 | --overwrite_output_dir \ 40 | --per_device_eval_batch_size $bsz \ 41 | --per_device_train_batch_size $bsz \ 42 | --save_steps 1000 \ 43 | --save_strategy steps \ 44 | --save_total_limit 1 \ 45 | --tokenizer_name /root/xtlv/data/models/DeBERTaV3_base \ 46 | --warmup_ratio 0.06 \ 47 | --warmup_steps 0 \ 48 | --weight_decay 0.1 \ 49 | --disable_tqdm true \ 50 | --load_best_model_at_end \ 51 | --ddp_find_unused_parameters false \ 52 | --sparse_lambda $lambda \ 53 | --sparse_lambda_2 $lambda2 \ 54 | --seed $seed \ 55 | --lora_r $lora_r \ 56 | --max_seq_length 128 \ 57 | --train_sparse > results/${task}_lora_r_${lora_r}_lambda_${lambda}_lambda2_${lambda2}_lr_${lr}_epoch_${epoch}_bsz_${bsz}_seed_${seed}.log 2>&1 58 | wait 59 | done 60 | done 61 | done 62 | done 63 | done 64 | done -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.3.0 2 | aiohttp==3.8.3 3 | aiosignal==1.3.1 4 | aliyun-python-sdk-core==2.13.36 5 | aliyun-python-sdk-kms==2.16.0 6 | appdirs==1.4.4 7 | asttokens==2.0.5 8 | async-timeout==4.0.2 9 | attrs==22.1.0 10 | backcall==0.2.0 11 | bigmodelvis==0.0.1 12 | blessings==1.7 13 | Bottleneck==1.3.5 14 | brotlipy==0.7.0 15 | certifi==2023.5.7 16 | cffi==1.15.1 17 | charset-normalizer==2.0.4 18 | cheroot==9.0.0 19 | click==8.0.4 20 | commonmark==0.9.1 21 | contourpy==1.0.5 22 | crcmod==1.7 23 | cryptography==38.0.1 24 | cycler==0.11.0 25 | datasets==1.17.0 26 | debugpy==1.5.1 27 | decorator==5.1.1 28 | delta-center-client==0.0.4 29 | dill==0.3.6 30 | docker-pycreds==0.4.0 31 | entrypoints==0.4 32 | executing==0.8.3 33 | filelock==3.6.0 34 | fonttools==4.25.0 35 | frozenlist==1.3.3 36 | fsspec==2022.11.0 37 | gitdb==4.0.9 38 | GitPython==3.1.31 39 | gpustat==0.6.0 40 | huggingface-hub==0.10.1 41 | idna==3.4 42 | importlib-resources==5.2.0 43 | ipykernel==6.15.2 44 | ipython==8.6.0 45 | jaraco.functools==3.5.2 46 | jedi==0.18.1 47 | jmespath==0.10.0 48 | joblib==1.1.1 49 | jupyter_client==7.4.8 50 | jupyter_core==4.11.2 51 | kiwisolver==1.4.4 52 | loralib==0.1.0 53 | matplotlib==3.7.1 54 | matplotlib-inline==0.1.6 55 | mkl-fft==1.3.1 56 | mkl-random==1.2.2 57 | mkl-service==2.4.0 58 | more-itertools==9.0.0 59 | multidict==6.0.2 60 | multiprocess==0.70.14 61 | munkres==1.1.4 62 | nest-asyncio==1.5.5 63 | nltk==3.7 64 | numexpr==2.8.4 65 | numpy==1.23.4 66 | nvidia-ml-py3==7.352.0 67 | opendelta==0.3.2 68 | oss2==2.15.0 69 | packaging==21.3 70 | pandas==1.5.3 71 | parso==0.8.3 72 | pathtools==0.1.2 73 | pexpect==4.8.0 74 | pickleshare==0.7.5 75 | Pillow==9.2.0 76 | pip==22.2.2 77 | ply==3.11 78 | prompt-toolkit==3.0.20 79 | protobuf==4.22.3 80 | psutil==5.9.0 81 | ptyprocess==0.7.0 82 | pure-eval==0.2.2 83 | pyarrow==10.0.0 84 | pycparser==2.21 85 | pycryptodome==3.15.0 86 | Pygments==2.13.0 87 | pyOpenSSL==22.0.0 88 | pyparsing==3.0.9 89 | PyQt5-sip==12.11.0 90 | PySocks==1.7.1 91 | python-dateutil==2.8.2 92 | pytz==2022.7 93 | PyYAML==6.0 94 | pyzmq==23.2.0 95 | regex==2022.7.9 96 | requests==2.28.1 97 | rich==12.6.0 98 | rouge-score==0.1.2 99 | sacremoses==0.0.43 100 | scikit-learn==1.1.3 101 | scipy==1.9.3 102 | seaborn==0.12.2 103 | sentencepiece==0.1.97 104 | sentry-sdk==1.19.1 105 | setproctitle==1.3.2 106 | setuptools==65.5.0 107 | sip==6.6.2 108 | six==1.16.0 109 | sklearn==0.0.post1 110 | smmap==5.0.0 111 | stack-data==0.2.0 112 | tensorboardX==2.6.1 113 | threadpoolctl==2.2.0 114 | tokenizers==0.10.3 115 | toml==0.10.2 116 | torch==1.11.0 117 | torchaudio==0.11.0 118 | torchvision==0.12.0 119 | tornado==6.2 120 | tqdm==4.64.1 121 | traitlets==5.1.1 122 | transformers==4.14.1 123 | typing_extensions==4.3.0 124 | urllib3==1.26.12 125 | wandb==0.14.2 126 | wcwidth==0.2.5 127 | web.py==0.62 128 | wheel==0.37.1 129 | xxhash==3.1.0 130 | yacs==0.1.8 131 | yarl==1.8.1 132 | zipp==3.11.0 133 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

Sparse Low-rank Adaptation of Pre-trained Language Models

4 | 5 |
6 | 7 | 🎉 This is the implementation of EMNLP 2023 paper:[Sparse Low-rank Adaptation of Pre-trained Language Models](https://arxiv.org/abs/2311.11696) 8 | 9 | 10 | ## Requirements 11 | 12 | To run our code, please install all the dependency packages by using the following command: 13 | 14 | ``` 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | ## Preparation 19 | 20 | ### Prepare the Data and Modify the Data Path 21 | 22 | In the paper/code, we use the GLUE datasets, you can download the data from Huggingface or from our [Google Drive](https://drive.google.com/drive/folders/1sNoQIp1x-5aXH4r9dOoSdsm5F1kihg_W?usp=sharing) 23 | 24 | After download the data, please replace the following data path definition with your data path: 25 | 26 | - `main_dir` in Line 27 of `SoRA/src/glue_tasks.py` 27 | - `main_dir` in Line 9 of `SoRA/src/processor.py` 28 | - `data_path` in Line 88 of `SoRA/run_glue.py`, `SoRA/run_glue_adapter.py`, `SoRA/run_glue_bitfit.py` and `SoRA/run_glue_adapter.py` 29 | 30 | ### Prepare the model 31 | 32 | You can download the base model and the corresponding tokenizer from Huggingface. And after that, do not forget to modify the `model_name_or_path` and `tokenizer_name` in script file (.sh). 33 | 34 | 35 | ## Baseline 36 | 37 | We provide the implementation of LoRA, Adapter, BitFit and Full-parameter Fine-Tune. You can apply these baselines by running the following codes: 38 | 39 | ```bash 40 | cd scripts 41 | # LoRA 42 | bash run_glue_lora.sh 43 | # Adapter 44 | bash run_glue_adapter.sh 45 | # BitFit 46 | bash run_glue_bitfit.sh 47 | # Full-parameter Fine-Tune 48 | bash run_glue_finetune.sh 49 | ``` 50 | 51 | ## SoRA 52 | 53 | You can apply SoRA by running the following codes: 54 | 55 | ```bash 56 | cd scripts 57 | # without the sparsifying scheduler 58 | bash run_glue_sora_no_schedule.sh 59 | # with the sparsifying scheduler (Algorithm 1) 60 | bash run_glue_sora_schedule_dense.sh 61 | ``` 62 | 63 | We explain some of the arguments as follows: 64 | 65 | - `sparse_lambda`: The hyperparameters $\eta_t$ in paper. 66 | - `sparse_lambda_2`: The hyperparameters $\xi$ in paper. 67 | - `lora_r`: The hyperparameters $r_{max}$ in paper. 68 | - `train_sparse`: The argument to decide whether or not to apply SoRA. 69 | - `lambda_schedule`: The strategies for sparsifying schedulers. Possible values are `linear`, `log_linear` and `exp_linear`. 70 | - `max_lambda`: The max $\xi$ when applying sparsifying scheduler. 71 | - `lambda_num`: The num of the indicator $\xi$ when applying sparsifying scheduler. 72 | 73 | 74 | ## Bugs or questions? 75 | 76 | If you have any questions related to the codes or the paper, please contact Ning Ding (`dn97@mail.tsinghua.edu.cn`), Xingtai Lv (`lvxt20@mails.tsinghua.edu.cn`) or open an issue. 77 | 78 | ## Citation 79 | 80 | If you find our work useful, please use the following citation: 81 | 82 | ```bibtex 83 | @article{ding2023sparse, 84 | title={Sparse Low-rank Adaptation of Pre-trained Language Models}, 85 | author={Ding, Ning and Lv, Xingtai and Wang, Qiaosen and Chen, Yulin and Zhou, Bowen and Liu, Zhiyuan and Sun, Maosong}, 86 | journal={arXiv preprint arXiv:2311.11696}, 87 | year={2023} 88 | } 89 | ``` 90 | -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.parallel import DistributedDataParallel 3 | from transformers.trainer_pt_utils import get_parameter_names 4 | import torch.nn as nn 5 | from transformers import AdamW, get_linear_schedule_with_warmup 6 | GATE_PARAM_NAME= "lora.gate" 7 | 8 | def compute_trainable_sparse_param(model): 9 | if isinstance(model, DistributedDataParallel): 10 | model = model.module 11 | 12 | total_trainable_param = 0 13 | deduct = 0 14 | for n, p in model.named_parameters(): 15 | if p.requires_grad: 16 | if GATE_PARAM_NAME in n: 17 | deduct += (torch.numel(p) - torch.count_nonzero(p)) * model.config.hidden_size * 2 # zero_number * 768 * 2 18 | else: 19 | total_trainable_param += torch.numel(p) 20 | sparse_trainable_param = total_trainable_param - deduct 21 | return sparse_trainable_param, total_trainable_param 22 | 23 | def create_optimizer_and_scheduler(args, model, num_training_steps: int): 24 | """ 25 | Setup the optimizer and the learning rate scheduler. 26 | 27 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the 28 | Trainer's init through :obj:`optimizers`, or subclass and override this method (or :obj:`create_optimizer` 29 | and/or :obj:`create_scheduler`) in a subclass. 30 | """ 31 | optimizer = create_optimizer(args, model) 32 | scheduler = create_scheduler(args, num_training_steps=num_training_steps, optimizer=optimizer) 33 | return optimizer, scheduler 34 | 35 | def create_optimizer(args, model): 36 | """ 37 | Setup the optimizer. 38 | 39 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the 40 | Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass. 41 | """ 42 | 43 | decay_parameters = get_parameter_names(model, [nn.LayerNorm]) 44 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 45 | print(f"removing {GATE_PARAM_NAME} from standard optimizer") 46 | optimizer_grouped_parameters = [ 47 | { 48 | "params": [p for n, p in model.named_parameters() if n in decay_parameters and GATE_PARAM_NAME not in n and p.requires_grad], 49 | "weight_decay": args.weight_decay, 50 | }, 51 | { 52 | "params": [p for n, p in model.named_parameters() if n not in decay_parameters and GATE_PARAM_NAME not in n and p.requires_grad], 53 | "weight_decay": 0.0, 54 | }, 55 | ] 56 | optimizer_kwargs = { 57 | "betas": (args.adam_beta1, args.adam_beta2), 58 | "eps": args.adam_epsilon, 59 | } 60 | optimizer_kwargs["lr"] = args.learning_rate 61 | optimizer = AdamW(optimizer_grouped_parameters, **optimizer_kwargs) 62 | 63 | return optimizer 64 | 65 | def create_scheduler(args, num_training_steps: int, optimizer: torch.optim.Optimizer = None): 66 | """ 67 | Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or 68 | passed as an argument. 69 | 70 | Args: 71 | num_training_steps (int): The number of training steps to do. 72 | """ 73 | lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.get_warmup_steps(num_training_steps), 74 | num_training_steps=num_training_steps) 75 | return lr_scheduler 76 | -------------------------------------------------------------------------------- /src/glue_tasks.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import collections 3 | import abc 4 | import functools 5 | from selectors import EpollSelector 6 | from typing import Callable, List, Mapping 7 | import datasets 8 | import logging 9 | import numpy as np 10 | import torch 11 | import re 12 | import itertools 13 | import os 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | from transformers.models.auto.tokenization_auto import tokenizer_class_from_name 19 | 20 | from typing import List, Dict 21 | from collections import defaultdict 22 | import warnings 23 | 24 | 25 | from .processor import AbstractTask 26 | 27 | main_dir = "/root/xtlv/data/sora_datasets/glue_datasets_from_dn" 28 | 29 | ##GLUE 30 | class COLA(AbstractTask): 31 | name = "cola" 32 | split_to_data_split = {"train": "train", 33 | "validation": "validation", 34 | "test": "validation"} 35 | 36 | def load_dataset(self, split): 37 | return datasets.load_from_disk(f"{main_dir}/cola")[split] 38 | 39 | 40 | class SST2(AbstractTask): 41 | name = "sst2" 42 | split_to_data_split = {"train": "train", 43 | "validation": "validation", 44 | "test": "validation"} 45 | 46 | 47 | 48 | 49 | class MRPC(AbstractTask): 50 | name = "mrpc" 51 | split_to_data_split = {"train": "train", 52 | "validation": "validation", 53 | "test": "validation"} 54 | 55 | 56 | class QQP(AbstractTask): 57 | name = "qqp" 58 | split_to_data_split = {"train": "train", 59 | "validation": "validation", 60 | "test": "validation"} 61 | 62 | class STSB(AbstractTask): 63 | name = "stsb" 64 | split_to_data_split = {"train": "train", 65 | "validation": "validation", 66 | "test": "validation"} 67 | 68 | 69 | 70 | class MNLI(AbstractTask): 71 | name = "mnli" 72 | split_to_data_split = {"train": "train", 73 | "validation": "validation_matched", 74 | "test": "validation_matched"} 75 | 76 | class MNLI_M(AbstractTask): 77 | name = "mnli" 78 | split_to_data_split = {"train": "train", 79 | "validation": "validation_matched", 80 | "test": "validation_matched"} 81 | 82 | class MNLI_MM(AbstractTask): 83 | name = "mnli" 84 | split_to_data_split = {"train": "train", 85 | "validation": "validation_mismatched", 86 | "test": "validation_mismatched"} 87 | 88 | 89 | class QNLI(AbstractTask): 90 | name = "qnli" 91 | split_to_data_split = {"train": "train", 92 | "validation": "validation", 93 | "test": "validation"} 94 | 95 | 96 | #Tested 97 | class RTE(AbstractTask): 98 | name = "rte" 99 | split_to_data_split = {"train": "train", 100 | "validation": "validation", 101 | "test": "validation"} 102 | 103 | class WNLI(AbstractTask): 104 | name = "wnli" 105 | split_to_data_split = {"train": "train", 106 | "validation": "validation", 107 | "test": "validation"} 108 | 109 | 110 | TASK_MAPPING = OrderedDict( 111 | [ 112 | ('mrpc', MRPC), 113 | ('cola', COLA), 114 | ('sst2', SST2), 115 | ('qnli', QNLI), 116 | ('rte', RTE), 117 | ('wnli', WNLI), 118 | ('mnli', MNLI), 119 | ('mnli-m', MNLI_M), 120 | ('mnli-mm', MNLI_MM), 121 | ('qqp', QQP), 122 | ('stsb', STSB), 123 | ] 124 | ) 125 | 126 | class AutoTask: 127 | @classmethod 128 | def get(self, task, config, data_args, seed=42): 129 | if task in TASK_MAPPING: 130 | return TASK_MAPPING[task](config, data_args, seed) 131 | raise ValueError( 132 | "Unrecognized task {} for AutoTask Model: {}.\n" 133 | "Task name should be one of {}.".format( 134 | ", ".join(c for c in TASK_MAPPING.keys()) 135 | ) 136 | ) 137 | 138 | if __name__ == "__main__": 139 | for name in TASK_MAPPING: 140 | print(name) 141 | task = AutoTask().get(name, None, None) 142 | print(task.split_train_to_make_test) 143 | print(task.split_valid_to_make_test) 144 | train_set = task.get("train", split_validation_test=True) 145 | print(train_set[0]) 146 | -------------------------------------------------------------------------------- /src/processor.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Callable, List, Mapping, Dict 3 | import datasets 4 | import logging 5 | import numpy as np 6 | import torch 7 | logger = logging.getLogger(__name__) 8 | 9 | main_dir = "/root/xtlv/data/sora_datasets/glue_datasets_from_dn" 10 | 11 | class AbstractTask(abc.ABC): 12 | name = NotImplemented 13 | config = NotImplemented 14 | prefix = NotImplemented 15 | split_map = None 16 | split_to_data_split: Mapping[str, str] = \ 17 | {"train": "train", "validation": "validation", "test": "test"} 18 | small_datasets_without_all_splits = ["cola", "wnli", "rte", "superglue-cb", "superglue-copa", "superglue-multirc", 19 | "superglue-wic", "superglue-wsc.fixed", "superglue-rte", "mrpc", "stsb", 20 | "superglue-boolq", "mnli"] 21 | large_data_without_all_splits = ["qqp", "qnli", "superglue-record", "sst2"] 22 | 23 | split_valid_to_make_test = True 24 | split_train_to_make_test = False 25 | keep_fields_after_preprocess = ["label"] # The fields that should be kept even after preprocessiing 26 | 27 | def __init__(self, config, data_args, seed=42, default_max_length=1): 28 | self.config = config 29 | self.seed = seed 30 | self.data_args = data_args 31 | 32 | self.default_max_length = default_max_length 33 | self.__post_init__() 34 | 35 | def __post_init__(self): 36 | self.split_valid_to_make_test = self.name in self.small_datasets_without_all_splits 37 | self.split_train_to_make_test = self.name in self.large_data_without_all_splits 38 | 39 | def load_dataset(self, split): 40 | tmp = datasets.load_from_disk(f"{main_dir}/{self.name}") 41 | 42 | return tmp[split] 43 | 44 | def check_n_obs(self, n_obs, total_size): 45 | if n_obs is not None and n_obs > total_size: 46 | n_obs = total_size 47 | logger.warning("n_obs is set to %s", n_obs) 48 | return n_obs 49 | 50 | def shuffled_indices(self, dataset): 51 | num_samples = len(dataset) 52 | generator = torch.Generator() 53 | generator.manual_seed(self.seed) 54 | return torch.randperm(num_samples, generator=generator).tolist() 55 | 56 | def subsample(self, dataset, n_obs=None, indices=None): 57 | """ 58 | Given a dataset returns the subsampled dataset. 59 | :param n_obs: the number of samples of the subsampled dataset. 60 | :param indices: indices to select the samples from, if not given, indices are computed 61 | from by shuffling the given dataset. 62 | :return: subsampled dataset. 63 | """ 64 | num_samples = len(dataset) 65 | n_obs = self.check_n_obs(n_obs, num_samples) 66 | if indices is None: 67 | indices = self.shuffled_indices(dataset) 68 | indices = indices[:n_obs] 69 | return dataset.select(indices) 70 | 71 | 72 | def get_split_indices(self, split, dataset, validation_size): 73 | indices = self.shuffled_indices(dataset) 74 | if split == "validation": 75 | return indices[:validation_size] 76 | else: 77 | return indices[validation_size:] 78 | 79 | def preprocessor(self, example): 80 | return example 81 | 82 | def get(self, split, n_obs=None, split_validation_test=False): 83 | # For small datasets (n_samples < 10K) without test set, we divide validation set to 84 | # half, use one half as test set and one half as validation set. 85 | if split in ["eval", "dev", "valid"]: 86 | split = "validation" 87 | if split_validation_test and self.split_valid_to_make_test \ 88 | and split != "train": 89 | mapped_split = self.split_to_data_split["validation"] 90 | dataset = self.load_dataset(split=mapped_split) 91 | indices = self.get_split_indices(split, dataset, validation_size=len(dataset)//2) 92 | dataset = self.subsample(dataset, n_obs, indices) 93 | # For larger datasets (n_samples > 10K), we divide training set into 1K as 94 | # validation and the rest as training set, keeping the original validation 95 | # set as the test set. 96 | elif split_validation_test and self.split_train_to_make_test \ 97 | and split != "test": 98 | dataset = self.load_dataset(split="train") 99 | indices = self.get_split_indices(split, dataset, validation_size=1000) 100 | dataset = self.subsample(dataset, n_obs, indices) 101 | else: 102 | mapped_split = self.split_to_data_split[split] 103 | dataset = self.load_dataset(split=mapped_split) 104 | # shuffles the data and samples it. 105 | if n_obs is not None: 106 | dataset = self.subsample(dataset, n_obs) 107 | 108 | this_method = getattr(self.__class__, 'preprocessor') 109 | base_method = getattr(AbstractTask, 'preprocessor') 110 | if this_method is not base_method: 111 | return dataset.map(self.preprocessor) 112 | else: 113 | return dataset -------------------------------------------------------------------------------- /src/sparse_optimizer_multiply_lr.py: -------------------------------------------------------------------------------- 1 | from transformers import AdamW 2 | from torch.optim import Optimizer 3 | import torch 4 | import math 5 | import numpy as np 6 | 7 | class SparseAdamW(AdamW): 8 | def __init__(self, 9 | sparse_lambda = 0.1, 10 | lambda_schedule = None, 11 | max_lambda = None, 12 | lambda_num = None, 13 | **kwargs 14 | ): 15 | super().__init__(**kwargs) 16 | self.sparse_lambda = sparse_lambda 17 | print(f"lambda in optimizer={self.sparse_lambda}") 18 | self.lambda_idx = 0 19 | self.lambda_schedule = lambda_schedule 20 | self._build_lambda_list(max_lambda, lambda_num) 21 | 22 | def _build_lambda_list(self, max_lambda, lambda_num): 23 | if self.lambda_schedule is None: 24 | self._lambdas = None 25 | return 26 | if isinstance(self.lambda_schedule, list): 27 | self._lambdas = self.lambda_schedule 28 | if self.lambda_schedule == "linear": 29 | assert max_lambda is not None and lambda_num is not None, print(f"when using linear schedule, max_lambda and lambda_num must be provided, but got ({max_lambda} and {lambda_num})") 30 | self._lambdas = np.linspace(self.sparse_lambda, max_lambda, lambda_num) 31 | elif self.lambda_schedule == "log_linear": 32 | assert max_lambda is not None and lambda_num is not None, print(f"when using log_linear schedule, max_lambda and lambda_num must be provided, but got ({max_lambda} and {lambda_num})") 33 | self._lambdas = np.log(np.linspace(np.exp(self.sparse_lambda), np.exp(max_lambda), lambda_num)) 34 | else: 35 | raise NotImplementedError 36 | 37 | def step_lambda(self): 38 | if self._lambdas is None: 39 | print("no lambda schedule is specified, do nothing") 40 | return 41 | else: 42 | if self.lambda_idx < len(self._lambdas) - 1: 43 | self.lambda_idx += 1 44 | self.sparse_lambda = self._lambdas[self.lambda_idx] 45 | print(f"use lambda={self.sparse_lambda}") 46 | else: 47 | print(f"reach end of self._lambdas, keep using lambda={self.sparse_lambda}") 48 | 49 | 50 | def step(self, closure = None): 51 | """ 52 | Performs a single optimization step. 53 | Arguments: 54 | closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. 55 | """ 56 | loss = None 57 | if closure is not None: 58 | loss = closure() 59 | 60 | for group in self.param_groups: 61 | for p in group["params"]: 62 | if p.grad is None: 63 | continue 64 | grad = p.grad.data 65 | if grad.is_sparse: 66 | raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") 67 | 68 | state = self.state[p] 69 | 70 | # State initialization 71 | if len(state) == 0: 72 | state["step"] = 0 73 | # Exponential moving average of gradient values 74 | state["exp_avg"] = torch.zeros_like(p.data) 75 | # Exponential moving average of squared gradient values 76 | state["exp_avg_sq"] = torch.zeros_like(p.data) 77 | 78 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 79 | beta1, beta2 = group["betas"] 80 | 81 | state["step"] += 1 82 | 83 | # Decay the first and second moment running average coefficient 84 | # In-place operations to update the averages at the same time 85 | exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) 86 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) 87 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 88 | 89 | step_size = group["lr"] 90 | if group["correct_bias"]: # No bias correction for Bert 91 | bias_correction1 = 1.0 - beta1 ** state["step"] 92 | bias_correction2 = 1.0 - beta2 ** state["step"] 93 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 94 | 95 | # Just adding the square of the weights to the loss function is *not* 96 | # the correct way of using L2 regularization/weight decay with Adam, 97 | # since that will interact with the m and v parameters in strange ways. 98 | # 99 | # Instead we want to decay the weights in a manner that doesn't interact 100 | # with the m/v parameters. This is equivalent to adding the square 101 | # of the weights to the loss with plain (non-momentum) SGD. 102 | # Add weight decay at the end (fixed version) 103 | 104 | # params with sparsity regularization do not need weight decay 105 | # still hard to decide: which quantity stands for $\eta$ in Adam? group['lr] or stepsize? 106 | to_add = torch.div(exp_avg, denom) * (-step_size) 107 | if group["weight_decay"] > 0.0: 108 | # p.data.add_(p.data, alpha=(-group["lr"] * group["weight_decay"])) 109 | to_add = to_add + (-group["lr"] * group["weight_decay"]) * p.data 110 | p.data.add_(to_add) 111 | 112 | if self.sparse_lambda > 0: 113 | p.data[p.data > self.sparse_lambda * group["lr"]] -= self.sparse_lambda * group["lr"] 114 | p.data[p.data < -self.sparse_lambda * group["lr"]] += self.sparse_lambda * group["lr"] 115 | p.data[abs(p.data) < self.sparse_lambda * group["lr"]] = 0.0 116 | print("in sparse optimizer lr=", group["lr"]) 117 | 118 | return loss 119 | -------------------------------------------------------------------------------- /src/sparse_optimizer.py: -------------------------------------------------------------------------------- 1 | from transformers import AdamW 2 | from torch.optim import Optimizer 3 | import torch 4 | import math 5 | import numpy as np 6 | 7 | class SparseAdamW(AdamW): 8 | def __init__(self, 9 | sparse_lambda = 0.1, 10 | lambda_schedule = None, 11 | max_lambda = None, 12 | lambda_num = None, 13 | **kwargs 14 | ): 15 | super().__init__(**kwargs) 16 | self.sparse_lambda = sparse_lambda 17 | print(f"lambda in optimizer={self.sparse_lambda}") 18 | self.lambda_idx = 0 19 | self.lambda_schedule = lambda_schedule 20 | self._build_lambda_list(max_lambda, lambda_num) 21 | 22 | def _build_lambda_list(self, max_lambda, lambda_num): 23 | if self.lambda_schedule is None: 24 | self._lambdas = None 25 | return 26 | if isinstance(self.lambda_schedule, list): 27 | self._lambdas = self.lambda_schedule 28 | if self.lambda_schedule == "linear": 29 | assert max_lambda is not None and lambda_num is not None, print(f"when using linear schedule, max_lambda and lambda_num must be provided, but got ({max_lambda} and {lambda_num})") 30 | self._lambdas = np.linspace(self.sparse_lambda, max_lambda, lambda_num) 31 | elif self.lambda_schedule == "log_linear": 32 | assert max_lambda is not None and lambda_num is not None, print(f"when using log_linear schedule, max_lambda and lambda_num must be provided, but got ({max_lambda} and {lambda_num})") 33 | self._lambdas = np.log(np.linspace(np.exp(self.sparse_lambda), np.exp(max_lambda), lambda_num)) 34 | elif self.lambda_schedule == "exp_linear": 35 | assert max_lambda is not None and lambda_num is not None, print(f"when using exp_linear schedule, max_lambda and lambda_num must be provided, but got ({max_lambda} and {lambda_num})") 36 | self._lambdas = np.exp(np.linspace(np.log(self.sparse_lambda), np.log(max_lambda), lambda_num)) 37 | else: 38 | raise NotImplementedError 39 | 40 | def step_lambda(self): 41 | if self._lambdas is None: 42 | print("no lambda schedule is specified, do nothing") 43 | return 44 | else: 45 | if self.lambda_idx < len(self._lambdas) - 1: 46 | self.lambda_idx += 1 47 | self.sparse_lambda = self._lambdas[self.lambda_idx] 48 | print(f"use lambda={self.sparse_lambda}") 49 | else: 50 | print(f"reach end of self._lambdas, keep using lambda={self.sparse_lambda}") 51 | 52 | 53 | def step(self, closure = None): 54 | """ 55 | Performs a single optimization step. 56 | Arguments: 57 | closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. 58 | """ 59 | loss = None 60 | if closure is not None: 61 | loss = closure() 62 | 63 | for group in self.param_groups: 64 | for p in group["params"]: 65 | if p.grad is None: 66 | continue 67 | grad = p.grad.data 68 | if grad.is_sparse: 69 | raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") 70 | 71 | state = self.state[p] 72 | 73 | # State initialization 74 | if len(state) == 0: 75 | state["step"] = 0 76 | # Exponential moving average of gradient values 77 | state["exp_avg"] = torch.zeros_like(p.data) 78 | # Exponential moving average of squared gradient values 79 | state["exp_avg_sq"] = torch.zeros_like(p.data) 80 | 81 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 82 | beta1, beta2 = group["betas"] 83 | 84 | state["step"] += 1 85 | 86 | # Decay the first and second moment running average coefficient 87 | # In-place operations to update the averages at the same time 88 | exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) 89 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) 90 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 91 | 92 | step_size = group["lr"] 93 | if group["correct_bias"]: # No bias correction for Bert 94 | bias_correction1 = 1.0 - beta1 ** state["step"] 95 | bias_correction2 = 1.0 - beta2 ** state["step"] 96 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 97 | 98 | # Just adding the square of the weights to the loss function is *not* 99 | # the correct way of using L2 regularization/weight decay with Adam, 100 | # since that will interact with the m and v parameters in strange ways. 101 | # 102 | # Instead we want to decay the weights in a manner that doesn't interact 103 | # with the m/v parameters. This is equivalent to adding the square 104 | # of the weights to the loss with plain (non-momentum) SGD. 105 | # Add weight decay at the end (fixed version) 106 | 107 | # params with sparsity regularization do not need weight decay 108 | # still hard to decide: which quantity stands for $\eta$ in Adam? group['lr] or stepsize? 109 | to_add = torch.div(exp_avg, denom) * (-step_size) 110 | if group["weight_decay"] > 0.0: 111 | # p.data.add_(p.data, alpha=(-group["lr"] * group["weight_decay"])) 112 | to_add = to_add + (-group["lr"] * group["weight_decay"]) * p.data 113 | p.data.add_(to_add) 114 | 115 | 116 | if self.sparse_lambda > 0: 117 | p.data[p.data > self.sparse_lambda] -= self.sparse_lambda 118 | p.data[p.data < -self.sparse_lambda] += self.sparse_lambda 119 | p.data[abs(p.data) < self.sparse_lambda] = 0.0 120 | 121 | return loss 122 | -------------------------------------------------------------------------------- /glue.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Datasets Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ GLUE benchmark metric. """ 15 | 16 | from scipy.stats import pearsonr, spearmanr 17 | from sklearn.metrics import f1_score, matthews_corrcoef 18 | 19 | import datasets 20 | 21 | 22 | _CITATION = """\ 23 | @inproceedings{wang2019glue, 24 | title={{GLUE}: A Multi-Task Benchmark and Analysis Platform for Natural Language Understanding}, 25 | author={Wang, Alex and Singh, Amanpreet and Michael, Julian and Hill, Felix and Levy, Omer and Bowman, Samuel R.}, 26 | note={In the Proceedings of ICLR.}, 27 | year={2019} 28 | } 29 | """ 30 | 31 | _DESCRIPTION = """\ 32 | GLUE, the General Language Understanding Evaluation benchmark 33 | (https://gluebenchmark.com/) is a collection of resources for training, 34 | evaluating, and analyzing natural language understanding systems. 35 | """ 36 | 37 | _KWARGS_DESCRIPTION = """ 38 | Compute GLUE evaluation metric associated to each GLUE dataset. 39 | Args: 40 | predictions: list of predictions to score. 41 | Each translation should be tokenized into a list of tokens. 42 | references: list of lists of references for each translation. 43 | Each reference should be tokenized into a list of tokens. 44 | Returns: depending on the GLUE subset, one or several of: 45 | "accuracy": Accuracy 46 | "f1": F1 score 47 | "pearson": Pearson Correlation 48 | "spearmanr": Spearman Correlation 49 | "matthews_correlation": Matthew Correlation 50 | Examples: 51 | 52 | >>> glue_metric = datasets.load_metric('glue', 'sst2') # 'sst2' or any of ["mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"] 53 | >>> references = [0, 1] 54 | >>> predictions = [0, 1] 55 | >>> results = glue_metric.compute(predictions=predictions, references=references) 56 | >>> print(results) 57 | {'accuracy': 1.0} 58 | 59 | >>> glue_metric = datasets.load_metric('glue', 'mrpc') # 'mrpc' or 'qqp' 60 | >>> references = [0, 1] 61 | >>> predictions = [0, 1] 62 | >>> results = glue_metric.compute(predictions=predictions, references=references) 63 | >>> print(results) 64 | {'accuracy': 1.0, 'f1': 1.0} 65 | 66 | >>> glue_metric = datasets.load_metric('glue', 'stsb') 67 | >>> references = [0., 1., 2., 3., 4., 5.] 68 | >>> predictions = [0., 1., 2., 3., 4., 5.] 69 | >>> results = glue_metric.compute(predictions=predictions, references=references) 70 | >>> print({"pearson": round(results["pearson"], 2), "spearmanr": round(results["spearmanr"], 2)}) 71 | {'pearson': 1.0, 'spearmanr': 1.0} 72 | 73 | >>> glue_metric = datasets.load_metric('glue', 'cola') 74 | >>> references = [0, 1] 75 | >>> predictions = [0, 1] 76 | >>> results = glue_metric.compute(predictions=predictions, references=references) 77 | >>> print(results) 78 | {'matthews_correlation': 1.0} 79 | """ 80 | 81 | 82 | def simple_accuracy(preds, labels): 83 | return float((preds == labels).mean()) 84 | 85 | 86 | def acc_and_f1(preds, labels): 87 | acc = simple_accuracy(preds, labels) 88 | f1 = float(f1_score(y_true=labels, y_pred=preds)) 89 | return { 90 | "accuracy": acc, 91 | "f1": f1, 92 | } 93 | 94 | 95 | def pearson_and_spearman(preds, labels): 96 | pearson_corr = float(pearsonr(preds, labels)[0]) 97 | spearman_corr = float(spearmanr(preds, labels)[0]) 98 | return { 99 | "pearson": pearson_corr, 100 | "spearmanr": spearman_corr, 101 | } 102 | 103 | 104 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 105 | class Glue(datasets.Metric): 106 | def _info(self): 107 | if self.config_name not in [ 108 | "sst2", 109 | "mnli", 110 | "mnli_mismatched", 111 | "mnli_matched", 112 | "cola", 113 | "stsb", 114 | "mrpc", 115 | "qqp", 116 | "qnli", 117 | "rte", 118 | "wnli", 119 | "hans", 120 | ]: 121 | raise KeyError( 122 | "You should supply a configuration name selected in " 123 | '["sst2", "mnli", "mnli_mismatched", "mnli_matched", ' 124 | '"cola", "stsb", "mrpc", "qqp", "qnli", "rte", "wnli", "hans"]' 125 | ) 126 | return datasets.MetricInfo( 127 | description=_DESCRIPTION, 128 | citation=_CITATION, 129 | inputs_description=_KWARGS_DESCRIPTION, 130 | features=datasets.Features( 131 | { 132 | "predictions": datasets.Value("int64" if self.config_name != "stsb" else "float32"), 133 | "references": datasets.Value("int64" if self.config_name != "stsb" else "float32"), 134 | } 135 | ), 136 | codebase_urls=[], 137 | reference_urls=[], 138 | format="numpy", 139 | ) 140 | 141 | def _compute(self, predictions, references): 142 | if self.config_name == "cola": 143 | return {"matthews_correlation": matthews_corrcoef(references, predictions)} 144 | elif self.config_name == "stsb": 145 | return pearson_and_spearman(predictions, references) 146 | elif self.config_name in ["mrpc", "qqp"]: 147 | return acc_and_f1(predictions, references) 148 | elif self.config_name in ["sst2", "mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"]: 149 | return {"accuracy": simple_accuracy(predictions, references)} 150 | else: 151 | raise KeyError( 152 | "You should supply a configuration name selected in " 153 | '["sst2", "mnli", "mnli_mismatched", "mnli_matched", ' 154 | '"cola", "stsb", "mrpc", "qqp", "qnli", "rte", "wnli", "hans"]' 155 | ) -------------------------------------------------------------------------------- /src/lora.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func 4 | from opendelta.utils.name_based_addressing import * 5 | from opendelta.basemodel import DeltaBase 6 | import torch.nn as nn 7 | from opendelta import BaseDeltaConfig 8 | import math 9 | from dataclasses import dataclass, field 10 | import torch 11 | 12 | """ 13 | implementation of sparse lora 14 | """ 15 | 16 | class LowRankLinear(nn.Module): 17 | # ------------------------------------------------------------------------------------------ 18 | # Copyright (c) Microsoft Corporation. All rights reserved. 19 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 20 | # ------------------------------------------------------------------------------------------ 21 | # copy from loralib and do some refactor 22 | def __init__(self, 23 | in_features, 24 | out_features, 25 | weight, 26 | r=8, 27 | lora_alpha=16, 28 | lora_dropout=0.0, 29 | ): 30 | super().__init__() 31 | self.r = r 32 | self.lora_alpha = lora_alpha 33 | self.lora_dropout = lora_dropout 34 | if lora_dropout > 0.: 35 | self.lora_dropout = nn.Dropout(p=lora_dropout) 36 | else: 37 | self.lora_dropout = lambda x: x 38 | if r > 0: 39 | self.lora_A = nn.Parameter(weight.new_zeros((r, in_features))) 40 | self.lora_B = nn.Parameter(weight.new_zeros((out_features, r))) 41 | self.gate = nn.Parameter(torch.randn(1, r)) 42 | self.scaling = self.lora_alpha / self.r 43 | nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) 44 | nn.init.zeros_(self.lora_B) 45 | 46 | def forward(self, x): 47 | return ((self.lora_dropout(x) @ self.lora_A.T).mul(self.gate) @ self.lora_B.T) * self.scaling 48 | 49 | @dataclass 50 | class LoraArguments: 51 | r: int = 8 52 | lora_alpha: int = 16 53 | lora_dropout: float = 0.0 54 | 55 | class LoraConfig(BaseDeltaConfig): 56 | r""" 57 | This is the configuration class to store the configuration of a :py:class:`~LoraModel` 58 | """ 59 | def __init__( 60 | self, 61 | lora_r=8, 62 | lora_alpha=16, 63 | lora_dropout=0.0, 64 | **kwargs 65 | ): 66 | super().__init__(**kwargs) 67 | arg_names = get_arg_names_inside_func(self.__init__) 68 | for arg_name in arg_names: 69 | if not hasattr(self, arg_name): # the arg has not been registered in parent config 70 | setattr(self, arg_name, locals()[arg_name]) 71 | 72 | 73 | class LoraModel(DeltaBase): 74 | r""" The implementation of `LoRA: Low-Rank Adaptation of Large Language Models `_ . 75 | Thanks for their `loralib `_. 76 | 77 | .. note:: 78 | In our implementation, we did not use loralib.linear to replace the linear layer of the backbone model. 79 | Instead, we insert a parallel module into the backbone. 80 | In other words, we treat :math:`(W + A^TB) X` as :math:`WX+ A^TBX`, and insert the :math:`A^TBX` as a parallel insertion module. 81 | If you want to use the original implementation, please refer to `lora_old.py` 82 | 83 | class attributes: 84 | - default_modified_modules = ['attn.q', 'attn.v'] According to the paper, they modify q and v matrix in the 85 | attention layer. However, other linears can also be modified, and may lead to better performance. 86 | 87 | .. note:: 88 | modified_modules should point to linear layer. We currently don't support broadcast to all linears in 89 | a module's child modules. 90 | 91 | - delta_type = "lora" 92 | 93 | 94 | Args: 95 | backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified. 96 | lora_r (:obj:`int`, *optional*): the rank of the lora parameters. The smaller lora_r is , the fewer parameters lora has. 97 | lora_alpha (:obj:`int`, *optional*): A hyper-parameter to control the init scale of loralib.linear . 98 | lora_dropout (:obj:`float`, *optional*): The dropout rate in lora.linear. 99 | modified_modules (:obj:`List[str]`): For prefix tuning, the it must refer to an attention layer (Currently, only 100 | the implemented ones) 101 | unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen 102 | together with the prefix parameters. 103 | common_structure (:obj:`bool`): whether using name-based addressing with a common structure mapping. 104 | 105 | """ 106 | 107 | config_class = LoraConfig 108 | delta_type = "lora" 109 | default_modified_modules = ['attn@.q@', 'attn@.v@', 'attn@.k@', 'attn@.proj@', 'ff@.w1@', 'ff@.w2@'] 110 | _supported_backends = ['hf', 'bmt'] 111 | _need_pseudo_data = False 112 | def __init__(self, 113 | backbone_model: nn.Module, 114 | lora_r=8, 115 | lora_alpha=16, 116 | lora_dropout=0.0, 117 | modified_modules: Optional[List[str]] = None, 118 | unfrozen_modules: Optional[List[str]] = None, 119 | exclude_modules: Optional[List[str]] = None, 120 | common_structure: Optional[bool] = None, 121 | interactive_modify: Optional[Union[bool, int]] = False, 122 | ): 123 | DeltaBase.__init__(self, 124 | backbone_model, 125 | modified_modules=modified_modules, 126 | unfrozen_modules=unfrozen_modules, 127 | common_structure=common_structure, 128 | interactive_modify=interactive_modify, 129 | ) 130 | arg_names = get_arg_names_inside_func(self.__init__) 131 | for arg_name in arg_names: 132 | if not hasattr(self, arg_name): # not registered in parent class 133 | setattr(self, arg_name, locals()[arg_name]) 134 | 135 | self.delta_modules = nn.ModuleList() 136 | 137 | self.add_all_delta_to_backbone(self.backbone_model, 138 | self.modified_modules, 139 | ) 140 | 141 | 142 | def update_module(self, module: nn.Module, key: str): 143 | print("calling update module") 144 | parent_ref, child_name, child_ref = self.find_module(module, key) 145 | print("child ref:", child_ref) 146 | 147 | parallel_module = self.new_module_like(child_module=child_ref) 148 | print("parallel module:", parallel_module) 149 | self.insert_parallel_module(child_ref, delta_module=parallel_module, delta_name="lora") 150 | 151 | def _pseudo_data_to_instantiate(self, module): 152 | # no need to pass pseudo input, so overwrite it 153 | pass 154 | 155 | def new_module_like(self, child_module): 156 | if isinstance(child_module, nn.Linear): 157 | in_features, out_features = child_module.in_features, child_module.out_features 158 | new_module = LowRankLinear(in_features = in_features, 159 | out_features = out_features, 160 | weight = child_module.weight, 161 | r=self.lora_r, 162 | lora_alpha=self.lora_alpha, 163 | lora_dropout=self.lora_dropout) 164 | self.delta_modules.append(new_module) 165 | else: 166 | raise NotImplementedError 167 | return new_module 168 | -------------------------------------------------------------------------------- /run_glue_adapter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning the library models for sequence classification on GLUE.""" 17 | # You can also adapt this script on your own text classification task. Pointers for this are left as comments. 18 | 19 | import logging 20 | import os 21 | import wandb 22 | os.environ['WANDB_MODE'] = 'offline' 23 | import random 24 | import sys 25 | from dataclasses import dataclass, field 26 | from typing import Optional 27 | 28 | import datasets 29 | import numpy as np 30 | from datasets import load_dataset 31 | 32 | import transformers 33 | from transformers import ( 34 | AutoConfig, 35 | AutoModelForSequenceClassification, 36 | AutoTokenizer, 37 | DataCollatorWithPadding, 38 | EvalPrediction, 39 | HfArgumentParser, 40 | PretrainedConfig, 41 | Trainer, 42 | TrainingArguments, 43 | default_data_collator, 44 | set_seed, 45 | ) 46 | from transformers.trainer_utils import get_last_checkpoint 47 | from transformers.utils import check_min_version 48 | from transformers.utils.versions import require_version 49 | sys.path.append('../') 50 | from src.trainer import SparseTrainer 51 | from src.util import compute_trainable_sparse_param, create_optimizer_and_scheduler 52 | from src.sparse_optimizer import SparseAdamW 53 | from transformers import get_linear_schedule_with_warmup 54 | 55 | 56 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 57 | # check_min_version("4.24.0") 58 | 59 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") 60 | 61 | task_to_keys = { 62 | "cola": ("sentence", None), 63 | "mnli": ("premise", "hypothesis"), 64 | "mnli-m": ("premise", "hypothesis"), 65 | "mnli-mm": ("premise", "hypothesis"), 66 | "mrpc": ("sentence1", "sentence2"), 67 | "qnli": ("question", "sentence"), 68 | "qqp": ("question1", "question2"), 69 | "rte": ("sentence1", "sentence2"), 70 | "sst2": ("sentence", None), 71 | "stsb": ("sentence1", "sentence2"), 72 | "wnli": ("sentence1", "sentence2"), 73 | } 74 | 75 | task_to_best_metric = { 76 | "rte": "eval_accuracy", 77 | "mrpc": "eval_f1", 78 | "cola": "eval_matthews_correlation", 79 | "stsb": "eval_pearson", 80 | "sst2": "eval_accuracy", 81 | "qnli": "eval_accuracy", 82 | "mnli": "eval_accuracy", 83 | "mnli-m": "eval_accuracy", 84 | "mnli-mm": "eval_accuracy", 85 | "qqp": "eval_accuracy", 86 | } 87 | 88 | data_path = '/root/xtlv/data/sora_datasets/glue_datasets_from_dn/' 89 | 90 | logger = logging.getLogger(__name__) 91 | 92 | 93 | @dataclass 94 | class DataTrainingArguments: 95 | """ 96 | Arguments pertaining to what data we are going to input our model for training and eval. 97 | Using `HfArgumentParser` we can turn this class 98 | into argparse arguments to be able to specify them on 99 | the command line. 100 | """ 101 | 102 | task_name: Optional[str] = field( 103 | default=None, 104 | metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, 105 | ) 106 | dataset_name: Optional[str] = field( 107 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 108 | ) 109 | dataset_config_name: Optional[str] = field( 110 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 111 | ) 112 | max_seq_length: int = field( 113 | default=128, 114 | metadata={ 115 | "help": ( 116 | "The maximum total input sequence length after tokenization. Sequences longer " 117 | "than this will be truncated, sequences shorter will be padded." 118 | ) 119 | }, 120 | ) 121 | overwrite_cache: bool = field( 122 | default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} 123 | ) 124 | pad_to_max_length: bool = field( 125 | default=True, 126 | metadata={ 127 | "help": ( 128 | "Whether to pad all samples to `max_seq_length`. " 129 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 130 | ) 131 | }, 132 | ) 133 | max_train_samples: Optional[int] = field( 134 | default=None, 135 | metadata={ 136 | "help": ( 137 | "For debugging purposes or quicker training, truncate the number of training examples to this " 138 | "value if set." 139 | ) 140 | }, 141 | ) 142 | max_eval_samples: Optional[int] = field( 143 | default=None, 144 | metadata={ 145 | "help": ( 146 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 147 | "value if set." 148 | ) 149 | }, 150 | ) 151 | max_predict_samples: Optional[int] = field( 152 | default=None, 153 | metadata={ 154 | "help": ( 155 | "For debugging purposes or quicker training, truncate the number of prediction examples to this " 156 | "value if set." 157 | ) 158 | }, 159 | ) 160 | train_file: Optional[str] = field( 161 | default=None, metadata={"help": "A csv or a json file containing the training data."} 162 | ) 163 | validation_file: Optional[str] = field( 164 | default=None, metadata={"help": "A csv or a json file containing the validation data."} 165 | ) 166 | test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."}) 167 | 168 | def __post_init__(self): 169 | if self.task_name is not None: 170 | self.task_name = self.task_name.lower() 171 | if self.task_name not in task_to_keys.keys(): 172 | raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys())) 173 | elif self.dataset_name is not None: 174 | pass 175 | elif self.train_file is None or self.validation_file is None: 176 | raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.") 177 | else: 178 | train_extension = self.train_file.split(".")[-1] 179 | assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file." 180 | validation_extension = self.validation_file.split(".")[-1] 181 | assert ( 182 | validation_extension == train_extension 183 | ), "`validation_file` should have the same extension (csv or json) as `train_file`." 184 | 185 | 186 | @dataclass 187 | class ModelArguments: 188 | """ 189 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 190 | """ 191 | 192 | model_name_or_path: str = field( 193 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 194 | ) 195 | config_name: Optional[str] = field( 196 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 197 | ) 198 | tokenizer_name: Optional[str] = field( 199 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 200 | ) 201 | cache_dir: Optional[str] = field( 202 | default=None, 203 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 204 | ) 205 | use_fast_tokenizer: bool = field( 206 | default=True, 207 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 208 | ) 209 | model_revision: str = field( 210 | default="main", 211 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 212 | ) 213 | use_auth_token: bool = field( 214 | default=False, 215 | metadata={ 216 | "help": ( 217 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 218 | "with private models)." 219 | ) 220 | }, 221 | ) 222 | ignore_mismatched_sizes: bool = field( 223 | default=False, 224 | metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."}, 225 | ) 226 | 227 | @dataclass 228 | class SparseArguments: 229 | sparse_lambda: Optional[float] = field( 230 | default=1e-3, metadata={"help": "loss penalty term for gate param"} 231 | ) 232 | sparse_lambda_2: Optional[float] = field( 233 | default=1e-3, metadata={"help": "clipping scale for gate param"} 234 | ) 235 | sparse_lr: Optional[float] = field( 236 | default=None, metadata={"help": "lr for gate parameter in sparse lora, default to same as learning rate for other parameters"} 237 | ) 238 | lora_r: Optional[int] = field( 239 | default=16, metadata={"help": "matrix rank in lora"} 240 | ) 241 | lambda_schedule: Optional[str] = field( 242 | default=None, metadata={"help": "scheduling of lambda_2, {linear, log_linear}"} 243 | ) 244 | max_lambda: Optional[float] = field( 245 | default=10, metadata={"help": "maximum value of lambda_2 in scheduling"} 246 | ) 247 | lambda_num: Optional[int] = field( 248 | default=10, metadata={"help": "total number of lambdas in scheduling"} 249 | ) 250 | bottleneck_dim: Optional[int] = field( 251 | default=12, metadata={"help": "matrix rank in lora"} 252 | ) 253 | 254 | @dataclass 255 | class SparseTrainingArguments(TrainingArguments): 256 | train_sparse: Optional[bool] = field( 257 | default=False, metadata={"help": "whether use sparse lora"} 258 | ) 259 | debug_mode: Optional[bool] = field( 260 | default=False, metadata={"help": "debug mode"} 261 | ) 262 | 263 | 264 | def main(): 265 | # See all possible arguments in src/transformers/training_args.py 266 | # or by passing the --help flag to this script. 267 | # We now keep distinct sets of args, for a cleaner separation of concerns. 268 | 269 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, SparseTrainingArguments, SparseArguments)) 270 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 271 | # If we pass only one argument to the script and it's the path to a json file, 272 | # let's parse it to get our arguments. 273 | model_args, data_args, training_args, sparse_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 274 | else: 275 | model_args, data_args, training_args, sparse_args = parser.parse_args_into_dataclasses() 276 | 277 | 278 | task_name_for_get = data_args.task_name 279 | if "mnli" in data_args.task_name: 280 | data_args.task_name = "mnli" 281 | 282 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The 283 | # information sent is the one passed as arguments along with your Python/PyTorch versions. 284 | # send_example_telemetry("run_glue", model_args, data_args) 285 | training_args.metric_for_best_model = task_to_best_metric[data_args.task_name] 286 | 287 | if os.getenv("LOCAL_RANK"): 288 | training_args.local_rank = int(os.environ["LOCAL_RANK"]) 289 | else: 290 | training_args.local_rank = -1 291 | 292 | if training_args.train_sparse: 293 | if sparse_args.sparse_lr is None: 294 | sparse_args.sparse_lr = training_args.learning_rate 295 | if training_args.debug_mode: 296 | training_args.output_dir += "-debug" 297 | print(f"save model to {training_args.output_dir}") 298 | 299 | # Setup logging 300 | logging.basicConfig( 301 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 302 | datefmt="%m/%d/%Y %H:%M:%S", 303 | handlers=[logging.StreamHandler(sys.stdout)], 304 | ) 305 | 306 | log_level = training_args.get_process_log_level() 307 | logger.setLevel(log_level) 308 | datasets.utils.logging.set_verbosity(log_level) 309 | transformers.utils.logging.set_verbosity(log_level) 310 | transformers.utils.logging.enable_default_handler() 311 | transformers.utils.logging.enable_explicit_format() 312 | 313 | # Log on each process the small summary: 314 | logger.warning( 315 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 316 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 317 | ) 318 | logger.info(f"Training/evaluation parameters {training_args}") 319 | 320 | # Detecting last checkpoint. 321 | last_checkpoint = None 322 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 323 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 324 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 325 | raise ValueError( 326 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 327 | "Use --overwrite_output_dir to overcome." 328 | ) 329 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 330 | logger.info( 331 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 332 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 333 | ) 334 | 335 | # Set seed before initializing model. 336 | set_seed(training_args.seed) 337 | 338 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 339 | # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). 340 | # 341 | # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the 342 | # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named 343 | # label if at least two columns are provided. 344 | # 345 | # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this 346 | # single column. You can easily tweak this behavior (see below) 347 | # 348 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 349 | # download the dataset. 350 | if data_args.task_name is not None: 351 | # Downloading and loading a dataset from the hub. 352 | from datasets import load_from_disk 353 | from src.glue_tasks import AutoTask 354 | raw_datasets = load_from_disk(data_path + data_args.task_name) 355 | 356 | task = AutoTask().get(data_args.task_name, None, None) 357 | raw_datasets = { 358 | "train": task.get("train", split_validation_test=True), 359 | "validation": task.get("validation", split_validation_test=True), 360 | "test": task.get("test", split_validation_test=True) 361 | } 362 | from datasets import DatasetDict 363 | raw_datasets = DatasetDict(raw_datasets) 364 | 365 | elif data_args.dataset_name is not None: 366 | raise NotImplementedError 367 | 368 | else: 369 | # Loading a dataset from your local files. 370 | # CSV/JSON training and evaluation files are needed. 371 | data_files = {"train": data_args.train_file, "validation": data_args.validation_file} 372 | 373 | # Get the test dataset: you can provide your own CSV/JSON test file (see below) 374 | # when you use `do_predict` without specifying a GLUE benchmark task. 375 | if training_args.do_predict: 376 | if data_args.test_file is not None: 377 | train_extension = data_args.train_file.split(".")[-1] 378 | test_extension = data_args.test_file.split(".")[-1] 379 | assert ( 380 | test_extension == train_extension 381 | ), "`test_file` should have the same extension (csv or json) as `train_file`." 382 | data_files["test"] = data_args.test_file 383 | else: 384 | raise ValueError("Need either a GLUE task or a test file for `do_predict`.") 385 | 386 | for key in data_files.keys(): 387 | logger.info(f"load a local file for {key}: {data_files[key]}") 388 | 389 | if data_args.train_file.endswith(".csv"): 390 | # Loading a dataset from local csv files 391 | raw_datasets = load_dataset( 392 | "csv", 393 | data_files=data_files, 394 | cache_dir=model_args.cache_dir, 395 | use_auth_token=True if model_args.use_auth_token else None, 396 | ) 397 | else: 398 | # Loading a dataset from local json files 399 | raw_datasets = load_dataset( 400 | "json", 401 | data_files=data_files, 402 | cache_dir=model_args.cache_dir, 403 | use_auth_token=True if model_args.use_auth_token else None, 404 | ) 405 | # See more about loading any type of standard or custom dataset at 406 | # https://huggingface.co/docs/datasets/loading_datasets.html. 407 | 408 | # Labels 409 | if data_args.task_name is not None: 410 | is_regression = data_args.task_name == "stsb" 411 | if not is_regression: 412 | label_list = raw_datasets["train"].features["label"].names 413 | num_labels = len(label_list) 414 | else: 415 | num_labels = 1 416 | else: 417 | # Trying to have good defaults here, don't hesitate to tweak to your needs. 418 | is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] 419 | if is_regression: 420 | num_labels = 1 421 | else: 422 | # A useful fast method: 423 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique 424 | label_list = raw_datasets["train"].unique("label") 425 | label_list.sort() # Let's sort it for determinism 426 | num_labels = len(label_list) 427 | 428 | # Load pretrained model and tokenizer 429 | # 430 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 431 | # download model & vocab. 432 | config = AutoConfig.from_pretrained( 433 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 434 | num_labels=num_labels, 435 | finetuning_task=data_args.task_name, 436 | cache_dir=model_args.cache_dir, 437 | revision=model_args.model_revision, 438 | use_auth_token=True if model_args.use_auth_token else None, 439 | ) 440 | tokenizer = AutoTokenizer.from_pretrained( 441 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 442 | cache_dir=model_args.cache_dir, 443 | use_fast=model_args.use_fast_tokenizer, 444 | revision=model_args.model_revision, 445 | use_auth_token=True if model_args.use_auth_token else None, 446 | ) 447 | model = AutoModelForSequenceClassification.from_pretrained( 448 | model_args.model_name_or_path, 449 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 450 | config=config, 451 | cache_dir=model_args.cache_dir, 452 | revision=model_args.model_revision, 453 | use_auth_token=True if model_args.use_auth_token else None, 454 | ignore_mismatched_sizes=model_args.ignore_mismatched_sizes, 455 | ) 456 | 457 | from opendelta.delta_models.adapter import AdapterModel, AdapterConfig 458 | import json 459 | adapter_config = json.load(open("./adapter_config.json")) 460 | adapter_config["bottleneck_dim"] = sparse_args.bottleneck_dim 461 | adapter_config = AdapterConfig.from_dict(adapter_config) 462 | delta_model = AdapterModel.from_config(adapter_config, backbone_model=model) 463 | delta_model.freeze_module(set_state_dict = True) 464 | delta_model.log(delta_ratio=True, trainable_ratio=True, visualization=False) 465 | 466 | 467 | 468 | # Preprocessing the raw_datasets 469 | if data_args.task_name is not None: 470 | sentence1_key, sentence2_key = task_to_keys[data_args.task_name] 471 | else: 472 | # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. 473 | non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] 474 | if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: 475 | sentence1_key, sentence2_key = "sentence1", "sentence2" 476 | else: 477 | if len(non_label_column_names) >= 2: 478 | sentence1_key, sentence2_key = non_label_column_names[:2] 479 | else: 480 | sentence1_key, sentence2_key = non_label_column_names[0], None 481 | 482 | # Padding strategy 483 | if data_args.pad_to_max_length: 484 | padding = "max_length" 485 | else: 486 | # We will pad later, dynamically at batch creation, to the max sequence length in each batch 487 | padding = False 488 | 489 | # Some models have set the order of the labels to use, so let's make sure we do use it. 490 | label_to_id = None 491 | if ( 492 | model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id 493 | and data_args.task_name is not None 494 | and not is_regression 495 | ): 496 | # Some have all caps in their config, some don't. 497 | label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} 498 | if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): 499 | label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)} 500 | else: 501 | logger.warning( 502 | "Your model seems to have been trained with labels, but they don't match the dataset: ", 503 | f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." 504 | "\nIgnoring the model labels as a result.", 505 | ) 506 | elif data_args.task_name is None and not is_regression: 507 | label_to_id = {v: i for i, v in enumerate(label_list)} 508 | 509 | if label_to_id is not None: 510 | model.config.label2id = label_to_id 511 | model.config.id2label = {id: label for label, id in config.label2id.items()} 512 | elif data_args.task_name is not None and not is_regression: 513 | model.config.label2id = {l: i for i, l in enumerate(label_list)} 514 | model.config.id2label = {id: label for label, id in config.label2id.items()} 515 | 516 | if data_args.max_seq_length > tokenizer.model_max_length: 517 | logger.warning( 518 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 519 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 520 | ) 521 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 522 | 523 | def preprocess_function(examples): 524 | # Tokenize the texts 525 | args = ( 526 | (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) 527 | ) 528 | result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) 529 | 530 | # Map labels to IDs (not necessary for GLUE tasks) 531 | if label_to_id is not None and "label" in examples: 532 | result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] 533 | return result 534 | 535 | with training_args.main_process_first(desc="dataset map pre-processing"): 536 | raw_datasets = raw_datasets.map( 537 | preprocess_function, 538 | batched=True, 539 | load_from_cache_file=not data_args.overwrite_cache, 540 | desc="Running tokenizer on dataset", 541 | ) 542 | if training_args.do_train: 543 | if "train" not in raw_datasets: 544 | raise ValueError("--do_train requires a train dataset") 545 | train_dataset = raw_datasets["train"] 546 | if data_args.max_train_samples is not None: 547 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 548 | train_dataset = train_dataset.select(range(max_train_samples)) 549 | 550 | if training_args.do_eval: 551 | if "validation" not in raw_datasets and "validation_matched" not in raw_datasets: 552 | raise ValueError("--do_eval requires a validation dataset") 553 | eval_dataset = raw_datasets["validation"] 554 | if data_args.max_eval_samples is not None: 555 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 556 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 557 | 558 | if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None: 559 | if "test" not in raw_datasets and "test_matched" not in raw_datasets: 560 | raise ValueError("--do_predict requires a test dataset") 561 | predict_dataset = raw_datasets["test"] 562 | if data_args.max_predict_samples is not None: 563 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) 564 | predict_dataset = predict_dataset.select(range(max_predict_samples)) 565 | 566 | # Log a few random samples from the training set: 567 | if training_args.do_train: 568 | for index in random.sample(range(len(train_dataset)), 3): 569 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 570 | 571 | # Get the metric function 572 | from datasets import load_metric 573 | if data_args.task_name is not None: 574 | metric = load_metric("./glue.py", data_args.task_name) 575 | else: 576 | metric = load_metric("accuracy") 577 | 578 | # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a 579 | # predictions and label_ids field) and has to return a dictionary string to float. 580 | def compute_metrics(mode, p: EvalPrediction): 581 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions 582 | preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) 583 | if data_args.task_name is not None: 584 | result = metric.compute(predictions=preds, references=p.label_ids) 585 | if len(result) > 1: 586 | result["combined_score"] = np.mean(list(result.values())).item() 587 | return result 588 | elif is_regression: 589 | return {"mse": ((preds - p.label_ids) ** 2).mean().item()} 590 | else: 591 | return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} 592 | 593 | # Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if 594 | # we already did the padding. 595 | if data_args.pad_to_max_length: 596 | data_collator = default_data_collator 597 | elif training_args.fp16: 598 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) 599 | else: 600 | data_collator = None 601 | 602 | 603 | # Initialize our Trainer 604 | optimizer, lr_scheduler = create_optimizer_and_scheduler(training_args, model, num_training_steps=int(training_args.num_train_epochs*(len(train_dataset) / training_args.train_batch_size))) 605 | sparse_optimizer = None 606 | sparse_scheduler = None 607 | if training_args.train_sparse: 608 | print("building sparse optimizer and scheduler") 609 | from src.trainer import GATE_PARAM_NAME 610 | valid_param_name = [] 611 | for n, p in model.named_parameters(): 612 | print(n) 613 | if GATE_PARAM_NAME in n: 614 | valid_param_name.append(n) 615 | print("valid param name:", valid_param_name) 616 | sparse_optimizer = SparseAdamW(sparse_lambda=sparse_args.sparse_lambda_2, lambda_schedule=sparse_args.lambda_schedule, max_lambda=sparse_args.max_lambda, lambda_num=sparse_args.lambda_num, params=[p for n, p in model.named_parameters() if GATE_PARAM_NAME in n and p.requires_grad], lr=sparse_args.sparse_lr) 617 | sparse_scheduler = get_linear_schedule_with_warmup(sparse_optimizer, 618 | num_warmup_steps=int(training_args.num_train_epochs*(len(train_dataset) / training_args.train_batch_size)*training_args.warmup_ratio), 619 | num_training_steps=int(training_args.num_train_epochs*(len(train_dataset) / training_args.train_batch_size))) 620 | 621 | if training_args.debug_mode: 622 | train_dataset = eval_dataset 623 | 624 | # Initialize our Trainer 625 | trainer = SparseTrainer( 626 | model=model, 627 | args=training_args, 628 | train_dataset=train_dataset if training_args.do_train else None, 629 | eval_dataset=eval_dataset if training_args.do_eval else None, 630 | compute_metrics=compute_metrics, 631 | tokenizer=tokenizer, 632 | data_collator=data_collator, 633 | optimizers = (optimizer, lr_scheduler), 634 | sparse_lambda = sparse_args.sparse_lambda, 635 | sparse_optimizer = (sparse_optimizer, sparse_scheduler) 636 | ) 637 | 638 | # Training 639 | if training_args.do_train: 640 | checkpoint = None 641 | if training_args.resume_from_checkpoint is not None: 642 | checkpoint = training_args.resume_from_checkpoint 643 | elif last_checkpoint is not None: 644 | checkpoint = last_checkpoint 645 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 646 | metrics = train_result.metrics 647 | max_train_samples = ( 648 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 649 | ) 650 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 651 | 652 | trainer.save_model() # Saves the tokenizer too for easy upload 653 | 654 | trainer.log_metrics("train", metrics) 655 | trainer.save_metrics("train", metrics) 656 | trainer.save_state() 657 | 658 | sparse_param, total_param = compute_trainable_sparse_param(model) 659 | 660 | 661 | # eval on 1000 samples train set 662 | train_dataset_for_eval = train_dataset.shuffle(seed=42).select(range(1000)) 663 | logger.info("*** Evaluate on training subset ***") 664 | metrics = trainer.evaluate(eval_dataset=train_dataset_for_eval, metric_key_prefix = "eval_train") 665 | trainer.log_metrics("eval_train", metrics) 666 | trainer.save_metrics("eval_train", metrics) 667 | BEST_TRAIN_METRIC = metrics["eval_train_" + "_".join(task_to_best_metric[data_args.task_name].split("_")[1:])] 668 | 669 | 670 | # Evaluation 671 | if training_args.do_eval: 672 | logger.info("*** Evaluate ***") 673 | 674 | # Loop to handle MNLI double evaluation (matched, mis-matched) 675 | tasks = [data_args.task_name] 676 | eval_datasets = [eval_dataset] 677 | 678 | 679 | for eval_dataset, task in zip(eval_datasets, tasks): 680 | metrics = trainer.evaluate(eval_dataset=eval_dataset) 681 | 682 | max_eval_samples = ( 683 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 684 | ) 685 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 686 | 687 | 688 | trainer.log_metrics("eval", metrics) 689 | trainer.save_metrics("eval", metrics) 690 | 691 | BEST_EVAL_METRIC = metrics[task_to_best_metric[data_args.task_name]] 692 | 693 | if training_args.do_predict: 694 | logger.info("*** Predict ***") 695 | 696 | # Loop to handle MNLI double evaluation (matched, mis-matched) 697 | tasks = [data_args.task_name] 698 | predict_datasets = [predict_dataset] 699 | 700 | 701 | for predict_dataset, task in zip(predict_datasets, tasks): 702 | metrics = trainer.evaluate(eval_dataset=predict_dataset) 703 | 704 | max_eval_samples = ( 705 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 706 | ) 707 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 708 | 709 | trainer.log_metrics("test", metrics) 710 | trainer.save_metrics("test", metrics) 711 | 712 | logger.info("***** Final Model ******\nAdapter bottleneck_dim: %d\nNumber of trainable full param: %d\nNumber of trainable sparse param: %d, Ratio: %.4f%%\n**********" % (adapter_config.bottleneck_dim, total_param, sparse_param, sparse_param / total_param * 100)) 713 | 714 | if __name__ == "__main__": 715 | main() -------------------------------------------------------------------------------- /run_glue_bitfit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning the library models for sequence classification on GLUE.""" 17 | # You can also adapt this script on your own text classification task. Pointers for this are left as comments. 18 | 19 | import logging 20 | import os 21 | import wandb 22 | os.environ['WANDB_MODE'] = 'offline' 23 | import random 24 | import sys 25 | from dataclasses import dataclass, field 26 | from typing import Optional 27 | 28 | import datasets 29 | import numpy as np 30 | from datasets import load_dataset 31 | 32 | import transformers 33 | from transformers import ( 34 | AutoConfig, 35 | AutoModelForSequenceClassification, 36 | AutoTokenizer, 37 | DataCollatorWithPadding, 38 | EvalPrediction, 39 | HfArgumentParser, 40 | PretrainedConfig, 41 | Trainer, 42 | TrainingArguments, 43 | default_data_collator, 44 | set_seed, 45 | ) 46 | from transformers.trainer_utils import get_last_checkpoint 47 | from transformers.utils import check_min_version 48 | from transformers.utils.versions import require_version 49 | sys.path.append('../') 50 | from src.trainer import SparseTrainer 51 | from src.util import compute_trainable_sparse_param, create_optimizer_and_scheduler 52 | from src.sparse_optimizer import SparseAdamW 53 | from transformers import get_linear_schedule_with_warmup 54 | 55 | 56 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 57 | # check_min_version("4.24.0") 58 | 59 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") 60 | 61 | task_to_keys = { 62 | "cola": ("sentence", None), 63 | "mnli": ("premise", "hypothesis"), 64 | "mnli-m": ("premise", "hypothesis"), 65 | "mnli-mm": ("premise", "hypothesis"), 66 | "mrpc": ("sentence1", "sentence2"), 67 | "qnli": ("question", "sentence"), 68 | "qqp": ("question1", "question2"), 69 | "rte": ("sentence1", "sentence2"), 70 | "sst2": ("sentence", None), 71 | "stsb": ("sentence1", "sentence2"), 72 | "wnli": ("sentence1", "sentence2"), 73 | } 74 | 75 | task_to_best_metric = { 76 | "rte": "eval_accuracy", 77 | "mrpc": "eval_f1", 78 | "cola": "eval_matthews_correlation", 79 | "stsb": "eval_pearson", 80 | "sst2": "eval_accuracy", 81 | "qnli": "eval_accuracy", 82 | "mnli": "eval_accuracy", 83 | "mnli-m": "eval_accuracy", 84 | "mnli-mm": "eval_accuracy", 85 | "qqp": "eval_accuracy", 86 | } 87 | 88 | data_path = '/root/xtlv/data/sora_datasets/glue_datasets_from_dn/' 89 | 90 | logger = logging.getLogger(__name__) 91 | 92 | 93 | @dataclass 94 | class DataTrainingArguments: 95 | """ 96 | Arguments pertaining to what data we are going to input our model for training and eval. 97 | Using `HfArgumentParser` we can turn this class 98 | into argparse arguments to be able to specify them on 99 | the command line. 100 | """ 101 | 102 | task_name: Optional[str] = field( 103 | default=None, 104 | metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, 105 | ) 106 | dataset_name: Optional[str] = field( 107 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 108 | ) 109 | dataset_config_name: Optional[str] = field( 110 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 111 | ) 112 | max_seq_length: int = field( 113 | default=128, 114 | metadata={ 115 | "help": ( 116 | "The maximum total input sequence length after tokenization. Sequences longer " 117 | "than this will be truncated, sequences shorter will be padded." 118 | ) 119 | }, 120 | ) 121 | overwrite_cache: bool = field( 122 | default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} 123 | ) 124 | pad_to_max_length: bool = field( 125 | default=True, 126 | metadata={ 127 | "help": ( 128 | "Whether to pad all samples to `max_seq_length`. " 129 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 130 | ) 131 | }, 132 | ) 133 | max_train_samples: Optional[int] = field( 134 | default=None, 135 | metadata={ 136 | "help": ( 137 | "For debugging purposes or quicker training, truncate the number of training examples to this " 138 | "value if set." 139 | ) 140 | }, 141 | ) 142 | max_eval_samples: Optional[int] = field( 143 | default=None, 144 | metadata={ 145 | "help": ( 146 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 147 | "value if set." 148 | ) 149 | }, 150 | ) 151 | max_predict_samples: Optional[int] = field( 152 | default=None, 153 | metadata={ 154 | "help": ( 155 | "For debugging purposes or quicker training, truncate the number of prediction examples to this " 156 | "value if set." 157 | ) 158 | }, 159 | ) 160 | train_file: Optional[str] = field( 161 | default=None, metadata={"help": "A csv or a json file containing the training data."} 162 | ) 163 | validation_file: Optional[str] = field( 164 | default=None, metadata={"help": "A csv or a json file containing the validation data."} 165 | ) 166 | test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."}) 167 | 168 | def __post_init__(self): 169 | if self.task_name is not None: 170 | self.task_name = self.task_name.lower() 171 | if self.task_name not in task_to_keys.keys(): 172 | raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys())) 173 | elif self.dataset_name is not None: 174 | pass 175 | elif self.train_file is None or self.validation_file is None: 176 | raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.") 177 | else: 178 | train_extension = self.train_file.split(".")[-1] 179 | assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file." 180 | validation_extension = self.validation_file.split(".")[-1] 181 | assert ( 182 | validation_extension == train_extension 183 | ), "`validation_file` should have the same extension (csv or json) as `train_file`." 184 | 185 | 186 | @dataclass 187 | class ModelArguments: 188 | """ 189 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 190 | """ 191 | 192 | model_name_or_path: str = field( 193 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 194 | ) 195 | config_name: Optional[str] = field( 196 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 197 | ) 198 | tokenizer_name: Optional[str] = field( 199 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 200 | ) 201 | cache_dir: Optional[str] = field( 202 | default=None, 203 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 204 | ) 205 | use_fast_tokenizer: bool = field( 206 | default=True, 207 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 208 | ) 209 | model_revision: str = field( 210 | default="main", 211 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 212 | ) 213 | use_auth_token: bool = field( 214 | default=False, 215 | metadata={ 216 | "help": ( 217 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 218 | "with private models)." 219 | ) 220 | }, 221 | ) 222 | ignore_mismatched_sizes: bool = field( 223 | default=False, 224 | metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."}, 225 | ) 226 | 227 | @dataclass 228 | class SparseArguments: 229 | sparse_lambda: Optional[float] = field( 230 | default=1e-3, metadata={"help": "loss penalty term for gate param"} 231 | ) 232 | sparse_lambda_2: Optional[float] = field( 233 | default=1e-3, metadata={"help": "clipping scale for gate param"} 234 | ) 235 | sparse_lr: Optional[float] = field( 236 | default=None, metadata={"help": "lr for gate parameter in sparse lora, default to same as learning rate for other parameters"} 237 | ) 238 | lora_r: Optional[int] = field( 239 | default=16, metadata={"help": "matrix rank in lora"} 240 | ) 241 | lambda_schedule: Optional[str] = field( 242 | default=None, metadata={"help": "scheduling of lambda_2, {linear, log_linear}"} 243 | ) 244 | max_lambda: Optional[float] = field( 245 | default=10, metadata={"help": "maximum value of lambda_2 in scheduling"} 246 | ) 247 | lambda_num: Optional[int] = field( 248 | default=10, metadata={"help": "total number of lambdas in scheduling"} 249 | ) 250 | 251 | bottleneck_dim: Optional[int] = field( 252 | default=12, metadata={"help": "matrix rank in lora"} 253 | ) 254 | 255 | @dataclass 256 | class SparseTrainingArguments(TrainingArguments): 257 | train_sparse: Optional[bool] = field( 258 | default=False, metadata={"help": "whether use sparse lora"} 259 | ) 260 | debug_mode: Optional[bool] = field( 261 | default=False, metadata={"help": "debug mode"} 262 | ) 263 | 264 | 265 | def main(): 266 | # See all possible arguments in src/transformers/training_args.py 267 | # or by passing the --help flag to this script. 268 | # We now keep distinct sets of args, for a cleaner separation of concerns. 269 | 270 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, SparseTrainingArguments, SparseArguments)) 271 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 272 | # If we pass only one argument to the script and it's the path to a json file, 273 | # let's parse it to get our arguments. 274 | model_args, data_args, training_args, sparse_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 275 | else: 276 | model_args, data_args, training_args, sparse_args = parser.parse_args_into_dataclasses() 277 | 278 | 279 | task_name_for_get = data_args.task_name 280 | if "mnli" in data_args.task_name: 281 | data_args.task_name = "mnli" 282 | 283 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The 284 | # information sent is the one passed as arguments along with your Python/PyTorch versions. 285 | # send_example_telemetry("run_glue", model_args, data_args) 286 | training_args.metric_for_best_model = task_to_best_metric[data_args.task_name] 287 | 288 | if os.getenv("LOCAL_RANK"): 289 | training_args.local_rank = int(os.environ["LOCAL_RANK"]) 290 | else: 291 | training_args.local_rank = -1 292 | 293 | if training_args.train_sparse: 294 | if sparse_args.sparse_lr is None: 295 | sparse_args.sparse_lr = training_args.learning_rate 296 | if training_args.debug_mode: 297 | training_args.output_dir += "-debug" 298 | print(f"save model to {training_args.output_dir}") 299 | 300 | # Setup logging 301 | logging.basicConfig( 302 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 303 | datefmt="%m/%d/%Y %H:%M:%S", 304 | handlers=[logging.StreamHandler(sys.stdout)], 305 | ) 306 | 307 | log_level = training_args.get_process_log_level() 308 | logger.setLevel(log_level) 309 | datasets.utils.logging.set_verbosity(log_level) 310 | transformers.utils.logging.set_verbosity(log_level) 311 | transformers.utils.logging.enable_default_handler() 312 | transformers.utils.logging.enable_explicit_format() 313 | 314 | # Log on each process the small summary: 315 | logger.warning( 316 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 317 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 318 | ) 319 | logger.info(f"Training/evaluation parameters {training_args}") 320 | 321 | # Detecting last checkpoint. 322 | last_checkpoint = None 323 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 324 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 325 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 326 | raise ValueError( 327 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 328 | "Use --overwrite_output_dir to overcome." 329 | ) 330 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 331 | logger.info( 332 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 333 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 334 | ) 335 | 336 | # Set seed before initializing model. 337 | set_seed(training_args.seed) 338 | 339 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 340 | # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). 341 | # 342 | # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the 343 | # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named 344 | # label if at least two columns are provided. 345 | # 346 | # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this 347 | # single column. You can easily tweak this behavior (see below) 348 | # 349 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 350 | # download the dataset. 351 | if data_args.task_name is not None: 352 | # Downloading and loading a dataset from the hub. 353 | from datasets import load_from_disk 354 | from src.glue_tasks import AutoTask 355 | raw_datasets = load_from_disk(data_path + data_args.task_name) 356 | 357 | task = AutoTask().get(data_args.task_name, None, None) 358 | raw_datasets = { 359 | "train": task.get("train", split_validation_test=True), 360 | "validation": task.get("validation", split_validation_test=True), 361 | "test": task.get("test", split_validation_test=True) 362 | } 363 | from datasets import DatasetDict 364 | raw_datasets = DatasetDict(raw_datasets) 365 | 366 | elif data_args.dataset_name is not None: 367 | raise NotImplementedError 368 | 369 | else: 370 | # Loading a dataset from your local files. 371 | # CSV/JSON training and evaluation files are needed. 372 | data_files = {"train": data_args.train_file, "validation": data_args.validation_file} 373 | 374 | # Get the test dataset: you can provide your own CSV/JSON test file (see below) 375 | # when you use `do_predict` without specifying a GLUE benchmark task. 376 | if training_args.do_predict: 377 | if data_args.test_file is not None: 378 | train_extension = data_args.train_file.split(".")[-1] 379 | test_extension = data_args.test_file.split(".")[-1] 380 | assert ( 381 | test_extension == train_extension 382 | ), "`test_file` should have the same extension (csv or json) as `train_file`." 383 | data_files["test"] = data_args.test_file 384 | else: 385 | raise ValueError("Need either a GLUE task or a test file for `do_predict`.") 386 | 387 | for key in data_files.keys(): 388 | logger.info(f"load a local file for {key}: {data_files[key]}") 389 | 390 | if data_args.train_file.endswith(".csv"): 391 | # Loading a dataset from local csv files 392 | raw_datasets = load_dataset( 393 | "csv", 394 | data_files=data_files, 395 | cache_dir=model_args.cache_dir, 396 | use_auth_token=True if model_args.use_auth_token else None, 397 | ) 398 | else: 399 | # Loading a dataset from local json files 400 | raw_datasets = load_dataset( 401 | "json", 402 | data_files=data_files, 403 | cache_dir=model_args.cache_dir, 404 | use_auth_token=True if model_args.use_auth_token else None, 405 | ) 406 | # See more about loading any type of standard or custom dataset at 407 | # https://huggingface.co/docs/datasets/loading_datasets.html. 408 | 409 | # Labels 410 | if data_args.task_name is not None: 411 | is_regression = data_args.task_name == "stsb" 412 | if not is_regression: 413 | label_list = raw_datasets["train"].features["label"].names 414 | num_labels = len(label_list) 415 | else: 416 | num_labels = 1 417 | else: 418 | # Trying to have good defaults here, don't hesitate to tweak to your needs. 419 | is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] 420 | if is_regression: 421 | num_labels = 1 422 | else: 423 | # A useful fast method: 424 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique 425 | label_list = raw_datasets["train"].unique("label") 426 | label_list.sort() # Let's sort it for determinism 427 | num_labels = len(label_list) 428 | 429 | # Load pretrained model and tokenizer 430 | # 431 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 432 | # download model & vocab. 433 | config = AutoConfig.from_pretrained( 434 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 435 | num_labels=num_labels, 436 | finetuning_task=data_args.task_name, 437 | cache_dir=model_args.cache_dir, 438 | revision=model_args.model_revision, 439 | use_auth_token=True if model_args.use_auth_token else None, 440 | ) 441 | tokenizer = AutoTokenizer.from_pretrained( 442 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 443 | cache_dir=model_args.cache_dir, 444 | use_fast=model_args.use_fast_tokenizer, 445 | revision=model_args.model_revision, 446 | use_auth_token=True if model_args.use_auth_token else None, 447 | ) 448 | model = AutoModelForSequenceClassification.from_pretrained( 449 | model_args.model_name_or_path, 450 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 451 | config=config, 452 | cache_dir=model_args.cache_dir, 453 | revision=model_args.model_revision, 454 | use_auth_token=True if model_args.use_auth_token else None, 455 | ignore_mismatched_sizes=model_args.ignore_mismatched_sizes, 456 | ) 457 | 458 | from opendelta.delta_models.bitfit import BitFitConfig, BitFitModel 459 | import json 460 | bitfit_config = json.load(open("./bitfit_config.json")) 461 | bitfit_config = BitFitConfig.from_dict(bitfit_config) 462 | delta_model = BitFitModel.from_config(bitfit_config, backbone_model=model) 463 | delta_model.freeze_module(set_state_dict = True) 464 | delta_model.log(delta_ratio=True, trainable_ratio=True, visualization=False) 465 | 466 | 467 | 468 | # Preprocessing the raw_datasets 469 | if data_args.task_name is not None: 470 | sentence1_key, sentence2_key = task_to_keys[data_args.task_name] 471 | else: 472 | # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. 473 | non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] 474 | if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: 475 | sentence1_key, sentence2_key = "sentence1", "sentence2" 476 | else: 477 | if len(non_label_column_names) >= 2: 478 | sentence1_key, sentence2_key = non_label_column_names[:2] 479 | else: 480 | sentence1_key, sentence2_key = non_label_column_names[0], None 481 | 482 | # Padding strategy 483 | if data_args.pad_to_max_length: 484 | padding = "max_length" 485 | else: 486 | # We will pad later, dynamically at batch creation, to the max sequence length in each batch 487 | padding = False 488 | 489 | # Some models have set the order of the labels to use, so let's make sure we do use it. 490 | label_to_id = None 491 | if ( 492 | model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id 493 | and data_args.task_name is not None 494 | and not is_regression 495 | ): 496 | # Some have all caps in their config, some don't. 497 | label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} 498 | if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): 499 | label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)} 500 | else: 501 | logger.warning( 502 | "Your model seems to have been trained with labels, but they don't match the dataset: ", 503 | f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." 504 | "\nIgnoring the model labels as a result.", 505 | ) 506 | elif data_args.task_name is None and not is_regression: 507 | label_to_id = {v: i for i, v in enumerate(label_list)} 508 | 509 | if label_to_id is not None: 510 | model.config.label2id = label_to_id 511 | model.config.id2label = {id: label for label, id in config.label2id.items()} 512 | elif data_args.task_name is not None and not is_regression: 513 | model.config.label2id = {l: i for i, l in enumerate(label_list)} 514 | model.config.id2label = {id: label for label, id in config.label2id.items()} 515 | 516 | if data_args.max_seq_length > tokenizer.model_max_length: 517 | logger.warning( 518 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 519 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 520 | ) 521 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 522 | 523 | def preprocess_function(examples): 524 | # Tokenize the texts 525 | args = ( 526 | (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) 527 | ) 528 | result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) 529 | 530 | # Map labels to IDs (not necessary for GLUE tasks) 531 | if label_to_id is not None and "label" in examples: 532 | result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] 533 | return result 534 | 535 | with training_args.main_process_first(desc="dataset map pre-processing"): 536 | raw_datasets = raw_datasets.map( 537 | preprocess_function, 538 | batched=True, 539 | load_from_cache_file=not data_args.overwrite_cache, 540 | desc="Running tokenizer on dataset", 541 | ) 542 | if training_args.do_train: 543 | if "train" not in raw_datasets: 544 | raise ValueError("--do_train requires a train dataset") 545 | train_dataset = raw_datasets["train"] 546 | if data_args.max_train_samples is not None: 547 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 548 | train_dataset = train_dataset.select(range(max_train_samples)) 549 | 550 | if training_args.do_eval: 551 | if "validation" not in raw_datasets and "validation_matched" not in raw_datasets: 552 | raise ValueError("--do_eval requires a validation dataset") 553 | eval_dataset = raw_datasets["validation"] 554 | if data_args.max_eval_samples is not None: 555 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 556 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 557 | 558 | if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None: 559 | if "test" not in raw_datasets and "test_matched" not in raw_datasets: 560 | raise ValueError("--do_predict requires a test dataset") 561 | predict_dataset = raw_datasets["test"] 562 | if data_args.max_predict_samples is not None: 563 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) 564 | predict_dataset = predict_dataset.select(range(max_predict_samples)) 565 | 566 | # Log a few random samples from the training set: 567 | if training_args.do_train: 568 | for index in random.sample(range(len(train_dataset)), 3): 569 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 570 | 571 | # Get the metric function 572 | from datasets import load_metric 573 | if data_args.task_name is not None: 574 | metric = load_metric("./glue.py", data_args.task_name) 575 | else: 576 | metric = load_metric("accuracy") 577 | 578 | # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a 579 | # predictions and label_ids field) and has to return a dictionary string to float. 580 | def compute_metrics(mode, p: EvalPrediction): 581 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions 582 | preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) 583 | if data_args.task_name is not None: 584 | result = metric.compute(predictions=preds, references=p.label_ids) 585 | if len(result) > 1: 586 | result["combined_score"] = np.mean(list(result.values())).item() 587 | return result 588 | elif is_regression: 589 | return {"mse": ((preds - p.label_ids) ** 2).mean().item()} 590 | else: 591 | return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} 592 | 593 | # Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if 594 | # we already did the padding. 595 | if data_args.pad_to_max_length: 596 | data_collator = default_data_collator 597 | elif training_args.fp16: 598 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) 599 | else: 600 | data_collator = None 601 | 602 | # Initialize our Trainer 603 | optimizer, lr_scheduler = create_optimizer_and_scheduler(training_args, model, num_training_steps=int(training_args.num_train_epochs*(len(train_dataset) / training_args.train_batch_size))) 604 | sparse_optimizer = None 605 | sparse_scheduler = None 606 | if training_args.train_sparse: 607 | print("building sparse optimizer and scheduler") 608 | from src.trainer import GATE_PARAM_NAME 609 | valid_param_name = [] 610 | for n, p in model.named_parameters(): 611 | print(n) 612 | if GATE_PARAM_NAME in n: 613 | valid_param_name.append(n) 614 | print("valid param name:", valid_param_name) 615 | sparse_optimizer = SparseAdamW(sparse_lambda=sparse_args.sparse_lambda_2, lambda_schedule=sparse_args.lambda_schedule, max_lambda=sparse_args.max_lambda, lambda_num=sparse_args.lambda_num, params=[p for n, p in model.named_parameters() if GATE_PARAM_NAME in n and p.requires_grad], lr=sparse_args.sparse_lr) 616 | sparse_scheduler = get_linear_schedule_with_warmup(sparse_optimizer, 617 | num_warmup_steps=int(training_args.num_train_epochs*(len(train_dataset) / training_args.train_batch_size)*training_args.warmup_ratio), 618 | num_training_steps=int(training_args.num_train_epochs*(len(train_dataset) / training_args.train_batch_size))) 619 | 620 | if training_args.debug_mode: 621 | train_dataset = eval_dataset 622 | 623 | # Initialize our Trainer 624 | trainer = SparseTrainer( 625 | model=model, 626 | args=training_args, 627 | train_dataset=train_dataset if training_args.do_train else None, 628 | eval_dataset=eval_dataset if training_args.do_eval else None, 629 | compute_metrics=compute_metrics, 630 | tokenizer=tokenizer, 631 | data_collator=data_collator, 632 | optimizers = (optimizer, lr_scheduler), 633 | sparse_lambda = sparse_args.sparse_lambda, 634 | sparse_optimizer = (sparse_optimizer, sparse_scheduler) 635 | ) 636 | 637 | # Training 638 | if training_args.do_train: 639 | checkpoint = None 640 | if training_args.resume_from_checkpoint is not None: 641 | checkpoint = training_args.resume_from_checkpoint 642 | elif last_checkpoint is not None: 643 | checkpoint = last_checkpoint 644 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 645 | metrics = train_result.metrics 646 | max_train_samples = ( 647 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 648 | ) 649 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 650 | 651 | trainer.save_model() # Saves the tokenizer too for easy upload 652 | 653 | trainer.log_metrics("train", metrics) 654 | trainer.save_metrics("train", metrics) 655 | trainer.save_state() 656 | 657 | sparse_param, total_param = compute_trainable_sparse_param(model) 658 | 659 | 660 | # eval on 1000 samples train set 661 | train_dataset_for_eval = train_dataset.shuffle(seed=42).select(range(1000)) 662 | logger.info("*** Evaluate on training subset ***") 663 | metrics = trainer.evaluate(eval_dataset=train_dataset_for_eval, metric_key_prefix = "eval_train") 664 | trainer.log_metrics("eval_train", metrics) 665 | trainer.save_metrics("eval_train", metrics) 666 | BEST_TRAIN_METRIC = metrics["eval_train_" + "_".join(task_to_best_metric[data_args.task_name].split("_")[1:])] 667 | 668 | 669 | # Evaluation 670 | if training_args.do_eval: 671 | logger.info("*** Evaluate ***") 672 | 673 | # Loop to handle MNLI double evaluation (matched, mis-matched) 674 | tasks = [data_args.task_name] 675 | eval_datasets = [eval_dataset] 676 | 677 | 678 | for eval_dataset, task in zip(eval_datasets, tasks): 679 | metrics = trainer.evaluate(eval_dataset=eval_dataset) 680 | 681 | max_eval_samples = ( 682 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 683 | ) 684 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 685 | 686 | 687 | trainer.log_metrics("eval", metrics) 688 | # trainer.save_metrics("eval", combined if task is not None and "mnli" in task else metrics) 689 | trainer.save_metrics("eval", metrics) 690 | 691 | BEST_EVAL_METRIC = metrics[task_to_best_metric[data_args.task_name]] 692 | 693 | if training_args.do_predict: 694 | logger.info("*** Predict ***") 695 | 696 | # Loop to handle MNLI double evaluation (matched, mis-matched) 697 | tasks = [data_args.task_name] 698 | predict_datasets = [predict_dataset] 699 | 700 | 701 | for predict_dataset, task in zip(predict_datasets, tasks): 702 | metrics = trainer.evaluate(eval_dataset=predict_dataset) 703 | 704 | max_eval_samples = ( 705 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 706 | ) 707 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 708 | 709 | 710 | trainer.log_metrics("test", metrics) 711 | 712 | trainer.save_metrics("test", metrics) 713 | 714 | logger.info("***** Final Model ******\nNumber of trainable full param: %d\nNumber of trainable sparse param: %d, Ratio: %.4f%%\n**********" % (total_param, sparse_param, sparse_param / total_param * 100)) 715 | 716 | if __name__ == "__main__": 717 | main() -------------------------------------------------------------------------------- /run_glue_finetune.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning the library models for sequence classification on GLUE.""" 17 | # You can also adapt this script on your own text classification task. Pointers for this are left as comments. 18 | 19 | import logging 20 | import os 21 | import wandb 22 | os.environ['WANDB_MODE'] = 'offline' 23 | import random 24 | import sys 25 | from dataclasses import dataclass, field 26 | from typing import Optional 27 | 28 | import datasets 29 | import numpy as np 30 | from datasets import load_dataset 31 | 32 | import transformers 33 | from transformers import ( 34 | AutoConfig, 35 | AutoModelForSequenceClassification, 36 | AutoTokenizer, 37 | DataCollatorWithPadding, 38 | EvalPrediction, 39 | HfArgumentParser, 40 | PretrainedConfig, 41 | Trainer, 42 | TrainingArguments, 43 | default_data_collator, 44 | set_seed, 45 | ) 46 | from transformers.trainer_utils import get_last_checkpoint 47 | from transformers.utils import check_min_version 48 | from transformers.utils.versions import require_version 49 | sys.path.append('../') 50 | from src.trainer import SparseTrainer 51 | from src.util import compute_trainable_sparse_param, create_optimizer_and_scheduler 52 | from src.sparse_optimizer import SparseAdamW 53 | from transformers import get_linear_schedule_with_warmup 54 | 55 | 56 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 57 | # check_min_version("4.24.0") 58 | 59 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") 60 | 61 | task_to_keys = { 62 | "cola": ("sentence", None), 63 | "mnli": ("premise", "hypothesis"), 64 | "mnli-m": ("premise", "hypothesis"), 65 | "mnli-mm": ("premise", "hypothesis"), 66 | "mrpc": ("sentence1", "sentence2"), 67 | "qnli": ("question", "sentence"), 68 | "qqp": ("question1", "question2"), 69 | "rte": ("sentence1", "sentence2"), 70 | "sst2": ("sentence", None), 71 | "stsb": ("sentence1", "sentence2"), 72 | "wnli": ("sentence1", "sentence2"), 73 | } 74 | 75 | task_to_best_metric = { 76 | "rte": "eval_accuracy", 77 | "mrpc": "eval_f1", 78 | "cola": "eval_matthews_correlation", 79 | "stsb": "eval_pearson", 80 | "sst2": "eval_accuracy", 81 | "qnli": "eval_accuracy", 82 | "mnli": "eval_accuracy", 83 | "mnli-m": "eval_accuracy", 84 | "mnli-mm": "eval_accuracy", 85 | "qqp": "eval_accuracy", 86 | } 87 | 88 | data_path = '/root/xtlv/data/sora_datasets/glue_datasets_from_dn/' 89 | 90 | logger = logging.getLogger(__name__) 91 | 92 | 93 | @dataclass 94 | class DataTrainingArguments: 95 | """ 96 | Arguments pertaining to what data we are going to input our model for training and eval. 97 | Using `HfArgumentParser` we can turn this class 98 | into argparse arguments to be able to specify them on 99 | the command line. 100 | """ 101 | 102 | task_name: Optional[str] = field( 103 | default=None, 104 | metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, 105 | ) 106 | dataset_name: Optional[str] = field( 107 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 108 | ) 109 | dataset_config_name: Optional[str] = field( 110 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 111 | ) 112 | max_seq_length: int = field( 113 | default=128, 114 | metadata={ 115 | "help": ( 116 | "The maximum total input sequence length after tokenization. Sequences longer " 117 | "than this will be truncated, sequences shorter will be padded." 118 | ) 119 | }, 120 | ) 121 | overwrite_cache: bool = field( 122 | default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} 123 | ) 124 | pad_to_max_length: bool = field( 125 | default=True, 126 | metadata={ 127 | "help": ( 128 | "Whether to pad all samples to `max_seq_length`. " 129 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 130 | ) 131 | }, 132 | ) 133 | max_train_samples: Optional[int] = field( 134 | default=None, 135 | metadata={ 136 | "help": ( 137 | "For debugging purposes or quicker training, truncate the number of training examples to this " 138 | "value if set." 139 | ) 140 | }, 141 | ) 142 | max_eval_samples: Optional[int] = field( 143 | default=None, 144 | metadata={ 145 | "help": ( 146 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 147 | "value if set." 148 | ) 149 | }, 150 | ) 151 | max_predict_samples: Optional[int] = field( 152 | default=None, 153 | metadata={ 154 | "help": ( 155 | "For debugging purposes or quicker training, truncate the number of prediction examples to this " 156 | "value if set." 157 | ) 158 | }, 159 | ) 160 | train_file: Optional[str] = field( 161 | default=None, metadata={"help": "A csv or a json file containing the training data."} 162 | ) 163 | validation_file: Optional[str] = field( 164 | default=None, metadata={"help": "A csv or a json file containing the validation data."} 165 | ) 166 | test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."}) 167 | 168 | def __post_init__(self): 169 | if self.task_name is not None: 170 | self.task_name = self.task_name.lower() 171 | if self.task_name not in task_to_keys.keys(): 172 | raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys())) 173 | elif self.dataset_name is not None: 174 | pass 175 | elif self.train_file is None or self.validation_file is None: 176 | raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.") 177 | else: 178 | train_extension = self.train_file.split(".")[-1] 179 | assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file." 180 | validation_extension = self.validation_file.split(".")[-1] 181 | assert ( 182 | validation_extension == train_extension 183 | ), "`validation_file` should have the same extension (csv or json) as `train_file`." 184 | 185 | 186 | @dataclass 187 | class ModelArguments: 188 | """ 189 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 190 | """ 191 | 192 | model_name_or_path: str = field( 193 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 194 | ) 195 | config_name: Optional[str] = field( 196 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 197 | ) 198 | tokenizer_name: Optional[str] = field( 199 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 200 | ) 201 | cache_dir: Optional[str] = field( 202 | default=None, 203 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 204 | ) 205 | use_fast_tokenizer: bool = field( 206 | default=True, 207 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 208 | ) 209 | model_revision: str = field( 210 | default="main", 211 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 212 | ) 213 | use_auth_token: bool = field( 214 | default=False, 215 | metadata={ 216 | "help": ( 217 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 218 | "with private models)." 219 | ) 220 | }, 221 | ) 222 | ignore_mismatched_sizes: bool = field( 223 | default=False, 224 | metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."}, 225 | ) 226 | 227 | @dataclass 228 | class SparseArguments: 229 | sparse_lambda: Optional[float] = field( 230 | default=1e-3, metadata={"help": "loss penalty term for gate param"} 231 | ) 232 | sparse_lambda_2: Optional[float] = field( 233 | default=1e-3, metadata={"help": "clipping scale for gate param"} 234 | ) 235 | sparse_lr: Optional[float] = field( 236 | default=None, metadata={"help": "lr for gate parameter in sparse lora, default to same as learning rate for other parameters"} 237 | ) 238 | lora_r: Optional[int] = field( 239 | default=16, metadata={"help": "matrix rank in lora"} 240 | ) 241 | lambda_schedule: Optional[str] = field( 242 | default=None, metadata={"help": "scheduling of lambda_2, {linear, log_linear}"} 243 | ) 244 | max_lambda: Optional[float] = field( 245 | default=10, metadata={"help": "maximum value of lambda_2 in scheduling"} 246 | ) 247 | lambda_num: Optional[int] = field( 248 | default=10, metadata={"help": "total number of lambdas in scheduling"} 249 | ) 250 | 251 | 252 | @dataclass 253 | class SparseTrainingArguments(TrainingArguments): 254 | train_sparse: Optional[bool] = field( 255 | default=False, metadata={"help": "whether use sparse lora"} 256 | ) 257 | debug_mode: Optional[bool] = field( 258 | default=False, metadata={"help": "debug mode"} 259 | ) 260 | 261 | 262 | def main(): 263 | # See all possible arguments in src/transformers/training_args.py 264 | # or by passing the --help flag to this script. 265 | # We now keep distinct sets of args, for a cleaner separation of concerns. 266 | 267 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, SparseTrainingArguments, SparseArguments)) 268 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 269 | # If we pass only one argument to the script and it's the path to a json file, 270 | # let's parse it to get our arguments. 271 | model_args, data_args, training_args, sparse_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 272 | else: 273 | model_args, data_args, training_args, sparse_args = parser.parse_args_into_dataclasses() 274 | 275 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The 276 | # information sent is the one passed as arguments along with your Python/PyTorch versions. 277 | # send_example_telemetry("run_glue", model_args, data_args) 278 | 279 | task_name_for_get = data_args.task_name 280 | if "mnli" in data_args.task_name: 281 | data_args.task_name = "mnli" 282 | 283 | training_args.metric_for_best_model = task_to_best_metric[data_args.task_name] 284 | 285 | if os.getenv("LOCAL_RANK"): 286 | training_args.local_rank = int(os.environ["LOCAL_RANK"]) 287 | else: 288 | training_args.local_rank = -1 289 | 290 | if training_args.train_sparse: 291 | if sparse_args.sparse_lr is None: 292 | sparse_args.sparse_lr = training_args.learning_rate 293 | 294 | if training_args.debug_mode: 295 | training_args.output_dir += "-debug" 296 | print(f"save model to {training_args.output_dir}") 297 | 298 | # Setup logging 299 | logging.basicConfig( 300 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 301 | datefmt="%m/%d/%Y %H:%M:%S", 302 | handlers=[logging.StreamHandler(sys.stdout)], 303 | ) 304 | 305 | log_level = training_args.get_process_log_level() 306 | logger.setLevel(log_level) 307 | datasets.utils.logging.set_verbosity(log_level) 308 | transformers.utils.logging.set_verbosity(log_level) 309 | transformers.utils.logging.enable_default_handler() 310 | transformers.utils.logging.enable_explicit_format() 311 | 312 | # Log on each process the small summary: 313 | logger.warning( 314 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 315 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 316 | ) 317 | logger.info(f"Training/evaluation parameters {training_args}") 318 | 319 | # Detecting last checkpoint. 320 | last_checkpoint = None 321 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 322 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 323 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 324 | raise ValueError( 325 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 326 | "Use --overwrite_output_dir to overcome." 327 | ) 328 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 329 | logger.info( 330 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 331 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 332 | ) 333 | 334 | # Set seed before initializing model. 335 | set_seed(training_args.seed) 336 | 337 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 338 | # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). 339 | # 340 | # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the 341 | # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named 342 | # label if at least two columns are provided. 343 | # 344 | # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this 345 | # single column. You can easily tweak this behavior (see below) 346 | # 347 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 348 | # download the dataset. 349 | if data_args.task_name is not None: 350 | # Downloading and loading a dataset from the hub. 351 | from datasets import load_from_disk 352 | from src.glue_tasks import AutoTask 353 | raw_datasets = load_from_disk(data_path + data_args.task_name) 354 | 355 | task = AutoTask().get(task_name_for_get, None, None) 356 | raw_datasets = { 357 | "train": task.get("train", split_validation_test=True), 358 | "validation": task.get("validation", split_validation_test=True), 359 | "test": task.get("test", split_validation_test=True) 360 | } 361 | from datasets import DatasetDict 362 | raw_datasets = DatasetDict(raw_datasets) 363 | 364 | 365 | 366 | elif data_args.dataset_name is not None: 367 | raise NotImplementedError 368 | 369 | else: 370 | # Loading a dataset from your local files. 371 | # CSV/JSON training and evaluation files are needed. 372 | data_files = {"train": data_args.train_file, "validation": data_args.validation_file} 373 | 374 | # Get the test dataset: you can provide your own CSV/JSON test file (see below) 375 | # when you use `do_predict` without specifying a GLUE benchmark task. 376 | if training_args.do_predict: 377 | if data_args.test_file is not None: 378 | train_extension = data_args.train_file.split(".")[-1] 379 | test_extension = data_args.test_file.split(".")[-1] 380 | assert ( 381 | test_extension == train_extension 382 | ), "`test_file` should have the same extension (csv or json) as `train_file`." 383 | data_files["test"] = data_args.test_file 384 | else: 385 | raise ValueError("Need either a GLUE task or a test file for `do_predict`.") 386 | 387 | for key in data_files.keys(): 388 | logger.info(f"load a local file for {key}: {data_files[key]}") 389 | 390 | if data_args.train_file.endswith(".csv"): 391 | # Loading a dataset from local csv files 392 | raw_datasets = load_dataset( 393 | "csv", 394 | data_files=data_files, 395 | cache_dir=model_args.cache_dir, 396 | use_auth_token=True if model_args.use_auth_token else None, 397 | ) 398 | else: 399 | # Loading a dataset from local json files 400 | raw_datasets = load_dataset( 401 | "json", 402 | data_files=data_files, 403 | cache_dir=model_args.cache_dir, 404 | use_auth_token=True if model_args.use_auth_token else None, 405 | ) 406 | # See more about loading any type of standard or custom dataset at 407 | # https://huggingface.co/docs/datasets/loading_datasets.html. 408 | 409 | # Labels 410 | if data_args.task_name is not None: 411 | is_regression = data_args.task_name == "stsb" 412 | if not is_regression: 413 | label_list = raw_datasets["train"].features["label"].names 414 | num_labels = len(label_list) 415 | else: 416 | num_labels = 1 417 | else: 418 | # Trying to have good defaults here, don't hesitate to tweak to your needs. 419 | is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] 420 | if is_regression: 421 | num_labels = 1 422 | else: 423 | # A useful fast method: 424 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique 425 | label_list = raw_datasets["train"].unique("label") 426 | label_list.sort() # Let's sort it for determinism 427 | num_labels = len(label_list) 428 | 429 | # Load pretrained model and tokenizer 430 | # 431 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 432 | # download model & vocab. 433 | config = AutoConfig.from_pretrained( 434 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 435 | num_labels=num_labels, 436 | finetuning_task=data_args.task_name, 437 | cache_dir=model_args.cache_dir, 438 | revision=model_args.model_revision, 439 | use_auth_token=True if model_args.use_auth_token else None, 440 | ) 441 | tokenizer = AutoTokenizer.from_pretrained( 442 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 443 | cache_dir=model_args.cache_dir, 444 | use_fast=model_args.use_fast_tokenizer, 445 | revision=model_args.model_revision, 446 | use_auth_token=True if model_args.use_auth_token else None, 447 | ) 448 | model = AutoModelForSequenceClassification.from_pretrained( 449 | model_args.model_name_or_path, 450 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 451 | config=config, 452 | cache_dir=model_args.cache_dir, 453 | revision=model_args.model_revision, 454 | use_auth_token=True if model_args.use_auth_token else None, 455 | ignore_mismatched_sizes=model_args.ignore_mismatched_sizes, 456 | ) 457 | 458 | # Preprocessing the raw_datasets 459 | if data_args.task_name is not None: 460 | sentence1_key, sentence2_key = task_to_keys[data_args.task_name] 461 | else: 462 | # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. 463 | non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] 464 | if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: 465 | sentence1_key, sentence2_key = "sentence1", "sentence2" 466 | else: 467 | if len(non_label_column_names) >= 2: 468 | sentence1_key, sentence2_key = non_label_column_names[:2] 469 | else: 470 | sentence1_key, sentence2_key = non_label_column_names[0], None 471 | 472 | # Padding strategy 473 | if data_args.pad_to_max_length: 474 | padding = "max_length" 475 | else: 476 | # We will pad later, dynamically at batch creation, to the max sequence length in each batch 477 | padding = False 478 | 479 | # Some models have set the order of the labels to use, so let's make sure we do use it. 480 | label_to_id = None 481 | if ( 482 | model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id 483 | and data_args.task_name is not None 484 | and not is_regression 485 | ): 486 | # Some have all caps in their config, some don't. 487 | label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} 488 | if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): 489 | label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)} 490 | else: 491 | logger.warning( 492 | "Your model seems to have been trained with labels, but they don't match the dataset: ", 493 | f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." 494 | "\nIgnoring the model labels as a result.", 495 | ) 496 | elif data_args.task_name is None and not is_regression: 497 | label_to_id = {v: i for i, v in enumerate(label_list)} 498 | 499 | if label_to_id is not None: 500 | model.config.label2id = label_to_id 501 | model.config.id2label = {id: label for label, id in config.label2id.items()} 502 | elif data_args.task_name is not None and not is_regression: 503 | model.config.label2id = {l: i for i, l in enumerate(label_list)} 504 | model.config.id2label = {id: label for label, id in config.label2id.items()} 505 | 506 | 507 | if data_args.max_seq_length > tokenizer.model_max_length: 508 | logger.warning( 509 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 510 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 511 | ) 512 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 513 | 514 | def preprocess_function(examples): 515 | # Tokenize the texts 516 | args = ( 517 | (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) 518 | ) 519 | result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) 520 | 521 | # Map labels to IDs (not necessary for GLUE tasks) 522 | if label_to_id is not None and "label" in examples: 523 | result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] 524 | return result 525 | 526 | with training_args.main_process_first(desc="dataset map pre-processing"): 527 | raw_datasets = raw_datasets.map( 528 | preprocess_function, 529 | batched=True, 530 | load_from_cache_file=not data_args.overwrite_cache, 531 | desc="Running tokenizer on dataset", 532 | ) 533 | if training_args.do_train: 534 | if "train" not in raw_datasets: 535 | raise ValueError("--do_train requires a train dataset") 536 | train_dataset = raw_datasets["train"] 537 | if data_args.max_train_samples is not None: 538 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 539 | train_dataset = train_dataset.select(range(max_train_samples)) 540 | 541 | if training_args.do_eval: 542 | if "validation" not in raw_datasets and "validation_matched" not in raw_datasets: 543 | raise ValueError("--do_eval requires a validation dataset") 544 | # eval_dataset = raw_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"] 545 | eval_dataset = raw_datasets["validation"] 546 | if data_args.max_eval_samples is not None: 547 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 548 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 549 | 550 | if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None: 551 | if "test" not in raw_datasets and "test_matched" not in raw_datasets: 552 | raise ValueError("--do_predict requires a test dataset") 553 | # predict_dataset = raw_datasets["test_matched" if data_args.task_name == "mnli" else "test"] 554 | predict_dataset = raw_datasets["test"] 555 | if data_args.max_predict_samples is not None: 556 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) 557 | predict_dataset = predict_dataset.select(range(max_predict_samples)) 558 | 559 | # Log a few random samples from the training set: 560 | if training_args.do_train: 561 | for index in random.sample(range(len(train_dataset)), 3): 562 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 563 | 564 | # Get the metric function 565 | from datasets import load_metric 566 | if data_args.task_name is not None: 567 | metric = load_metric("./glue.py", data_args.task_name) 568 | else: 569 | metric = load_metric("accuracy") 570 | 571 | # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a 572 | # predictions and label_ids field) and has to return a dictionary string to float. 573 | def compute_metrics(mode, p: EvalPrediction): 574 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions 575 | preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) 576 | if data_args.task_name is not None: 577 | result = metric.compute(predictions=preds, references=p.label_ids) 578 | if len(result) > 1: 579 | result["combined_score"] = np.mean(list(result.values())).item() 580 | return result 581 | elif is_regression: 582 | return {"mse": ((preds - p.label_ids) ** 2).mean().item()} 583 | else: 584 | return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} 585 | 586 | # Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if 587 | # we already did the padding. 588 | if data_args.pad_to_max_length: 589 | data_collator = default_data_collator 590 | elif training_args.fp16: 591 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) 592 | else: 593 | data_collator = None 594 | 595 | 596 | # Initialize our Trainer 597 | optimizer, lr_scheduler = create_optimizer_and_scheduler(training_args, model, num_training_steps=int(training_args.num_train_epochs*(len(train_dataset) / training_args.train_batch_size))) 598 | sparse_optimizer = None 599 | sparse_scheduler = None 600 | if training_args.train_sparse: 601 | print("building sparse optimizer and scheduler") 602 | from src.trainer import GATE_PARAM_NAME 603 | valid_param_name = [] 604 | for n, p in model.named_parameters(): 605 | print(n) 606 | if GATE_PARAM_NAME in n: 607 | valid_param_name.append(n) 608 | print("valid param name:", valid_param_name) 609 | sparse_optimizer = SparseAdamW(sparse_lambda=sparse_args.sparse_lambda_2, lambda_schedule=sparse_args.lambda_schedule, max_lambda=sparse_args.max_lambda, lambda_num=sparse_args.lambda_num, params=[p for n, p in model.named_parameters() if GATE_PARAM_NAME in n and p.requires_grad], lr=sparse_args.sparse_lr) 610 | sparse_scheduler = get_linear_schedule_with_warmup(sparse_optimizer, 611 | num_warmup_steps=int(training_args.num_train_epochs*(len(train_dataset) / training_args.train_batch_size)*training_args.warmup_ratio), 612 | num_training_steps=int(training_args.num_train_epochs*(len(train_dataset) / training_args.train_batch_size))) 613 | 614 | if training_args.debug_mode: 615 | train_dataset = eval_dataset 616 | 617 | # Initialize our Trainer 618 | trainer = SparseTrainer( 619 | model=model, 620 | args=training_args, 621 | train_dataset=train_dataset if training_args.do_train else None, 622 | eval_dataset=eval_dataset if training_args.do_eval else None, 623 | compute_metrics=compute_metrics, 624 | tokenizer=tokenizer, 625 | data_collator=data_collator, 626 | optimizers = (optimizer, lr_scheduler), 627 | sparse_lambda = sparse_args.sparse_lambda, 628 | sparse_optimizer = (sparse_optimizer, sparse_scheduler) 629 | ) 630 | 631 | # Training 632 | if training_args.do_train: 633 | checkpoint = None 634 | if training_args.resume_from_checkpoint is not None: 635 | checkpoint = training_args.resume_from_checkpoint 636 | elif last_checkpoint is not None: 637 | checkpoint = last_checkpoint 638 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 639 | metrics = train_result.metrics 640 | max_train_samples = ( 641 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 642 | ) 643 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 644 | 645 | trainer.save_model() # Saves the tokenizer too for easy upload 646 | 647 | trainer.log_metrics("train", metrics) 648 | trainer.save_metrics("train", metrics) 649 | trainer.save_state() 650 | 651 | sparse_param, total_param = compute_trainable_sparse_param(model) 652 | 653 | 654 | # eval on 1000 samples train set 655 | train_dataset_for_eval = train_dataset.shuffle(seed=42).select(range(1000)) 656 | logger.info("*** Evaluate on training subset ***") 657 | metrics = trainer.evaluate(eval_dataset=train_dataset_for_eval, metric_key_prefix = "eval_train") 658 | trainer.log_metrics("eval_train", metrics) 659 | trainer.save_metrics("eval_train", metrics) 660 | BEST_TRAIN_METRIC = metrics["eval_train_" + "_".join(task_to_best_metric[data_args.task_name].split("_")[1:])] 661 | 662 | 663 | # Evaluation 664 | 665 | if training_args.do_eval: 666 | logger.info("*** Evaluate ***") 667 | 668 | # Loop to handle MNLI double evaluation (matched, mis-matched) 669 | tasks = [data_args.task_name] 670 | eval_datasets = [eval_dataset] 671 | 672 | 673 | for eval_dataset, task in zip(eval_datasets, tasks): 674 | metrics = trainer.evaluate(eval_dataset=eval_dataset) 675 | 676 | max_eval_samples = ( 677 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 678 | ) 679 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 680 | 681 | 682 | 683 | trainer.log_metrics("eval", metrics) 684 | trainer.save_metrics("eval", metrics) 685 | 686 | BEST_EVAL_METRIC = metrics[task_to_best_metric[data_args.task_name]] 687 | 688 | if training_args.do_predict: 689 | logger.info("*** Predict ***") 690 | 691 | # Loop to handle MNLI double evaluation (matched, mis-matched) 692 | tasks = [data_args.task_name] 693 | predict_datasets = [predict_dataset] 694 | 695 | 696 | for predict_dataset, task in zip(predict_datasets, tasks): 697 | metrics = trainer.evaluate(eval_dataset=predict_dataset) 698 | 699 | max_eval_samples = ( 700 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 701 | ) 702 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 703 | 704 | 705 | 706 | trainer.log_metrics("test", metrics) 707 | 708 | trainer.save_metrics("test", metrics) 709 | 710 | logger.info("***** Final Model ******\nNumber of trainable full param: %d\nNumber of trainable sparse param: %d, Ratio: %.4f%%\n**********" % (total_param, sparse_param, sparse_param / total_param * 100)) 711 | 712 | 713 | def compute_metrics_in_schedule(mode, p: EvalPrediction): 714 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions 715 | preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) 716 | if data_args.task_name is not None: 717 | result = metric.compute(predictions=preds, references=p.label_ids) 718 | if len(result) > 1: 719 | result["combined_score"] = np.mean(list(result.values())).item() 720 | if mode == "eval": 721 | result["generalization"] = result["_".join(task_to_best_metric[data_args.task_name].split("_")[1:])] / BEST_EVAL_METRIC * 100 722 | elif mode == "eval_train": 723 | result["memorization"] = result["_".join(task_to_best_metric[data_args.task_name].split("_")[1:])] / BEST_TRAIN_METRIC * 100 724 | elif mode == "test": 725 | pass 726 | else: 727 | raise NotImplementedError 728 | return result 729 | elif is_regression: 730 | raise NotImplementedError 731 | else: 732 | raise NotImplementedError 733 | 734 | 735 | # schedule 736 | if sparse_args.lambda_schedule is not None: 737 | logger.info("*****Start lambda_2 scheduling***") 738 | from transformers import EarlyStoppingCallback 739 | for _ in range(sparse_args.lambda_num - 1): 740 | training_args.num_train_epochs = 15 741 | training_args.load_best_model_at_end = False 742 | sparse_optimizer.step_lambda() 743 | trainer = SparseTrainer( 744 | model=model, 745 | args=training_args, 746 | train_dataset=train_dataset if training_args.do_train else None, 747 | eval_dataset=[eval_dataset if training_args.do_eval else None, train_dataset_for_eval], 748 | compute_metrics=compute_metrics_in_schedule, 749 | tokenizer=tokenizer, 750 | data_collator=data_collator, 751 | optimizers = (optimizer, lr_scheduler), 752 | sparse_lambda = sparse_args.sparse_lambda, 753 | sparse_optimizer = (sparse_optimizer, sparse_scheduler), 754 | ) 755 | 756 | trainer.train() 757 | 758 | if training_args.do_predict: 759 | logger.info("*** Predict ***") 760 | 761 | # Loop to handle MNLI double evaluation (matched, mis-matched) 762 | tasks = [data_args.task_name] 763 | predict_datasets = [predict_dataset] 764 | 765 | 766 | for predict_dataset, task in zip(predict_datasets, tasks): 767 | metrics = trainer.evaluate(eval_dataset=predict_dataset, metric_key_prefix="test") 768 | 769 | max_eval_samples = ( 770 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 771 | ) 772 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 773 | 774 | 775 | 776 | trainer.log_metrics("test", metrics) 777 | 778 | trainer.save_metrics("test", metrics) 779 | 780 | 781 | 782 | 783 | 784 | sparse_param, total_param = compute_trainable_sparse_param(model) 785 | 786 | logger.info("***** Lambda=%f Final Model ******\nLora rank: %d\nNumber of trainable full param: %d\nNumber of trainable sparse param: %d, Ratio: %.4f%%\n**********" % (sparse_optimizer.sparse_lambda, lora_config.lora_r, total_param, sparse_param, sparse_param / total_param * 100)) 787 | 788 | 789 | 790 | def _mp_fn(index): 791 | # For xla_spawn (TPUs) 792 | main() 793 | 794 | 795 | if __name__ == "__main__": 796 | main() -------------------------------------------------------------------------------- /run_glue.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning the library models for sequence classification on GLUE.""" 17 | # You can also adapt this script on your own text classification task. Pointers for this are left as comments. 18 | 19 | import logging 20 | import os 21 | import wandb 22 | os.environ['WANDB_MODE'] = 'offline' 23 | import random 24 | import sys 25 | from dataclasses import dataclass, field 26 | from typing import Optional 27 | 28 | import datasets 29 | import numpy as np 30 | from datasets import load_dataset 31 | 32 | import transformers 33 | from transformers import ( 34 | AutoConfig, 35 | AutoModelForSequenceClassification, 36 | AutoTokenizer, 37 | DataCollatorWithPadding, 38 | EvalPrediction, 39 | HfArgumentParser, 40 | PretrainedConfig, 41 | Trainer, 42 | TrainingArguments, 43 | default_data_collator, 44 | set_seed, 45 | ) 46 | from transformers.trainer_utils import get_last_checkpoint 47 | from transformers.utils import check_min_version 48 | from transformers.utils.versions import require_version 49 | sys.path.append('../') 50 | from src.trainer import SparseTrainer 51 | from src.util import compute_trainable_sparse_param, create_optimizer_and_scheduler 52 | from src.sparse_optimizer import SparseAdamW 53 | from transformers import get_linear_schedule_with_warmup 54 | 55 | 56 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 57 | # check_min_version("4.24.0") 58 | 59 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") 60 | 61 | task_to_keys = { 62 | "cola": ("sentence", None), 63 | "mnli": ("premise", "hypothesis"), 64 | "mnli-m": ("premise", "hypothesis"), 65 | "mnli-mm": ("premise", "hypothesis"), 66 | "mrpc": ("sentence1", "sentence2"), 67 | "qnli": ("question", "sentence"), 68 | "qqp": ("question1", "question2"), 69 | "rte": ("sentence1", "sentence2"), 70 | "sst2": ("sentence", None), 71 | "stsb": ("sentence1", "sentence2"), 72 | "wnli": ("sentence1", "sentence2"), 73 | } 74 | 75 | task_to_best_metric = { 76 | "rte": "eval_accuracy", 77 | "mrpc": "eval_f1", 78 | "cola": "eval_matthews_correlation", 79 | "stsb": "eval_pearson", 80 | "sst2": "eval_accuracy", 81 | "qnli": "eval_accuracy", 82 | "mnli": "eval_accuracy", 83 | "mnli-m": "eval_accuracy", 84 | "mnli-mm": "eval_accuracy", 85 | "qqp": "eval_accuracy", 86 | } 87 | 88 | data_path = '/root/xtlv/data/sora_datasets/glue_datasets_from_dn/' 89 | 90 | logger = logging.getLogger(__name__) 91 | 92 | 93 | @dataclass 94 | class DataTrainingArguments: 95 | """ 96 | Arguments pertaining to what data we are going to input our model for training and eval. 97 | Using `HfArgumentParser` we can turn this class 98 | into argparse arguments to be able to specify them on 99 | the command line. 100 | """ 101 | 102 | task_name: Optional[str] = field( 103 | default=None, 104 | metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, 105 | ) 106 | dataset_name: Optional[str] = field( 107 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 108 | ) 109 | dataset_config_name: Optional[str] = field( 110 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 111 | ) 112 | max_seq_length: int = field( 113 | default=128, 114 | metadata={ 115 | "help": ( 116 | "The maximum total input sequence length after tokenization. Sequences longer " 117 | "than this will be truncated, sequences shorter will be padded." 118 | ) 119 | }, 120 | ) 121 | overwrite_cache: bool = field( 122 | default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} 123 | ) 124 | pad_to_max_length: bool = field( 125 | default=True, 126 | metadata={ 127 | "help": ( 128 | "Whether to pad all samples to `max_seq_length`. " 129 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 130 | ) 131 | }, 132 | ) 133 | max_train_samples: Optional[int] = field( 134 | default=None, 135 | metadata={ 136 | "help": ( 137 | "For debugging purposes or quicker training, truncate the number of training examples to this " 138 | "value if set." 139 | ) 140 | }, 141 | ) 142 | max_eval_samples: Optional[int] = field( 143 | default=None, 144 | metadata={ 145 | "help": ( 146 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 147 | "value if set." 148 | ) 149 | }, 150 | ) 151 | max_predict_samples: Optional[int] = field( 152 | default=None, 153 | metadata={ 154 | "help": ( 155 | "For debugging purposes or quicker training, truncate the number of prediction examples to this " 156 | "value if set." 157 | ) 158 | }, 159 | ) 160 | train_file: Optional[str] = field( 161 | default=None, metadata={"help": "A csv or a json file containing the training data."} 162 | ) 163 | validation_file: Optional[str] = field( 164 | default=None, metadata={"help": "A csv or a json file containing the validation data."} 165 | ) 166 | test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."}) 167 | 168 | def __post_init__(self): 169 | if self.task_name is not None: 170 | self.task_name = self.task_name.lower() 171 | if self.task_name not in task_to_keys.keys(): 172 | raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys())) 173 | elif self.dataset_name is not None: 174 | pass 175 | elif self.train_file is None or self.validation_file is None: 176 | raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.") 177 | else: 178 | train_extension = self.train_file.split(".")[-1] 179 | assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file." 180 | validation_extension = self.validation_file.split(".")[-1] 181 | assert ( 182 | validation_extension == train_extension 183 | ), "`validation_file` should have the same extension (csv or json) as `train_file`." 184 | 185 | 186 | @dataclass 187 | class ModelArguments: 188 | """ 189 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 190 | """ 191 | 192 | model_name_or_path: str = field( 193 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 194 | ) 195 | config_name: Optional[str] = field( 196 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 197 | ) 198 | tokenizer_name: Optional[str] = field( 199 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 200 | ) 201 | cache_dir: Optional[str] = field( 202 | default=None, 203 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 204 | ) 205 | use_fast_tokenizer: bool = field( 206 | default=True, 207 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 208 | ) 209 | model_revision: str = field( 210 | default="main", 211 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 212 | ) 213 | use_auth_token: bool = field( 214 | default=False, 215 | metadata={ 216 | "help": ( 217 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 218 | "with private models)." 219 | ) 220 | }, 221 | ) 222 | ignore_mismatched_sizes: bool = field( 223 | default=False, 224 | metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."}, 225 | ) 226 | 227 | @dataclass 228 | class SparseArguments: 229 | sparse_lambda: Optional[float] = field( 230 | default=1e-3, metadata={"help": "loss penalty term for gate param"} 231 | ) 232 | sparse_lambda_2: Optional[float] = field( 233 | default=1e-3, metadata={"help": "clipping scale for gate param"} 234 | ) 235 | sparse_lr: Optional[float] = field( 236 | default=None, metadata={"help": "lr for gate parameter in sparse lora, default to same as learning rate for other parameters"} 237 | ) 238 | lora_r: Optional[int] = field( 239 | default=16, metadata={"help": "matrix rank in lora"} 240 | ) 241 | lambda_schedule: Optional[str] = field( 242 | default=None, metadata={"help": "scheduling of lambda_2, {linear, log_linear}"} 243 | ) 244 | max_lambda: Optional[float] = field( 245 | default=10, metadata={"help": "maximum value of lambda_2 in scheduling"} 246 | ) 247 | lambda_num: Optional[int] = field( 248 | default=10, metadata={"help": "total number of lambdas in scheduling"} 249 | ) 250 | 251 | @dataclass 252 | class SparseTrainingArguments(TrainingArguments): 253 | train_sparse: Optional[bool] = field( 254 | default=False, metadata={"help": "whether use sparse lora"} 255 | ) 256 | debug_mode: Optional[bool] = field( 257 | default=False, metadata={"help": "debug mode"} 258 | ) 259 | 260 | 261 | def main(): 262 | # See all possible arguments in src/transformers/training_args.py 263 | # or by passing the --help flag to this script. 264 | # We now keep distinct sets of args, for a cleaner separation of concerns. 265 | 266 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, SparseTrainingArguments, SparseArguments)) 267 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 268 | # If we pass only one argument to the script and it's the path to a json file, 269 | # let's parse it to get our arguments. 270 | model_args, data_args, training_args, sparse_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 271 | else: 272 | model_args, data_args, training_args, sparse_args = parser.parse_args_into_dataclasses() 273 | 274 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The 275 | # information sent is the one passed as arguments along with your Python/PyTorch versions. 276 | # send_example_telemetry("run_glue", model_args, data_args) 277 | 278 | task_name_for_get = data_args.task_name 279 | if "mnli" in data_args.task_name: 280 | data_args.task_name = "mnli" 281 | 282 | training_args.metric_for_best_model = task_to_best_metric[data_args.task_name] 283 | 284 | if os.getenv("LOCAL_RANK"): 285 | training_args.local_rank = int(os.environ["LOCAL_RANK"]) 286 | else: 287 | training_args.local_rank = -1 288 | 289 | if training_args.train_sparse: 290 | if sparse_args.sparse_lr is None: 291 | sparse_args.sparse_lr = training_args.learning_rate 292 | 293 | if training_args.debug_mode: 294 | training_args.output_dir += "-debug" 295 | print(f"save model to {training_args.output_dir}") 296 | 297 | # Setup logging 298 | logging.basicConfig( 299 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 300 | datefmt="%m/%d/%Y %H:%M:%S", 301 | handlers=[logging.StreamHandler(sys.stdout)], 302 | ) 303 | 304 | log_level = training_args.get_process_log_level() 305 | logger.setLevel(log_level) 306 | datasets.utils.logging.set_verbosity(log_level) 307 | transformers.utils.logging.set_verbosity(log_level) 308 | transformers.utils.logging.enable_default_handler() 309 | transformers.utils.logging.enable_explicit_format() 310 | 311 | # Log on each process the small summary: 312 | logger.warning( 313 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 314 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 315 | ) 316 | logger.info(f"Training/evaluation parameters {training_args}") 317 | 318 | # Detecting last checkpoint. 319 | last_checkpoint = None 320 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 321 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 322 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 323 | raise ValueError( 324 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 325 | "Use --overwrite_output_dir to overcome." 326 | ) 327 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 328 | logger.info( 329 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 330 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 331 | ) 332 | 333 | # Set seed before initializing model. 334 | set_seed(training_args.seed) 335 | 336 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 337 | # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). 338 | # 339 | # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the 340 | # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named 341 | # label if at least two columns are provided. 342 | # 343 | # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this 344 | # single column. You can easily tweak this behavior (see below) 345 | # 346 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 347 | # download the dataset. 348 | if data_args.task_name is not None: 349 | # Downloading and loading a dataset from the hub. 350 | from datasets import load_from_disk 351 | from src.glue_tasks import AutoTask 352 | raw_datasets = load_from_disk(data_path + data_args.task_name) 353 | 354 | task = AutoTask().get(task_name_for_get, None, None) 355 | raw_datasets = { 356 | "train": task.get("train", split_validation_test=True), 357 | "validation": task.get("validation", split_validation_test=True), 358 | "test": task.get("test", split_validation_test=True) 359 | } 360 | from datasets import DatasetDict 361 | raw_datasets = DatasetDict(raw_datasets) 362 | 363 | elif data_args.dataset_name is not None: 364 | raise NotImplementedError 365 | else: 366 | # Loading a dataset from your local files. 367 | # CSV/JSON training and evaluation files are needed. 368 | data_files = {"train": data_args.train_file, "validation": data_args.validation_file} 369 | 370 | # Get the test dataset: you can provide your own CSV/JSON test file (see below) 371 | # when you use `do_predict` without specifying a GLUE benchmark task. 372 | if training_args.do_predict: 373 | if data_args.test_file is not None: 374 | train_extension = data_args.train_file.split(".")[-1] 375 | test_extension = data_args.test_file.split(".")[-1] 376 | assert ( 377 | test_extension == train_extension 378 | ), "`test_file` should have the same extension (csv or json) as `train_file`." 379 | data_files["test"] = data_args.test_file 380 | else: 381 | raise ValueError("Need either a GLUE task or a test file for `do_predict`.") 382 | 383 | for key in data_files.keys(): 384 | logger.info(f"load a local file for {key}: {data_files[key]}") 385 | 386 | if data_args.train_file.endswith(".csv"): 387 | # Loading a dataset from local csv files 388 | raw_datasets = load_dataset( 389 | "csv", 390 | data_files=data_files, 391 | cache_dir=model_args.cache_dir, 392 | use_auth_token=True if model_args.use_auth_token else None, 393 | ) 394 | else: 395 | # Loading a dataset from local json files 396 | raw_datasets = load_dataset( 397 | "json", 398 | data_files=data_files, 399 | cache_dir=model_args.cache_dir, 400 | use_auth_token=True if model_args.use_auth_token else None, 401 | ) 402 | # See more about loading any type of standard or custom dataset at 403 | # https://huggingface.co/docs/datasets/loading_datasets.html. 404 | 405 | # Labels 406 | if data_args.task_name is not None: 407 | is_regression = data_args.task_name == "stsb" 408 | if not is_regression: 409 | label_list = raw_datasets["train"].features["label"].names 410 | num_labels = len(label_list) 411 | else: 412 | num_labels = 1 413 | else: 414 | # Trying to have good defaults here, don't hesitate to tweak to your needs. 415 | is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] 416 | if is_regression: 417 | num_labels = 1 418 | else: 419 | # A useful fast method: 420 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique 421 | label_list = raw_datasets["train"].unique("label") 422 | label_list.sort() # Let's sort it for determinism 423 | num_labels = len(label_list) 424 | 425 | # Load pretrained model and tokenizer 426 | # 427 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 428 | # download model & vocab. 429 | config = AutoConfig.from_pretrained( 430 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 431 | num_labels=num_labels, 432 | finetuning_task=data_args.task_name, 433 | cache_dir=model_args.cache_dir, 434 | revision=model_args.model_revision, 435 | use_auth_token=True if model_args.use_auth_token else None, 436 | ) 437 | tokenizer = AutoTokenizer.from_pretrained( 438 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 439 | cache_dir=model_args.cache_dir, 440 | use_fast=model_args.use_fast_tokenizer, 441 | revision=model_args.model_revision, 442 | use_auth_token=True if model_args.use_auth_token else None, 443 | ) 444 | model = AutoModelForSequenceClassification.from_pretrained( 445 | model_args.model_name_or_path, 446 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 447 | config=config, 448 | cache_dir=model_args.cache_dir, 449 | revision=model_args.model_revision, 450 | use_auth_token=True if model_args.use_auth_token else None, 451 | ignore_mismatched_sizes=model_args.ignore_mismatched_sizes, 452 | ) 453 | 454 | if training_args.train_sparse: 455 | print("loading from src.lora") 456 | from src.lora import LoraModel, LoraConfig 457 | else: 458 | from opendelta.delta_models import LoraModel, LoraConfig 459 | 460 | import json 461 | lora_config = json.load(open("config/lora_config.json")) 462 | lora_config["lora_r"] = sparse_args.lora_r 463 | lora_config = LoraConfig.from_dict(lora_config) 464 | delta_model = LoraModel.from_config(lora_config, backbone_model=model) 465 | delta_model.freeze_module(set_state_dict = True) 466 | delta_model.log(delta_ratio=True, trainable_ratio=True, visualization=False) 467 | 468 | 469 | # Preprocessing the raw_datasets 470 | if data_args.task_name is not None: 471 | sentence1_key, sentence2_key = task_to_keys[data_args.task_name] 472 | else: 473 | # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. 474 | non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] 475 | if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: 476 | sentence1_key, sentence2_key = "sentence1", "sentence2" 477 | else: 478 | if len(non_label_column_names) >= 2: 479 | sentence1_key, sentence2_key = non_label_column_names[:2] 480 | else: 481 | sentence1_key, sentence2_key = non_label_column_names[0], None 482 | 483 | # Padding strategy 484 | if data_args.pad_to_max_length: 485 | padding = "max_length" 486 | else: 487 | # We will pad later, dynamically at batch creation, to the max sequence length in each batch 488 | padding = False 489 | 490 | # Some models have set the order of the labels to use, so let's make sure we do use it. 491 | label_to_id = None 492 | if ( 493 | model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id 494 | and data_args.task_name is not None 495 | and not is_regression 496 | ): 497 | # Some have all caps in their config, some don't. 498 | label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} 499 | if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): 500 | label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)} 501 | else: 502 | logger.warning( 503 | "Your model seems to have been trained with labels, but they don't match the dataset: ", 504 | f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." 505 | "\nIgnoring the model labels as a result.", 506 | ) 507 | elif data_args.task_name is None and not is_regression: 508 | label_to_id = {v: i for i, v in enumerate(label_list)} 509 | 510 | if label_to_id is not None: 511 | model.config.label2id = label_to_id 512 | model.config.id2label = {id: label for label, id in config.label2id.items()} 513 | elif data_args.task_name is not None and not is_regression: 514 | model.config.label2id = {l: i for i, l in enumerate(label_list)} 515 | model.config.id2label = {id: label for label, id in config.label2id.items()} 516 | 517 | 518 | if data_args.max_seq_length > tokenizer.model_max_length: 519 | logger.warning( 520 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 521 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 522 | ) 523 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 524 | 525 | def preprocess_function(examples): 526 | # Tokenize the texts 527 | args = ( 528 | (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) 529 | ) 530 | result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) 531 | 532 | # Map labels to IDs (not necessary for GLUE tasks) 533 | if label_to_id is not None and "label" in examples: 534 | result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] 535 | return result 536 | 537 | with training_args.main_process_first(desc="dataset map pre-processing"): 538 | raw_datasets = raw_datasets.map( 539 | preprocess_function, 540 | batched=True, 541 | load_from_cache_file=not data_args.overwrite_cache, 542 | desc="Running tokenizer on dataset", 543 | ) 544 | if training_args.do_train: 545 | if "train" not in raw_datasets: 546 | raise ValueError("--do_train requires a train dataset") 547 | train_dataset = raw_datasets["train"] 548 | if data_args.max_train_samples is not None: 549 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 550 | train_dataset = train_dataset.select(range(max_train_samples)) 551 | 552 | if training_args.do_eval: 553 | if "validation" not in raw_datasets and "validation_matched" not in raw_datasets: 554 | raise ValueError("--do_eval requires a validation dataset") 555 | eval_dataset = raw_datasets["validation"] 556 | if data_args.max_eval_samples is not None: 557 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 558 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 559 | 560 | if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None: 561 | if "test" not in raw_datasets and "test_matched" not in raw_datasets: 562 | raise ValueError("--do_predict requires a test dataset") 563 | predict_dataset = raw_datasets["test"] 564 | if data_args.max_predict_samples is not None: 565 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) 566 | predict_dataset = predict_dataset.select(range(max_predict_samples)) 567 | 568 | # Log a few random samples from the training set: 569 | if training_args.do_train: 570 | for index in random.sample(range(len(train_dataset)), 3): 571 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 572 | 573 | # Get the metric function 574 | from datasets import load_metric 575 | if data_args.task_name is not None: 576 | metric = load_metric("./glue.py", data_args.task_name) 577 | else: 578 | metric = load_metric("accuracy") 579 | 580 | # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a 581 | # predictions and label_ids field) and has to return a dictionary string to float. 582 | def compute_metrics(mode, p: EvalPrediction): 583 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions 584 | preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) 585 | if data_args.task_name is not None: 586 | result = metric.compute(predictions=preds, references=p.label_ids) 587 | if len(result) > 1: 588 | result["combined_score"] = np.mean(list(result.values())).item() 589 | return result 590 | elif is_regression: 591 | return {"mse": ((preds - p.label_ids) ** 2).mean().item()} 592 | else: 593 | return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} 594 | 595 | # Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if 596 | # we already did the padding. 597 | if data_args.pad_to_max_length: 598 | data_collator = default_data_collator 599 | elif training_args.fp16: 600 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) 601 | else: 602 | data_collator = None 603 | 604 | 605 | # Initialize our Trainer 606 | optimizer, lr_scheduler = create_optimizer_and_scheduler(training_args, model, num_training_steps=int(training_args.num_train_epochs*(len(train_dataset) / training_args.train_batch_size))) 607 | sparse_optimizer = None 608 | sparse_scheduler = None 609 | if training_args.train_sparse: 610 | print("building sparse optimizer and scheduler") 611 | from src.trainer import GATE_PARAM_NAME 612 | valid_param_name = [] 613 | for n, p in model.named_parameters(): 614 | print(n) 615 | if GATE_PARAM_NAME in n: 616 | valid_param_name.append(n) 617 | print("valid param name:", valid_param_name) 618 | sparse_optimizer = SparseAdamW(sparse_lambda=sparse_args.sparse_lambda_2, lambda_schedule=sparse_args.lambda_schedule, max_lambda=sparse_args.max_lambda, lambda_num=sparse_args.lambda_num, params=[p for n, p in model.named_parameters() if GATE_PARAM_NAME in n and p.requires_grad], lr=sparse_args.sparse_lr) 619 | sparse_scheduler = get_linear_schedule_with_warmup(sparse_optimizer, 620 | num_warmup_steps=int(training_args.num_train_epochs*(len(train_dataset) / training_args.train_batch_size)*training_args.warmup_ratio), 621 | num_training_steps=int(training_args.num_train_epochs*(len(train_dataset) / training_args.train_batch_size))) 622 | 623 | if training_args.debug_mode: 624 | train_dataset = eval_dataset 625 | 626 | # Initialize our Trainer 627 | trainer = SparseTrainer( 628 | model=model, 629 | args=training_args, 630 | train_dataset=train_dataset if training_args.do_train else None, 631 | eval_dataset=eval_dataset if training_args.do_eval else None, 632 | compute_metrics=compute_metrics, 633 | tokenizer=tokenizer, 634 | data_collator=data_collator, 635 | optimizers = (optimizer, lr_scheduler), 636 | sparse_lambda = sparse_args.sparse_lambda, 637 | sparse_optimizer = (sparse_optimizer, sparse_scheduler) 638 | ) 639 | 640 | # Training 641 | if training_args.do_train: 642 | checkpoint = None 643 | if training_args.resume_from_checkpoint is not None: 644 | checkpoint = training_args.resume_from_checkpoint 645 | elif last_checkpoint is not None: 646 | checkpoint = last_checkpoint 647 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 648 | metrics = train_result.metrics 649 | max_train_samples = ( 650 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 651 | ) 652 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 653 | 654 | trainer.save_model() # Saves the tokenizer too for easy upload 655 | 656 | trainer.log_metrics("train", metrics) 657 | trainer.save_metrics("train", metrics) 658 | trainer.save_state() 659 | 660 | sparse_param, total_param = compute_trainable_sparse_param(model) 661 | 662 | 663 | # eval on 1000 samples train set 664 | train_dataset_for_eval = train_dataset.shuffle(seed=42).select(range(1000)) 665 | logger.info("*** Evaluate on training subset ***") 666 | metrics = trainer.evaluate(eval_dataset=train_dataset_for_eval, metric_key_prefix = "eval_train") 667 | trainer.log_metrics("eval_train", metrics) 668 | trainer.save_metrics("eval_train", metrics) 669 | BEST_TRAIN_METRIC = metrics["eval_train_" + "_".join(task_to_best_metric[data_args.task_name].split("_")[1:])] 670 | 671 | 672 | # Evaluation 673 | if training_args.do_eval: 674 | logger.info("*** Evaluate ***") 675 | 676 | # Loop to handle MNLI double evaluation (matched, mis-matched) 677 | tasks = [data_args.task_name] 678 | eval_datasets = [eval_dataset] 679 | 680 | for eval_dataset, task in zip(eval_datasets, tasks): 681 | metrics = trainer.evaluate(eval_dataset=eval_dataset) 682 | 683 | max_eval_samples = ( 684 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 685 | ) 686 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 687 | 688 | trainer.log_metrics("eval", metrics) 689 | trainer.save_metrics("eval", metrics) 690 | 691 | BEST_EVAL_METRIC = metrics[task_to_best_metric[data_args.task_name]] 692 | 693 | if training_args.do_predict: 694 | logger.info("*** Predict ***") 695 | 696 | # Loop to handle MNLI double evaluation (matched, mis-matched) 697 | tasks = [data_args.task_name] 698 | predict_datasets = [predict_dataset] 699 | 700 | for predict_dataset, task in zip(predict_datasets, tasks): 701 | metrics = trainer.evaluate(eval_dataset=predict_dataset) 702 | 703 | max_eval_samples = ( 704 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 705 | ) 706 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 707 | 708 | trainer.log_metrics("test", metrics) 709 | 710 | trainer.save_metrics("test", metrics) 711 | 712 | logger.info("***** Final Model ******\nLora rank: %d\nNumber of trainable full param: %d\nNumber of trainable sparse param: %d, Ratio: %.4f%%\n**********" % (lora_config.lora_r, total_param, sparse_param, sparse_param / total_param * 100)) 713 | 714 | 715 | def compute_metrics_in_schedule(mode, p: EvalPrediction): 716 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions 717 | preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) 718 | if data_args.task_name is not None: 719 | result = metric.compute(predictions=preds, references=p.label_ids) 720 | if len(result) > 1: 721 | result["combined_score"] = np.mean(list(result.values())).item() 722 | if mode == "eval": 723 | result["generalization"] = result["_".join(task_to_best_metric[data_args.task_name].split("_")[1:])] / BEST_EVAL_METRIC * 100 724 | elif mode == "eval_train": 725 | result["memorization"] = result["_".join(task_to_best_metric[data_args.task_name].split("_")[1:])] / BEST_TRAIN_METRIC * 100 726 | elif mode == "test": 727 | pass 728 | else: 729 | raise NotImplementedError 730 | return result 731 | elif is_regression: 732 | raise NotImplementedError 733 | 734 | else: 735 | raise NotImplementedError 736 | 737 | 738 | # schedule 739 | if sparse_args.lambda_schedule is not None: 740 | logger.info("*****Start lambda_2 scheduling***") 741 | from transformers import EarlyStoppingCallback 742 | for _ in range(sparse_args.lambda_num - 1): 743 | training_args.num_train_epochs = 15 744 | training_args.load_best_model_at_end = False 745 | sparse_optimizer.step_lambda() 746 | trainer = SparseTrainer( 747 | model=model, 748 | args=training_args, 749 | train_dataset=train_dataset if training_args.do_train else None, 750 | eval_dataset=[eval_dataset if training_args.do_eval else None, train_dataset_for_eval], 751 | compute_metrics=compute_metrics_in_schedule, 752 | tokenizer=tokenizer, 753 | data_collator=data_collator, 754 | optimizers = (optimizer, lr_scheduler), 755 | sparse_lambda = sparse_args.sparse_lambda, 756 | sparse_optimizer = (sparse_optimizer, sparse_scheduler), 757 | ) 758 | 759 | trainer.train() 760 | 761 | if training_args.do_predict: 762 | logger.info("*** Predict ***") 763 | 764 | # Loop to handle MNLI double evaluation (matched, mis-matched) 765 | tasks = [data_args.task_name] 766 | predict_datasets = [predict_dataset] 767 | 768 | 769 | for predict_dataset, task in zip(predict_datasets, tasks): 770 | metrics = trainer.evaluate(eval_dataset=predict_dataset, metric_key_prefix="test") 771 | 772 | max_eval_samples = ( 773 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 774 | ) 775 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 776 | 777 | 778 | trainer.log_metrics("test", metrics) 779 | 780 | trainer.save_metrics("test", metrics) 781 | 782 | 783 | sparse_param, total_param = compute_trainable_sparse_param(model) 784 | 785 | logger.info("***** Lambda=%f Final Model ******\nLora rank: %d\nNumber of trainable full param: %d\nNumber of trainable sparse param: %d, Ratio: %.4f%%\n**********" % (sparse_optimizer.sparse_lambda, lora_config.lora_r, total_param, sparse_param, sparse_param / total_param * 100)) 786 | 787 | 788 | 789 | def _mp_fn(index): 790 | # For xla_spawn (TPUs) 791 | main() 792 | 793 | 794 | if __name__ == "__main__": 795 | main() --------------------------------------------------------------------------------