├── .gitignore ├── icl.sh ├── large_exp.py ├── README.md ├── finetune_fsdp.sh ├── finetune.sh ├── HiZOO.sh ├── requirements.txt ├── metrics.py ├── lora.py ├── lr_scheduler.py ├── Hessian_smooth_scheduler.py ├── prefix.py ├── exp.py ├── templates.py ├── utils.py ├── tasks.py ├── run.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | ./__pycache__/* 2 | ./__pycache__ 3 | ./result/* 4 | ./result 5 | ./wandb/* 6 | ./wandb -------------------------------------------------------------------------------- /icl.sh: -------------------------------------------------------------------------------- 1 | MODEL=${MODEL:-facebook/opt-13b} 2 | 3 | python run.py --model_name $MODEL --task_name $TASK --output_dir result/tmp --tag icl --num_train 32 --num_eval 1000 --load_float16 --verbose "$@" 4 | -------------------------------------------------------------------------------- /large_exp.py: -------------------------------------------------------------------------------- 1 | import cupy as cp 2 | import time 3 | 4 | def allocate_memory_on_device(device_id, memory_size): 5 | cp.cuda.Device(device_id).use() 6 | memory_pool = cp.cuda.MemoryPool() 7 | cp.cuda.set_allocator(memory_pool.malloc) 8 | num_bytes = memory_size 9 | d_memory = cp.empty(num_bytes, dtype=cp.uint8) 10 | d_memory.fill(1) 11 | device = cp.cuda.Device() 12 | return d_memory 13 | 14 | def run_for_days(days): 15 | 16 | device_id0 = 6 17 | memory_size0 = 30 * 1024 * 1024 * 1024 18 | allocated_memory0 = allocate_memory_on_device(device_id0, memory_size0) 19 | ''' 20 | device_id1 = 0 21 | memory_size1 = 24 * 1024 * 1024 * 1024 22 | allocated_memory1 = allocate_memory_on_device(device_id1, memory_size1) 23 | 24 | device_id2 = 1 25 | memory_size2 = 27 * 1024 * 1024 * 1024 26 | allocated_memory2 = allocate_memory_on_device(device_id2, memory_size2) 27 | ''' 28 | try: 29 | 30 | seconds_to_run = days * 24 * 60 * 60 31 | start_time = time.time() 32 | current_time = start_time 33 | 34 | 35 | while (current_time - start_time) < seconds_to_run: 36 | time.sleep(3600) 37 | current_time = time.time() 38 | #print("Time's up. Exiting program...") 39 | 40 | finally: 41 | 42 | del allocated_memory0 43 | #del allocated_memory1 44 | #del allocated_memory2 45 | #cp.cuda.Device(device_id0).synchronize() 46 | #cp.cuda.Device(device_id1).synchronize() 47 | #cp.cuda.Device(device_id2).synchronize() 48 | 49 | 50 | #device_id = 0 51 | #memory_size = 1024 * 1024 # 1MB 52 | days_to_run = 7 53 | 54 | 55 | run_for_days(days_to_run) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Second-Order Fine-Tuning without Pain for LLMs: a Hessian Informed Zeroth-Order Optimizer([ICLR 2025](https://arxiv.org/pdf/2402.15173)) 2 | 3 | 4 | In this work, we propose a diagonal 5 | hessian-informed zeroth-order optimizer(HiZOO) 6 | without computing first-order or second-order 7 | derivatives. To our knowledge, this is the first 8 | work that leverages hessian to enhance zeroth-order optimizer for fine-tuning LLMs. What’s 9 | more, HiZOO avoids the heavy memory cost 10 | brought by backpropagation while only increases 11 | one forward pass per step. Extensive experiments 12 | on various models(350M∼66B parameters) indicate that HiZOO efficiently improves model convergence, reducing training steps and enhancing 13 | model accuracy. 14 | 15 | 16 | ## Installation 17 | 18 | ```bash 19 | conda create -n HiZOO python==3.9.19 20 | conda activate HiZOO 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | This environment can support the **OPT**, **LLaMA**, **Phi** and other latest models. 25 | 26 | ## Usage 27 | 28 | Use `run.py` for all functions (zero-shot/ICL/fine-tuning/MeZO/HiZOO): 29 | 30 | ```bash 31 | python run.py {ARGUMENTS} 32 | ``` 33 | 34 | Please read `run.py` for a complete list of arguments. 35 | 36 | We provide example script below for reproducing our experiments. All our examples sample 1,000 37 | training examples, 500 validation examples, and 1,000 testing examples. 38 | 39 | ```bash 40 | # HiZOO (full-parameter fine-tune OPT-13B on CB dataset) 41 | CUDA_VISIBLE_DEVICES=0 MODEL=facebook/opt-13b TASK=WSC MODE=ft LR=1e-6 EPS=1e-3 HESSIAN_SMOOTH_TYPE=constant1e-8 bash HiZOO.sh 42 | 43 | ``` 44 | 45 | ## Citation 46 | ```bash 47 | @article{zhao2024second, 48 | title={Second-order fine-tuning without pain for llms: A hessian informed zeroth-order optimizer}, 49 | author={Zhao, Yanjun and Dang, Sizhe and Ye, Haishan and Dai, Guang and Qian, Yi and Tsang, Ivor W}, 50 | journal={arXiv preprint arXiv:2402.15173}, 51 | year={2024} 52 | } 53 | ``` 54 | -------------------------------------------------------------------------------- /finetune_fsdp.sh: -------------------------------------------------------------------------------- 1 | MODEL=${MODEL:-facebook/opt-1.3b} 2 | MODEL_NAME=(${MODEL//\// }) 3 | MODEL_NAME="${MODEL_NAME[-1]}" 4 | 5 | EPOCH=${EPOCH:-5} 6 | BS=${BS:-8} 7 | LR=${LR:-1e-5} 8 | SEED=${SEED:-0} 9 | TRAIN=${TRAIN:-1000} 10 | DEV=${DEV:-500} 11 | EVAL=${EVAL:-1000} 12 | 13 | MODE=${MODE:-ft} 14 | EXTRA_ARGS="" 15 | if [ "$MODE" == "prefix" ]; then 16 | EXTRA_ARGS="--prefix_tuning --num_prefix 5 --no_reparam --prefix_init_by_real_act" 17 | elif [ "$MODE" == "lora" ]; then 18 | EXTRA_ARGS="--lora" 19 | fi 20 | TAG=fsdp-$MODE-$EPOCH-$BS-$LR-$SEED 21 | 22 | TASK_ARGS="" 23 | case $TASK in 24 | # For Copa, ReCoRD, SQuAD, DROP, we set --train_as_classification False; for others, set this flag to True 25 | CB) # It has <1000 training examples. Only use 100 for dev 26 | DEV=100 27 | ;; 28 | Copa) # It has <1000 training examples. Only use 100 for dev 29 | DEV=100 30 | TASK_ARGS="--train_as_classification False" 31 | ;; 32 | ReCoRD) 33 | TASK_ARGS="--train_as_classification False" 34 | ;; 35 | DROP) 36 | TASK_ARGS="--train_as_classification False" 37 | ;; 38 | SQuAD) 39 | TASK_ARGS="--train_as_classification False" 40 | ;; 41 | esac 42 | 43 | echo $TAG 44 | echo "EPOCH: $EPOCH" 45 | echo "BS (gradient accumulation): $BS" 46 | echo "LR: $LR" 47 | echo "SEED: $SEED" 48 | echo "MODE: $MODE" 49 | echo "Extra args: $EXTRA_ARGS $TASK_ARGS" 50 | 51 | OMP_NUM_THREADS=10 torchrun --nproc_per_node=$NUM_GPU --master_port=$(( RANDOM + 1000 )) run.py \ 52 | --model_name $MODEL \ 53 | --task_name $TASK \ 54 | --output_dir result/$TASK-${MODEL_NAME}-$TAG --tag $TAG --train_set_seed $SEED --num_train $TRAIN --num_dev $DEV --num_eval $EVAL --logging_steps 10 \ 55 | --trainer regular --fp16 --no_auto_device \ 56 | --learning_rate $LR --num_train_epochs $EPOCH --per_device_train_batch_size 1 --gradient_accumulation_steps $BS \ 57 | --load_best_model_at_end --evaluation_strategy epoch --save_strategy epoch --save_total_limit 1 \ 58 | --train_as_classification \ 59 | --fsdp "full_shard auto_wrap" \ 60 | --fsdp_transformer_layer_cls_to_wrap 'OPTDecoderLayer' \ 61 | $EXTRA_ARGS \ 62 | $TASK_ARGS \ 63 | "$@" 64 | -------------------------------------------------------------------------------- /finetune.sh: -------------------------------------------------------------------------------- 1 | MODEL=${MODEL:-facebook/opt-1.3b} 2 | MODEL_NAME=(${MODEL//\// }) 3 | MODEL_NAME="${MODEL_NAME[-1]}" 4 | 5 | EPOCH=${EPOCH:-5} 6 | BS=${BS:-8} 7 | LR=${LR:-1e-5} 8 | SEED=${SEED:-0} 9 | TRAIN=${TRAIN:-1000} 10 | DEV=${DEV:-500} 11 | EVAL=${EVAL:-1000} 12 | 13 | 14 | 15 | MODE=${MODE:-ft} 16 | EXTRA_ARGS="" 17 | if [ "$MODE" == "prefix" ]; then 18 | EXTRA_ARGS="--prefix_tuning --num_prefix 5 --no_reparam --prefix_init_by_real_act" 19 | elif [ "$MODE" == "lora" ]; then 20 | EXTRA_ARGS="--lora" 21 | fi 22 | TAG=$MODE-$EPOCH-$BS-$LR-$SEED- 23 | 24 | TASK_ARGS="" 25 | case $TASK in 26 | # For Copa, ReCoRD, SQuAD, DROP, we set --train_as_classification False; for others, set this flag to True 27 | CB) # It has <1000 training examples. Only use 100 for dev 28 | DEV=100 29 | ;; 30 | Copa) # It has <1000 training examples. Only use 100 for dev 31 | DEV=100 32 | TASK_ARGS="--train_as_classification False" 33 | ;; 34 | MultiRC) # Can only fit real bsz = 2 on 80G A100 35 | GA=$(expr $BS / 2) 36 | BS=2 37 | echo "Gradient accumulation: $GA" 38 | TASK_ARGS="--gradient_accumulation_steps $GA" 39 | ;; 40 | ReCoRD) # Can only fit real bsz = 2 on 80G A100 41 | GA=$(expr $BS / 2) 42 | BS=2 43 | echo "Gradient accumulation: $GA" 44 | TASK_ARGS="--gradient_accumulation_steps $GA --train_as_classification False" 45 | ;; 46 | DROP) # Can only fit real bsz = 1 on 80G A100 47 | GA=$(expr $BS / 1) 48 | BS=1 49 | echo "Gradient accumulation: $GA" 50 | TASK_ARGS="--gradient_accumulation_steps $GA --train_as_classification False" 51 | ;; 52 | SQuAD) 53 | TASK_ARGS="--train_as_classification False" 54 | ;; 55 | esac 56 | 57 | echo $TAG 58 | echo "EPOCH: $EPOCH" 59 | echo "BS: $BS" 60 | echo "LR: $LR" 61 | echo "SEED: $SEED" 62 | echo "MODE: $MODE" 63 | echo "Extra args: $EXTRA_ARGS $TASK_ARGS" 64 | 65 | python run.py \ 66 | --model_name $MODEL \ 67 | --task_name $TASK \ 68 | --output_dir result/$TASK-${MODEL_NAME}-$TAG --tag $TAG --train_set_seed $SEED --num_train $TRAIN --num_dev $DEV --num_eval $EVAL --logging_steps 10 \ 69 | --trainer regular --fp16 \ 70 | --learning_rate $LR --num_train_epochs $EPOCH --per_device_train_batch_size $BS \ 71 | --load_best_model_at_end --evaluation_strategy epoch --save_strategy epoch --save_total_limit 1 \ 72 | --train_as_classification \ 73 | $EXTRA_ARGS \ 74 | $TASK_ARGS \ 75 | "$@" 76 | -------------------------------------------------------------------------------- /HiZOO.sh: -------------------------------------------------------------------------------- 1 | MODEL=${MODEL:-opt-13b} 2 | MODEL_NAME=(${MODEL//\// }) 3 | MODEL_NAME="${MODEL_NAME[-1]}" 4 | 5 | BS=${BS:-16} 6 | LR=${LR:-1e-5} 7 | EPS=${EPS:-1e-3} 8 | SEED=${SEED:-0} 9 | TRAIN=${TRAIN:-1000} 10 | DEV=${DEV:-500} 11 | EVAL=${EVAL:-1000} 12 | STEPS=${STEPS:-20000} 13 | EVAL_STEPS=${EVAL_STEPS:-2000} 14 | WARMUP_STEP=${WARMUP_STEP:-0} 15 | DECAY_STEP=${DECAY_STEP:-0} 16 | ZO_LR_SCHEDULER_TYPE=${ZO_LR_SCHEDULER_TYPE:-'constant'} 17 | WEIGHT_DECAY=${WEIGHT_DECAY:-0} 18 | HESSIAN_SMOOTH_TYPE=${HESSIAN_SMOOTH_TYPE:-'constant0'} 19 | 20 | 21 | MODE=${MODE:-ft} 22 | EXTRA_ARGS="" 23 | if [ "$MODE" == "prefix" ]; then 24 | EXTRA_ARGS="--prefix_tuning --num_prefix 5 --no_reparam --prefix_init_by_real_act" 25 | elif [ "$MODE" == "lora" ]; then 26 | EXTRA_ARGS="--lora" 27 | fi 28 | TAG=mezo-$MODE-$STEPS-$BS-$LR-$EPS-$SEED-$HESSIAN_SMOOTH_TYPE 29 | 30 | TASK_ARGS="" 31 | case $TASK in 32 | # For Copa, ReCoRD, SQuAD, DROP, we set --train_as_classification False; for others, set this flag to True 33 | CB) # It has <1000 training examples. Only use 100 for dev 34 | DEV=100 35 | ;; 36 | Copa) # It has <1000 training examples. Only use 100 for dev 37 | DEV=100 38 | TASK_ARGS="--train_as_classification False" 39 | ;; 40 | ReCoRD) 41 | TASK_ARGS="--train_as_classification False" 42 | ;; 43 | DROP) 44 | TASK_ARGS="--train_as_classification False" 45 | ;; 46 | SQuAD) 47 | TASK_ARGS="--train_as_classification False" 48 | ;; 49 | esac 50 | 51 | echo $TAG 52 | echo "BS: $BS" 53 | echo "LR: $LR" 54 | echo "EPS: $EPS" 55 | echo "SEED: $SEED" 56 | echo "TRAIN/EVAL STEPS: $STEPS/$EVAL_STEPS" 57 | echo "MODE: $MODE" 58 | echo "Extra args: $EXTRA_ARGS $TASK_ARGS" 59 | 60 | python run.py \ 61 | --model_name $MODEL \ 62 | --task_name $TASK \ 63 | --output_dir result/$TASK-${MODEL_NAME}-$TAG --tag $TAG --train_set_seed $SEED --num_train $TRAIN --num_dev $DEV --num_eval $EVAL --logging_steps 10 \ 64 | --max_steps $STEPS \ 65 | --trainer zo --load_float16 \ 66 | --learning_rate $LR --zo_eps $EPS --per_device_train_batch_size $BS --lr_scheduler_type "constant" \ 67 | --load_best_model_at_end --evaluation_strategy steps --save_strategy steps --save_total_limit 1 \ 68 | --eval_steps $EVAL_STEPS --save_steps $EVAL_STEPS \ 69 | --warmup_step $WARMUP_STEP --decay_step $DECAY_STEP --zo_lr_scheduler_type $ZO_LR_SCHEDULER_TYPE \ 70 | --weight_decay $WEIGHT_DECAY --hessian_smooth_type $HESSIAN_SMOOTH_TYPE \ 71 | --train_as_classification \ 72 | $EXTRA_ARGS \ 73 | $TASK_ARGS \ 74 | "$@" 75 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.32.1 2 | aiohttp==3.9.5 3 | aiosignal==1.2.0 4 | annotated-types==0.7.0 5 | #apex==0.1 6 | asttokens==2.4.1 7 | async-timeout==4.0.3 8 | attrs==23.1.0 9 | #autoattack==0.1 10 | backcall==0.2.0 11 | beautifulsoup4==4.12.3 12 | bleach==6.1.0 13 | Bottleneck==1.3.7 14 | Brotli==1.0.9 15 | certifi==2024.2.2 16 | cffi==1.16.0 17 | charset-normalizer==2.0.4 18 | click==8.1.7 19 | comm==0.2.2 20 | contourpy==1.2.1 21 | cycler==0.12.1 22 | datasets==2.18.0 23 | decorator==5.1.1 24 | defusedxml==0.7.1 25 | dill==0.3.6 26 | docker-pycreds==0.4.0 27 | docopt==0.6.2 28 | einops==0.8.0 29 | eval_type_backport==0.2.2 30 | exceptiongroup==1.2.1 31 | executing==2.0.1 32 | fastjsonschema==2.20.0 33 | filelock==3.13.1 34 | flash-attention==1.0.0 35 | fonttools==4.52.4 36 | frozenlist==1.4.0 37 | fsspec==2024.2.0 38 | gdown==5.1.0 39 | geotorch==0.3.0 40 | gitdb==4.0.12 41 | GitPython==3.1.44 42 | gmpy2==2.1.2 43 | graphviz==0.20.3 44 | huggingface-hub==0.27.1 45 | idna==3.7 46 | importlib-metadata==7.0.1 47 | importlib_resources==6.4.0 48 | ipython==8.12.3 49 | ipywidgets==8.1.3 50 | jedi==0.19.1 51 | Jinja2==3.1.3 52 | joblib==1.4.0 53 | jsonschema==4.23.0 54 | jsonschema-specifications==2023.12.1 55 | jupyter_client==8.6.3 56 | jupyter_core==5.7.2 57 | jupyterlab_pygments==0.3.0 58 | jupyterlab_widgets==3.0.11 59 | kiwisolver==1.4.5 60 | loralib==0.1.2 61 | MarkupSafe==2.1.3 62 | matplotlib==3.9.0 63 | matplotlib-inline==0.1.7 64 | mistune==3.0.2 65 | mkl-fft==1.3.8 66 | mkl-random==1.2.4 67 | mkl-service==2.4.0 68 | mpmath==1.3.0 69 | multidict==6.0.4 70 | multiprocess==0.70.14 71 | nbclient==0.10.0 72 | nbconvert==7.16.4 73 | nbformat==5.10.4 74 | networkx==3.1 75 | numexpr==2.8.7 76 | numpy==1.26.4 77 | ordereddict==1.1 78 | packaging==23.2 79 | pandas==2.2.1 80 | pandocfilters==1.5.1 81 | parso==0.8.4 82 | pexpect==4.9.0 83 | pickleshare==0.7.5 84 | pillow==10.3.0 85 | pip==24.0 86 | pipreqs==0.5.0 87 | platformdirs==4.3.6 88 | prompt_toolkit==3.0.47 89 | protobuf==5.29.2 90 | psutil==6.0.0 91 | ptyprocess==0.7.0 92 | pure-eval==0.2.2 93 | pyarrow==16.1.0 94 | pyarrow-hotfix==0.6 95 | pycparser==2.21 96 | pydantic==2.10.4 97 | pydantic_core==2.27.2 98 | Pygments==2.18.0 99 | pynvml==11.5.3 100 | pyparsing==3.1.2 101 | PySocks==1.7.1 102 | python-dateutil==2.9.0.post0 103 | pytz==2024.1 104 | PyYAML==6.0.1 105 | pyzmq==26.2.0 106 | referencing==0.35.1 107 | regex==2023.10.3 108 | requests==2.32.3 109 | responses==0.13.3 110 | #robustbench==1.1 111 | rpds-py==0.20.0 112 | sacremoses==0.0.43 113 | safetensors==0.4.3 114 | scikit-learn==1.4.2 115 | scipy==1.13.0 116 | seaborn==0.13.2 117 | sentencepiece==0.2.0 118 | sentry-sdk==2.19.2 119 | setproctitle==1.3.4 120 | setuptools==69.5.1 121 | six==1.16.0 122 | smmap==5.0.2 123 | soupsieve==2.5 124 | stack-data==0.6.3 125 | sympy==1.12 126 | threadpoolctl==2.2.0 127 | timm==1.0.3 128 | tinycss2==1.3.0 129 | tokenizers==0.21.0 130 | torch==2.1.0 131 | torchaudio==2.1.0 132 | torchdiffeq==0.2.3 133 | torchvision==0.16.0 134 | torchviz==0.0.2 135 | tornado==6.4.1 136 | tqdm==4.66.4 137 | traitlets==5.14.3 138 | transformers==4.48.0 139 | triton==2.1.0 140 | typing_extensions==4.12.2 141 | tzdata==2023.3 142 | #unsloth==2025.1.5 143 | urllib3==2.2.1 144 | wandb==0.19.2 145 | wcwidth==0.2.13 146 | webencodings==0.5.1 147 | wheel==0.43.0 148 | widgetsnbextension==4.0.11 149 | xxhash==2.0.2 150 | yarg==0.1.9 151 | yarl==1.9.3 152 | zipp==3.17.0 -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import collections 3 | import re 4 | import string 5 | from collections import Counter 6 | 7 | def normalize_answer(s): 8 | """Lower text and remove punctuation, articles and extra whitespace.""" 9 | 10 | def remove_articles(text): 11 | return re.sub(r'\b(a|an|the)\b', ' ', text) 12 | 13 | def white_space_fix(text): 14 | return ' '.join(text.split()) 15 | 16 | def remove_punc(text): 17 | exclude = set(string.punctuation) 18 | return ''.join(ch for ch in text if ch not in exclude) 19 | 20 | def lower(text): 21 | return text.lower() 22 | 23 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 24 | 25 | 26 | def calculate_metric(predictions, metric_name): 27 | if metric_name == "accuracy": 28 | if isinstance(predictions[0].correct_candidate, list): 29 | return np.mean([pred.predicted_candidate in pred.correct_candidate for pred in predictions]) 30 | else: 31 | return np.mean([pred.correct_candidate == pred.predicted_candidate for pred in predictions]) 32 | elif metric_name == "em": 33 | # For question answering 34 | return np.mean([any([normalize_answer(ans) == normalize_answer(pred.predicted_candidate) for ans in pred.correct_candidate]) for pred in predictions]) 35 | elif metric_name == "f1": 36 | # For question answering 37 | f1 = [] 38 | for pred in predictions: 39 | all_f1s = [] 40 | if pred.correct_candidate[0] == "CANNOTANSWER" or pred.correct_candidate[0] == "no answer": 41 | f1.append(int(normalize_answer(pred.correct_candidate[0]) == normalize_answer(pred.predicted_candidate))) 42 | else: 43 | for ans in pred.correct_candidate: 44 | prediction_tokens = normalize_answer(pred.predicted_candidate).split() 45 | ground_truth_tokens = normalize_answer(ans).split() 46 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 47 | num_same = sum(common.values()) 48 | if num_same == 0: 49 | all_f1s.append(0) 50 | else: 51 | precision = 1.0 * num_same / len(prediction_tokens) 52 | recall = 1.0 * num_same / len(ground_truth_tokens) 53 | all_f1s.append((2 * precision * recall) / (precision + recall)) 54 | f1.append(max(all_f1s)) 55 | 56 | return np.mean(f1) 57 | 58 | 59 | def f1(pred, gold): 60 | """ 61 | This separate F1 function is used as non-differentiable metric for SQuAD 62 | """ 63 | if gold[0] == "CANNOTANSWER" or gold[0] == "no answer": 64 | return int(normalize_answer(gold[0]) == normalize_answer(pred)) 65 | else: 66 | all_f1s = [] 67 | for ans in gold: 68 | prediction_tokens = normalize_answer(pred).split() 69 | ground_truth_tokens = normalize_answer(ans).split() 70 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 71 | num_same = sum(common.values()) 72 | if num_same == 0: 73 | all_f1s.append(0) 74 | else: 75 | precision = 1.0 * num_same / len(prediction_tokens) 76 | recall = 1.0 * num_same / len(ground_truth_tokens) 77 | all_f1s.append((2 * precision * recall) / (precision + recall)) 78 | return np.max(all_f1s) -------------------------------------------------------------------------------- /lora.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 4 | logger = logging.getLogger(__name__) 5 | logger.setLevel(logging.INFO) 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | import math 11 | 12 | def find_module(root_module: nn.Module, key: str): 13 | """ 14 | Find a module with a specific name in a Transformer model 15 | From OpenDelta https://github.com/thunlp/OpenDelta 16 | """ 17 | sub_keys = key.split(".") 18 | parent_module = root_module 19 | for sub_key in sub_keys[:-1]: 20 | parent_module = getattr(parent_module, sub_key) 21 | module = getattr(parent_module, sub_keys[-1]) 22 | return parent_module, sub_keys[-1], module 23 | 24 | 25 | class LoRALinear(nn.Linear): 26 | """ 27 | LoRA implemented in a dense layer 28 | From https://github.com/microsoft/LoRA/blob/main/loralib/layers.py 29 | """ 30 | def __init__( 31 | self, 32 | in_features: int, 33 | out_features: int, 34 | r: int = 0, 35 | lora_alpha: int = 1, 36 | lora_dropout: float = 0., 37 | fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) 38 | merge_weights: bool = False, # Not sure if this will affect saving/loading models so just set it to be False 39 | **kwargs 40 | ): 41 | nn.Linear.__init__(self, in_features, out_features, **kwargs) 42 | 43 | self.r = r 44 | self.lora_alpha = lora_alpha 45 | # Optional dropout 46 | if lora_dropout > 0.: 47 | self.lora_dropout = nn.Dropout(p=lora_dropout) 48 | else: 49 | self.lora_dropout = lambda x: x 50 | # Mark the weight as unmerged 51 | self.merged = False 52 | self.merge_weights = merge_weights 53 | self.fan_in_fan_out = fan_in_fan_out 54 | # Actual trainable parameters 55 | if r > 0: 56 | self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) 57 | self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r))) 58 | self.scaling = self.lora_alpha / self.r 59 | # Freezing the pre-trained weight matrix 60 | self.weight.requires_grad = False 61 | self.reset_parameters() 62 | if fan_in_fan_out: 63 | self.weight.data = self.weight.data.transpose(0, 1) 64 | 65 | def reset_parameters(self): 66 | nn.Linear.reset_parameters(self) 67 | if hasattr(self, 'lora_A'): 68 | # initialize A the same way as the default for nn.Linear and B to zero 69 | nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) 70 | nn.init.zeros_(self.lora_B) 71 | 72 | def train(self, mode: bool = True): 73 | def T(w): 74 | return w.transpose(0, 1) if self.fan_in_fan_out else w 75 | nn.Linear.train(self, mode) 76 | if mode: 77 | if self.merge_weights and self.merged: 78 | # Make sure that the weights are not merged 79 | if self.r > 0: 80 | self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling 81 | self.merged = False 82 | else: 83 | if self.merge_weights and not self.merged: 84 | # Merge the weights and mark it 85 | if self.r > 0: 86 | self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling 87 | self.merged = True 88 | 89 | def forward(self, x: torch.Tensor): 90 | def T(w): 91 | return w.transpose(0, 1) if self.fan_in_fan_out else w 92 | if self.r > 0 and not self.merged: 93 | result = F.linear(x, T(self.weight), bias=self.bias) 94 | if self.r > 0: 95 | result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling 96 | return result 97 | else: 98 | return F.linear(x, T(self.weight), bias=self.bias) 99 | 100 | 101 | class LoRA: 102 | 103 | def __init__(self, model, r, alpha, float16): 104 | """ 105 | Input: 106 | r, alpha: LoRA hyperparameters 107 | float16: Whether the model parameters are float16 or not 108 | """ 109 | 110 | self.model = model 111 | self.hidden_dim = model.config.hidden_size 112 | self.float16 = float16 113 | 114 | if model.config.model_type == "opt": 115 | attention_name = "attn" 116 | elif model.config.model_type == "roberta": 117 | attention_name = "attention" 118 | else: 119 | raise NotImplementedError 120 | 121 | # Insert LoRA 122 | for key, _ in model.named_modules(): 123 | if key[-len(attention_name):] == attention_name: 124 | logger.info(f"Inject lora to: {key}") 125 | _, _, attn = find_module(model, key) 126 | 127 | if model.config.model_type == "opt": 128 | original_q_weight = attn.q_proj.weight.data 129 | original_q_bias = attn.q_proj.bias.data 130 | original_v_weight= attn.v_proj.weight.data 131 | original_v_bias = attn.v_proj.bias.data 132 | attn.q_proj = LoRALinear(model.config.hidden_size, model.config.hidden_size, r=r, lora_alpha=alpha, bias=model.config.enable_bias).to(original_q_weight.device) 133 | attn.v_proj = LoRALinear(model.config.hidden_size, model.config.hidden_size, r=r, lora_alpha=alpha, bias=model.config.enable_bias).to(original_v_weight.device) 134 | if float16: 135 | attn.q_proj.half() 136 | attn.v_proj.half() 137 | attn.q_proj.weight.data = original_q_weight 138 | attn.q_proj.bias.data = original_q_bias 139 | attn.v_proj.weight.data = original_v_weight 140 | attn.v_proj.bias.data = original_v_bias 141 | else: 142 | raise NotImplementedError 143 | 144 | # Freeze non-LoRA parameters 145 | for n, p in model.named_parameters(): 146 | if "lora" not in n: 147 | p.requires_grad = False -------------------------------------------------------------------------------- /lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import warnings 19 | from functools import partial 20 | from typing import Callable, Iterable, Optional, Tuple, Union 21 | 22 | import torch 23 | from torch import nn 24 | #from torch.optim import Optimizer 25 | #from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau 26 | 27 | from transformers.utils import logging 28 | from transformers.utils.versions import require_version 29 | 30 | 31 | logger = logging.get_logger(__name__) 32 | 33 | 34 | def get_constant_schedule(learning_rate, name, num_warmup_steps, num_decay_steps, current_step, num_training_steps): 35 | 36 | return learning_rate 37 | 38 | 39 | def get_reduce_on_plateau_schedule(learning_rate, name, num_warmup_steps, num_decay_steps, current_step, num_training_steps): 40 | 41 | #to be updated 42 | return learning_rate 43 | 44 | def _get_inverse_sqrt_schedule_lr_lambda(current_step: int, *, num_warmup_steps: int, timescale: int = None): 45 | if current_step < num_warmup_steps: 46 | return float(current_step) / float(max(1, num_warmup_steps)) 47 | shift = timescale - num_warmup_steps 48 | decay = 1.0 / math.sqrt((current_step + shift) / timescale) 49 | return decay 50 | 51 | 52 | def get_inverse_sqrt_schedule(learning_rate, name, num_warmup_steps, num_decay_steps, current_step, num_training_steps): 53 | #to be updated 54 | return learning_rate 55 | 56 | def get_constant_schedule_with_warmup(learning_rate, name, num_warmup_steps, num_decay_steps, current_step, num_training_steps): 57 | 58 | if current_step < num_warmup_steps: 59 | return float(current_step) / float(max(1.0, num_warmup_steps)) * learning_rate 60 | 61 | return learning_rate 62 | 63 | 64 | def get_linear_schedule_with_warmup(learning_rate, name, num_warmup_steps, num_decay_steps, current_step, num_training_steps): 65 | 66 | 67 | if current_step < num_warmup_steps: 68 | return float(current_step) / float(max(1, num_warmup_steps)) * learning_rate 69 | return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) * learning_rate 70 | 71 | 72 | 73 | 74 | def get_cosine_schedule_with_warmup(learning_rate, name, num_warmup_steps, num_decay_steps, current_step, num_training_steps): 75 | 76 | num_cycles=2 77 | if current_step < num_warmup_steps: 78 | return float(current_step) / float(max(1, num_warmup_steps)) * learning_rate 79 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 80 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) * learning_rate 81 | 82 | 83 | 84 | def get_cosine_with_hard_restarts_schedule_with_warmup(learning_rate, name, num_warmup_steps, num_decay_steps, current_step, num_training_steps): 85 | num_cycles=2 86 | if current_step < num_warmup_steps: 87 | return float(current_step) / float(max(1, num_warmup_steps)) * learning_rate 88 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 89 | if progress >= 1.0: 90 | return 0.0 * learning_rate 91 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) * learning_rate 92 | 93 | 94 | 95 | 96 | def get_polynomial_decay_schedule_with_warmup(learning_rate, name, num_warmup_steps, num_decay_steps, current_step, num_training_steps): 97 | 98 | lr_init = learning_rate 99 | lr_end = 1e-10 100 | power = 3 101 | if current_step < num_warmup_steps: 102 | return float(current_step) / float(max(1, num_warmup_steps)) * learning_rate 103 | elif current_step > num_training_steps: 104 | return lr_end / lr_init * learning_rate # as LambdaLR multiplies by lr_init 105 | else: 106 | lr_range = lr_init - lr_end 107 | decay_steps = num_training_steps - num_warmup_steps 108 | pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps 109 | decay = lr_range * pct_remaining**power + lr_end 110 | return decay / lr_init * learning_rate # as LambdaLR multiplies by lr_init 111 | 112 | 113 | 114 | def get_polynomial_decay_schedule(learning_rate, name, num_warmup_steps, num_decay_steps, current_step, num_training_steps): 115 | 116 | lr_init = learning_rate 117 | lr_end = 1e-10 118 | power = 3 119 | if current_step > num_training_steps: 120 | return lr_end / lr_init * learning_rate # as LambdaLR multiplies by lr_init 121 | else: 122 | lr_range = lr_init - lr_end 123 | decay_steps = num_training_steps - 0 124 | pct_remaining = 1 - (current_step - 0) / decay_steps 125 | decay = lr_range * pct_remaining**power + lr_end 126 | return decay / lr_init * learning_rate # as LambdaLR multiplies by lr_init 127 | 128 | def get_constant_polynomial_decay_schedule(learning_rate, name, num_warmup_steps, num_decay_steps, current_step, num_training_steps): 129 | lr_init = learning_rate 130 | lr_end = 1e-10 131 | power = 3 132 | if current_step < num_decay_steps: 133 | return learning_rate 134 | elif current_step > num_training_steps: 135 | return lr_end / lr_init * learning_rate # as LambdaLR multiplies by lr_init 136 | else: 137 | lr_range = lr_init - lr_end 138 | decay_steps = num_training_steps - num_decay_steps 139 | pct_remaining = 1 - (current_step - num_decay_steps) / decay_steps 140 | decay = lr_range * pct_remaining**power + lr_end 141 | return decay / lr_init * learning_rate / 100 # as LambdaLR multiplies by lr_init 142 | 143 | def get_constants_schedule(learning_rate, name, num_warmup_steps, num_decay_steps, current_step, num_training_steps): 144 | if current_step < num_decay_steps: 145 | return learning_rate 146 | else: 147 | return learning_rate/10 148 | 149 | 150 | 151 | def get_constant_polynomial_decay_schedule_with_warmup(learning_rate, name, num_warmup_steps, num_decay_steps, current_step, num_training_steps): 152 | lr_init = learning_rate 153 | lr_end = 1e-7 154 | power = 3 155 | if current_step < num_warmup_steps: 156 | return float(current_step) / float(max(1, num_warmup_steps)) * learning_rate 157 | elif current_step >= num_warmup_steps and current_step < num_decay_steps: 158 | return learning_rate 159 | elif current_step > num_training_steps: 160 | return lr_end / lr_init * learning_rate # as LambdaLR multiplies by lr_init 161 | else: 162 | lr_range = lr_init - lr_end 163 | decay_steps = num_training_steps - num_decay_steps 164 | pct_remaining = 1 - (current_step - num_decay_steps) / decay_steps 165 | decay = lr_range * pct_remaining**power + lr_end 166 | return decay / lr_init * learning_rate # as LambdaLR multiplies by lr_init 167 | 168 | 169 | 170 | 171 | TYPE_TO_SCHEDULER_FUNCTION = { 172 | 'linear_with_warmup': get_linear_schedule_with_warmup, 173 | 'cosine_with_warmup': get_cosine_schedule_with_warmup, 174 | 'cosine_with_hard_restarts_with_warmup': get_cosine_with_hard_restarts_schedule_with_warmup, 175 | 'polynomial_decay_with_warmup': get_polynomial_decay_schedule_with_warmup, 176 | 'constant': get_constant_schedule, 177 | 'constant_with_warmup': get_constant_schedule_with_warmup, 178 | 'inverse_sqrt': get_inverse_sqrt_schedule, 179 | 'reduce_on_plateau': get_reduce_on_plateau_schedule, 180 | 'constant_polynomial_decay' : get_constant_polynomial_decay_schedule, 181 | 'constant_polynomial_decay_with_warmup' : get_constant_polynomial_decay_schedule_with_warmup, 182 | 'polynomial_decay' : get_polynomial_decay_schedule, 183 | 'constants': get_constants_schedule, 184 | 185 | } 186 | 187 | 188 | 189 | def zo_lr_scheduler(learning_rate, name, num_warmup_steps, num_decay_steps, current_step, num_training_steps): 190 | 191 | schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] 192 | 193 | return schedule_func(learning_rate, name, num_warmup_steps, num_decay_steps, current_step, num_training_steps) 194 | -------------------------------------------------------------------------------- /Hessian_smooth_scheduler.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import warnings 4 | from functools import partial 5 | from typing import Callable, Iterable, Optional, Tuple, Union 6 | 7 | import torch 8 | from torch import nn 9 | #from torch.optim import Optimizer 10 | #from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau 11 | 12 | from transformers.utils import logging 13 | from transformers.utils.versions import require_version 14 | 15 | 16 | logger = logging.get_logger(__name__) 17 | 18 | 19 | 20 | 21 | def get_reduce_on_plateau_schedule(learning_rate, name, num_warmup_steps, num_decay_steps, current_step, num_training_steps): 22 | 23 | #to be updated 24 | return learning_rate 25 | 26 | def _get_inverse_sqrt_schedule_lr_lambda(current_step: int, *, num_warmup_steps: int, timescale: int = None): 27 | if current_step < num_warmup_steps: 28 | return float(current_step) / float(max(1, num_warmup_steps)) 29 | shift = timescale - num_warmup_steps 30 | decay = 1.0 / math.sqrt((current_step + shift) / timescale) 31 | return decay 32 | 33 | 34 | def get_inverse_sqrt_schedule(learning_rate, name, num_warmup_steps, num_decay_steps, current_step, num_training_steps): 35 | #to be updated 36 | return learning_rate 37 | 38 | def get_constant_schedule_with_warmup(learning_rate, name, num_warmup_steps, num_decay_steps, current_step, num_training_steps): 39 | 40 | if current_step < num_warmup_steps: 41 | return float(current_step) / float(max(1.0, num_warmup_steps)) * learning_rate 42 | 43 | return learning_rate 44 | 45 | 46 | def get_linear_schedule_with_warmup(learning_rate, name, num_warmup_steps, num_decay_steps, current_step, num_training_steps): 47 | 48 | 49 | if current_step < num_warmup_steps: 50 | return float(current_step) / float(max(1, num_warmup_steps)) * learning_rate 51 | return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) * learning_rate 52 | 53 | 54 | 55 | 56 | def get_cosine_schedule_with_warmup(learning_rate, name, num_warmup_steps, num_decay_steps, current_step, num_training_steps): 57 | 58 | num_cycles=2 59 | if current_step < num_warmup_steps: 60 | return float(current_step) / float(max(1, num_warmup_steps)) * learning_rate 61 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 62 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) * learning_rate 63 | 64 | 65 | 66 | def get_cosine_with_hard_restarts_schedule_with_warmup(learning_rate, name, num_warmup_steps, num_decay_steps, current_step, num_training_steps): 67 | num_cycles=2 68 | if current_step < num_warmup_steps: 69 | return float(current_step) / float(max(1, num_warmup_steps)) * learning_rate 70 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 71 | if progress >= 1.0: 72 | return 0.0 * learning_rate 73 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) * learning_rate 74 | 75 | 76 | 77 | 78 | def get_polynomial_decay_schedule_with_warmup(learning_rate, name, num_warmup_steps, num_decay_steps, current_step, num_training_steps): 79 | 80 | lr_init = learning_rate 81 | lr_end = 1e-10 82 | power = 3 83 | if current_step < num_warmup_steps: 84 | return float(current_step) / float(max(1, num_warmup_steps)) * learning_rate 85 | elif current_step > num_training_steps: 86 | return lr_end / lr_init * learning_rate # as LambdaLR multiplies by lr_init 87 | else: 88 | lr_range = lr_init - lr_end 89 | decay_steps = num_training_steps - num_warmup_steps 90 | pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps 91 | decay = lr_range * pct_remaining**power + lr_end 92 | return decay / lr_init * learning_rate # as LambdaLR multiplies by lr_init 93 | 94 | def get_polynomial_decay_schedule(learning_rate, name, num_warmup_steps, num_decay_steps, current_step, num_training_steps): 95 | 96 | lr_init = learning_rate 97 | lr_end = 1e-10 98 | power = 3 99 | if current_step > num_training_steps: 100 | return lr_end / lr_init * learning_rate # as LambdaLR multiplies by lr_init 101 | else: 102 | lr_range = lr_init - lr_end 103 | decay_steps = num_training_steps - 0 104 | pct_remaining = 1 - (current_step - 0) / decay_steps 105 | decay = lr_range * pct_remaining**power + lr_end 106 | return decay / lr_init * learning_rate # as LambdaLR multiplies by lr_init 107 | 108 | def get_constant_polynomial_decay_schedule(learning_rate, name, num_warmup_steps, num_decay_steps, current_step, num_training_steps): 109 | lr_init = learning_rate 110 | lr_end = 1e-10 111 | power = 3 112 | if current_step < num_decay_steps: 113 | return learning_rate 114 | elif current_step > num_training_steps: 115 | return lr_end / lr_init * learning_rate # as LambdaLR multiplies by lr_init 116 | else: 117 | lr_range = lr_init - lr_end 118 | decay_steps = num_training_steps - num_decay_steps 119 | pct_remaining = 1 - (current_step - num_decay_steps) / decay_steps 120 | decay = lr_range * pct_remaining**power + lr_end 121 | return decay / lr_init * learning_rate / 100 # as LambdaLR multiplies by lr_init 122 | 123 | 124 | 125 | def get_constant_decay1_schedule(current_step, num_training_steps): 126 | if current_step < 9800: 127 | return 1e-6 128 | else: 129 | return 1e-8 130 | 131 | 132 | def get_constant6_schedule(current_step, num_training_steps): 133 | 134 | return 1e-6 135 | 136 | def get_constant8_schedule(current_step, num_training_steps): 137 | 138 | return 1e-8 139 | 140 | def get_constant10_schedule(current_step, num_training_steps): 141 | 142 | return 1e-10 143 | 144 | def get_constant12_schedule(current_step, num_training_steps): 145 | 146 | return 1e-12 147 | 148 | def get_constant4_schedule(current_step, num_training_steps): 149 | 150 | return 1e-4 151 | 152 | def get_constant2_schedule(current_step, num_training_steps): 153 | 154 | return 1e-2 155 | 156 | def get_constant0_schedule(current_step, num_training_steps): 157 | 158 | return 0 159 | 160 | def get_constant_schedule_with_warmup(current_step, num_training_steps): 161 | 162 | if current_step < num_warmup_steps: 163 | return float(current_step) / float(max(1.0, num_warmup_steps)) * 1e-6 164 | 165 | return 0 166 | 167 | def get_constant_polynomial_decay_schedule_with_warmup(learning_rate, name, num_warmup_steps, num_decay_steps, current_step, num_training_steps): 168 | lr_init = learning_rate 169 | lr_end = 1e-7 170 | power = 3 171 | if current_step < num_warmup_steps: 172 | return float(current_step) / float(max(1, num_warmup_steps)) * learning_rate 173 | elif current_step >= num_warmup_steps and current_step < num_decay_steps: 174 | return learning_rate 175 | elif current_step > num_training_steps: 176 | return lr_end / lr_init * learning_rate # as LambdaLR multiplies by lr_init 177 | else: 178 | lr_range = lr_init - lr_end 179 | decay_steps = num_training_steps - num_decay_steps 180 | pct_remaining = 1 - (current_step - num_decay_steps) / decay_steps 181 | decay = lr_range * pct_remaining**power + lr_end 182 | return decay / lr_init * learning_rate # as LambdaLR multiplies by lr_init 183 | 184 | 185 | 186 | 187 | TYPE_TO_SCHEDULER_FUNCTION = { 188 | 'linear_with_warmup': get_linear_schedule_with_warmup, 189 | 'cosine_with_warmup': get_cosine_schedule_with_warmup, 190 | 'cosine_with_hard_restarts_with_warmup': get_cosine_with_hard_restarts_schedule_with_warmup, 191 | 'polynomial_decay_with_warmup': get_polynomial_decay_schedule_with_warmup, 192 | 'inverse_sqrt': get_inverse_sqrt_schedule, 193 | 'reduce_on_plateau': get_reduce_on_plateau_schedule, 194 | 'constant_polynomial_decay' : get_constant_polynomial_decay_schedule, 195 | 'constant_polynomial_decay_with_warmup' : get_constant_polynomial_decay_schedule_with_warmup, 196 | 'polynomial_decay' : get_polynomial_decay_schedule, 197 | 'constant0': get_constant0_schedule, 198 | 'constant1e-6': get_constant6_schedule, 199 | 'constant1e-8': get_constant8_schedule, 200 | 'constant1e-10': get_constant10_schedule, 201 | 'constant1e-12': get_constant12_schedule, 202 | 'constant1e-2': get_constant2_schedule, 203 | 'constant1e-4': get_constant4_schedule, 204 | 'constant_with_warmup': get_constant_schedule_with_warmup, 205 | 'constant_decay1': get_constant_decay1_schedule, 206 | 207 | } 208 | 209 | def Hessian_smooth_scheduler(Hessian_smooth_type, current_step, num_training_steps): 210 | schedule_func = TYPE_TO_SCHEDULER_FUNCTION[Hessian_smooth_type] 211 | 212 | return schedule_func(current_step, num_training_steps) -------------------------------------------------------------------------------- /prefix.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 4 | logger = logging.getLogger(__name__) 5 | logger.setLevel(logging.INFO) 6 | 7 | import torch 8 | from torch import nn 9 | 10 | def find_module(root_module: nn.Module, key: str): 11 | """ 12 | Find a module with a specific name in a Transformer model 13 | From OpenDelta https://github.com/thunlp/OpenDelta 14 | """ 15 | sub_keys = key.split(".") 16 | parent_module = root_module 17 | for sub_key in sub_keys[:-1]: 18 | parent_module = getattr(parent_module, sub_key) 19 | module = getattr(parent_module, sub_keys[-1]) 20 | return parent_module, sub_keys[-1], module 21 | 22 | 23 | def attn_forward_hook(self, *args, **kwargs): 24 | """ 25 | Replace the original attention forward with this to enable prefix 26 | """ 27 | 28 | def _expand_bsz(x, bsz): 29 | x = x.reshape(x.size(0), self.num_heads, -1).transpose(0,1) # (num_prefix, hidden) -> (num_head, num_prefix, hidden/num_head) 30 | x = x.unsqueeze(0).expand(bsz, *x.shape) # -> (bsz, num_head, num_prefix, hidden/num_head) 31 | return x 32 | 33 | if "hidden_states" in kwargs: 34 | hidden_states = kwargs["hidden_states"] 35 | else: 36 | hidden_states = args[0] 37 | bsz = hidden_states.size(0) 38 | 39 | if 'past_key_value' not in kwargs or kwargs['past_key_value'] is None: 40 | if self.reparam: 41 | prefix_keys = self.prefix_mlp_keys(self.prefix_input_embeds) 42 | prefix_values = self.prefix_mlp_values(self.prefix_input_embeds) 43 | else: 44 | prefix_keys, prefix_values = self.prefix_keys, self.prefix_values 45 | kwargs['past_key_value'] = (_expand_bsz(prefix_keys, bsz), _expand_bsz(prefix_values, bsz)) 46 | 47 | if 'attention_mask' in kwargs and kwargs['attention_mask'] is not None: 48 | am = kwargs['attention_mask'] 49 | kwargs['attention_mask'] = torch.cat([-torch.zeros((*am.shape[:-1], self.num_prefix), dtype=am.dtype, device=am.device), am], dim=-1) 50 | elif len(args) > 1: # attention mask is passed via positional argument 51 | am = args[1] 52 | am = torch.cat([-torch.zeros((*am.shape[:-1], self.num_prefix), dtype=am.dtype, device=am.device), am], dim=-1) 53 | args = (args[0], am) + args[2:] 54 | 55 | return self.original_forward(*args, **kwargs) 56 | 57 | 58 | def prepare_inputs_for_generation( 59 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs): 60 | """ 61 | Replace the original "prepare_inputs_for_generation" with this to pass prefix correctly 62 | """ 63 | original_input_len = input_ids.size(-1) 64 | if past_key_values: 65 | input_ids = input_ids[:, -1:] 66 | 67 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 68 | if inputs_embeds is not None and past_key_values is None: 69 | model_inputs = {"inputs_embeds": inputs_embeds} 70 | else: 71 | model_inputs = {"input_ids": input_ids} 72 | 73 | if past_key_values is not None: 74 | # Check if we should add extra to attention mask 75 | if past_key_values[0][0].size(2) != attention_mask.size(1) - 1: 76 | num_prefix = past_key_values[0][0].size(2) - (attention_mask.size(1) - 1) 77 | attention_mask = torch.cat([torch.ones((attention_mask.size(0), num_prefix), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask], dim=-1) 78 | 79 | model_inputs.update( 80 | { 81 | "past_key_values": past_key_values, 82 | "use_cache": kwargs.get("use_cache"), 83 | "attention_mask": attention_mask, 84 | } 85 | ) 86 | return model_inputs 87 | 88 | 89 | class PrefixTuning: 90 | 91 | def __init__(self, model, num_prefix, reparam=True, embed_dim=512, mid_dim=512, float16=False, init_by_real_act=False): 92 | """ 93 | Inputs: 94 | num_prefix: number of prefix tokens 95 | reparam: use reparameterization trick (not used in MeZO) 96 | embed_dim, mid_dim: hyperparameters for reparameterization trick (not used in MeZO) 97 | float15: whether the model parameters are float15 98 | init_by_real_act: init prefix tokens by real activations 99 | """ 100 | 101 | self.model = model 102 | self.num_prefix = num_prefix 103 | self.hidden_dim = model.config.hidden_size 104 | self.float16 = float16 105 | 106 | # Reparameterization 107 | self.reparam = reparam 108 | self.embed_dim = embed_dim 109 | self.mid_dim = mid_dim 110 | 111 | input_embeds = None # For reparameterization 112 | if model.config.model_type == "opt": 113 | attention_name = "attn" 114 | first_layer_name = "layers.0" 115 | layer_name = "layers." 116 | elif model.config.model_type == "roberta": 117 | attention_name = "attention" 118 | first_layer_name = "layer.0" 119 | layer_name = "layer." 120 | else: 121 | raise NotImplementedError 122 | 123 | if init_by_real_act: 124 | # Initialize prefix with real words' activations 125 | assert not reparam 126 | 127 | # Randomly sample input tokens 128 | input_tokens = torch.randint(low=0, high=model.config.vocab_size, size=(1, num_prefix), dtype=torch.long).cuda() 129 | if model.config.model_type == "opt": 130 | with torch.no_grad(): 131 | # Get the real activations 132 | real_key_values = model(input_ids=input_tokens, use_cache=True).past_key_values 133 | else: 134 | raise NotImplementedError 135 | 136 | # Insert prefix 137 | for key, _ in model.named_modules(): 138 | if key[-len(attention_name):] == attention_name: 139 | layer_id = int(key.split(layer_name)[1].split(".")[0]) 140 | logger.info(f"Inject prefix to: {key}") 141 | _, _, attn = find_module(model, key) 142 | 143 | # Replace the old forward functions 144 | attn.original_forward = attn.forward 145 | attn.forward = attn_forward_hook.__get__(attn, type(attn)) 146 | if not hasattr(attn, "num_heads"): 147 | attn.num_heads = model.config.num_attention_heads 148 | first = first_layer_name in key 149 | self.add_prefix(attn, first=first, input_embeds=input_embeds) 150 | 151 | if first and self.reparam: 152 | input_embeds = attn.prefix_input_embeds 153 | if init_by_real_act: 154 | logger.info(f"Reinitialize with actual activation: {key} (layer {layer_id})") 155 | keys = real_key_values[layer_id][0].squeeze(0).transpose(0, 1).reshape(num_prefix, -1) 156 | values = real_key_values[layer_id][1].squeeze(0).transpose(0, 1).reshape(num_prefix, -1) 157 | attn.prefix_keys.data = keys.to(attn.prefix_keys.data.device) 158 | attn.prefix_values.data = values.to(attn.prefix_values.data.device) 159 | 160 | # Freeze non-prefix parameters 161 | for n, p in model.named_parameters(): 162 | if "prefix" not in n: 163 | p.requires_grad = False 164 | 165 | # Replace the old prepare_inputs_for_generation function 166 | model.prepare_inputs_for_generation = prepare_inputs_for_generation.__get__(model, type(model)) 167 | 168 | 169 | def add_prefix(self, module, first, input_embeds=None): 170 | device = module.k_proj.weight.data.device 171 | module.num_prefix = self.num_prefix 172 | module.reparam = self.reparam 173 | if self.reparam: 174 | if first: 175 | # For the first layer we inject the embeddings 176 | logger.info("For prefix+reparameterization, inject the embeddings in the first layer.") 177 | module.prefix_input_embeds = nn.Parameter(torch.randn(self.num_prefix, self.embed_dim, device=device, dtype=self.model.dtype), requires_grad=True) 178 | else: 179 | assert input_embeds is not None 180 | module.prefix_input_embeds = input_embeds 181 | module.prefix_mlp_keys = nn.Sequential( 182 | nn.Linear(self.embed_dim, self.mid_dim), 183 | nn.Tanh(), 184 | nn.Linear(self.mid_dim, self.hidden_dim) 185 | ).to(device) 186 | module.prefix_mlp_values = nn.Sequential( 187 | nn.Linear(self.embed_dim, self.mid_dim), 188 | nn.Tanh(), 189 | nn.Linear(self.mid_dim, self.hidden_dim) 190 | ).to(device) 191 | if self.float16: 192 | module.prefix_mlp_keys = module.prefix_mlp_keys.half() 193 | module.prefix_mlp_values = module.prefix_mlp_values.half() 194 | else: 195 | module.prefix_keys = nn.Parameter(torch.randn(self.num_prefix, self.hidden_dim, device=device, dtype=self.model.dtype), requires_grad=True) 196 | module.prefix_values = nn.Parameter(torch.randn(self.num_prefix, self.hidden_dim, device=device, dtype=self.model.dtype), requires_grad=True) 197 | -------------------------------------------------------------------------------- /exp.py: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torch.utils.data import Dataset, DataLoader 8 | import numpy as np 9 | from scipy.signal import hilbert 10 | from scipy.fft import fft, ifft 11 | import matplotlib.pyplot as plt 12 | 13 | 14 | # Signal parameters 15 | dt = 0.002 16 | Nt = 500 17 | time = np.arange(Nt) * dt 18 | 19 | # Frequency axis 20 | df = 0.5 21 | Nf = 201 22 | fre_focus = np.arange(Nf) * df 23 | 24 | # Window size 25 | sigma_t = 0.03 26 | 27 | # Function to compute WFT 28 | def WFT_wxk(s, dt, fre_focus, sigma_t): 29 | time = np.arange(len(s)) * dt 30 | s = hilbert(s) 31 | s_fre = fft(s) 32 | 33 | alpha = 1 / (2 * sigma_t ** 2) 34 | Nt = len(s) 35 | fre = np.concatenate((np.arange(0, Nt // 2), np.arange(-Nt // 2, 0))) / Nt / dt * 2 * np.pi 36 | 37 | WFT_G = np.zeros((len(fre_focus), Nt), dtype=complex) 38 | 39 | for mf, f in enumerate(fre_focus): 40 | G = np.exp(-(fre - f * 2 * np.pi) ** 2 / (4 * alpha)) 41 | G = G / np.sqrt(alpha / np.pi) 42 | WFT_G[mf, :] = ifft(s_fre * G) 43 | 44 | return WFT_G 45 | 46 | def WFTI_wxk(WFT_G, dt, fre_focus, sigma_t): 47 | Nf, Nt = WFT_G.shape 48 | alpha = 1 / (2 * sigma_t**2) 49 | time = np.arange(Nt) * dt 50 | WFTI_s = np.zeros(Nt, dtype=complex) 51 | for mf in range(Nf): 52 | G = np.sqrt(np.pi / alpha) * np.exp(-(time - fre_focus[mf] / (2 * np.pi))**2 / (4 * alpha)) 53 | WFTI_s += fft(WFT_G[mf, :]) * G 54 | WFTI_s = ifft(WFTI_s) 55 | return WFTI_s 56 | 57 | 58 | 59 | # Generate data and labels for 10 samples 60 | num_samples = 1500 61 | all_data_WFT = [] 62 | all_signal_WFT = [] 63 | 64 | for _ in range(num_samples): 65 | signal = np.cos(2 * np.pi * 25 * time) + np.cos(2 * np.pi * 50 * time) 66 | data = signal + 0.1 * np.random.randn(Nt) 67 | signal_WFT = WFT_wxk(signal, dt, fre_focus, sigma_t) 68 | data_WFT = WFT_wxk(data, dt, fre_focus, sigma_t) 69 | all_data_WFT.append(data_WFT) 70 | all_signal_WFT.append(signal_WFT) 71 | 72 | # Transform data into the correct format for CNN input 73 | data_real = np.real(all_data_WFT).reshape(num_samples, 1, Nf, Nt) 74 | data_imag = np.imag(all_data_WFT).reshape(num_samples, 1, Nf, Nt) 75 | labels_real = np.real(all_signal_WFT).reshape(num_samples, 1, Nf, Nt) 76 | labels_imag = np.imag(all_signal_WFT).reshape(num_samples, 1, Nf, Nt) 77 | 78 | class ComplexDataset(Dataset): 79 | def __init__(self, data_real, data_imag, labels_real, labels_imag): 80 | self.data_real = torch.from_numpy(data_real).float() 81 | self.data_imag = torch.from_numpy(data_imag).float() 82 | self.labels_real = torch.from_numpy(labels_real).float() 83 | self.labels_imag = torch.from_numpy(labels_imag).float() 84 | 85 | def __len__(self): 86 | return self.data_real.size(0) 87 | 88 | def __getitem__(self, idx): 89 | return (self.data_real[idx], self.data_imag[idx]), (self.labels_real[idx], self.labels_imag[idx]) 90 | 91 | # Create dataset 92 | dataset = ComplexDataset(data_real, data_imag, labels_real, labels_imag) 93 | train_loader = DataLoader(dataset, batch_size=2, shuffle=True) 94 | 95 | # Define complex convolution layer 96 | class ComplexConv2d(nn.Module): 97 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0): 98 | super(ComplexConv2d, self).__init__() 99 | self.real_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) 100 | self.imag_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) 101 | 102 | def forward(self, x): 103 | real, imag = x 104 | return self.real_conv(real) - self.imag_conv(imag), self.imag_conv(real) + self.real_conv(imag) 105 | 106 | # Define complex convolution network 五层 107 | class ComplexConvNet(nn.Module): 108 | def __init__(self): 109 | super(ComplexConvNet, self).__init__() 110 | self.conv1 = ComplexConv2d(1, 16, 3, padding=1) 111 | self.conv2 = ComplexConv2d(16, 32, 3, padding=1) 112 | self.conv3 = ComplexConv2d(32, 64, 3, padding=1) 113 | self.conv4 = ComplexConv2d(64, 64, 3, padding=1) 114 | self.conv5 = ComplexConv2d(64, 1, 3, padding=1) 115 | self.dropout = nn.Dropout(0.5) 116 | 117 | def forward(self, x): 118 | real, imag = self.conv1(x) 119 | real, imag = F.relu(real), F.relu(imag) 120 | real, imag = self.conv2((real, imag)) 121 | real, imag = F.relu(real), F.relu(imag) 122 | real, imag = self.conv3((real, imag)) 123 | real, imag = F.relu(real), F.relu(imag) 124 | real, imag = self.dropout(real), self.dropout(imag) 125 | real, imag = self.conv4((real, imag)) 126 | real, imag = F.relu(real), F.relu(imag) 127 | real, imag = self.conv5((real, imag)) 128 | return real, imag 129 | 130 | # Initialize network and optimizer 131 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 132 | model = ComplexConvNet().to(device) 133 | criterion = nn.MSELoss() 134 | optimizer = optim.Adam(model.parameters(), lr=0.0001) 135 | 136 | # Train the network 137 | losses = [] 138 | num_epochs = 200 139 | for epoch in range(num_epochs): 140 | total_loss = 0 141 | for (noisy_real, noisy_imag), (clean_real, clean_imag) in train_loader: 142 | noisy_real, noisy_imag = noisy_real.to(device), noisy_imag.to(device) 143 | clean_real, clean_imag = clean_real.to(device), clean_imag.to(device) 144 | optimizer.zero_grad() 145 | output_real, output_imag = model((noisy_real, noisy_imag)) 146 | loss = criterion(output_real, clean_real) + criterion(output_imag, clean_imag) 147 | loss.backward() 148 | optimizer.step() 149 | total_loss += loss.item() 150 | avg_loss = total_loss / len(train_loader) 151 | losses.append(avg_loss) 152 | print(f'Epoch {epoch+1}, Loss: {avg_loss}') 153 | 154 | # Plot the training loss 155 | plt.figure(figsize=(10, 5)) 156 | plt.plot(losses, label='Training Loss') 157 | plt.xlabel('Epoch') 158 | plt.ylabel('Loss') 159 | plt.title('Loss During Training') 160 | plt.legend() 161 | plt.show() 162 | 163 | # Function to calculate RMSE 164 | def rmse(predictions, targets): 165 | return torch.sqrt(((predictions - targets) ** 2).mean()) 166 | 167 | 168 | # 生成测试信号及其WFT 169 | test_time = np.arange(Nt) * dt 170 | clean_signal = np.cos(2 * np.pi * 25 * test_time) + np.cos(2 * np.pi * 50 * test_time) 171 | clean_signal_WFT = WFT_wxk(clean_signal, dt, fre_focus, sigma_t) 172 | test_data = clean_signal+ 0.1 * np.random.randn(Nt) 173 | test_data_WFT = WFT_wxk(test_data, dt, fre_focus, sigma_t) 174 | test_data_real = np.real(test_data_WFT).reshape(1, 1, Nf, Nt) 175 | test_data_imag = np.imag(test_data_WFT).reshape(1, 1, Nf, Nt) 176 | test_data_real = torch.from_numpy(test_data_real).float().to(device) 177 | test_data_imag = torch.from_numpy(test_data_imag).float().to(device) 178 | clean_labels_real = np.real(clean_signal_WFT).reshape(1, 1, Nf, Nt) 179 | clean_labels_imag = np.imag(clean_signal_WFT).reshape(1, 1, Nf, Nt) 180 | clean_labels_real = torch.from_numpy(clean_labels_real).float().to(device) 181 | clean_labels_imag = torch.from_numpy(clean_labels_imag).float().to(device) 182 | 183 | 184 | # Model evaluation 185 | model.eval() 186 | with torch.no_grad(): 187 | output_real, output_imag = model((test_data_real, test_data_imag)) 188 | 189 | # Calculate RMSE for the real and imaginary parts 190 | rmse_real = rmse(output_real, clean_labels_real) 191 | rmse_imag = rmse(output_imag, clean_labels_imag) 192 | print(f'Test RMSE (Real part): {rmse_real.item()}') 193 | print(f'Test RMSE (Imaginary part): {rmse_imag.item()}') 194 | 195 | 196 | output_magnitude = np.abs(output_real.cpu().numpy().squeeze() + 1j * output_imag.cpu().numpy().squeeze()) 197 | clean_magnitude = np.abs(clean_labels_real.cpu().numpy().squeeze() + 1j * clean_labels_imag.cpu().numpy().squeeze()) 198 | 199 | # 时间轴的刻度和标签 200 | time_ticks = np.linspace(0, Nt-1, num=5) # 创建5个刻度点 201 | time_labels = np.linspace(0, 1, num=5) # 创建对应0到1之间的5个标签 202 | 203 | # 频率轴的刻度和标签 204 | # 此处我们需要从fre_focus数组中找到25Hz和50Hz对应的索引 205 | frequency_indices = [np.argmin(np.abs(fre_focus - frequency)) for frequency in [25, 50]] 206 | frequency_labels = ['25 Hz', '50 Hz'] # 创建对应的标签 207 | 208 | # 创建画布和子图 209 | fig, axs = plt.subplots(1, 3, figsize=(15, 15)) 210 | # 设定jet色图 211 | #cmap = plt.get_cmap('jet') 212 | # 含噪信号的时频图 213 | axs[0].imshow(np.abs(test_data_WFT), extent=[time[0], time[-1], fre_focus[0], fre_focus[-1]], aspect='auto', origin='lower', cmap='viridis') 214 | axs[0].set_title('Test Noisy Signal (Time-Frequency)') 215 | axs[0].set_xlabel('Time [s]') 216 | axs[0].set_ylabel('Frequency [Hz]') 217 | 218 | 219 | # 测试输出信号的时频图 220 | axs[1].imshow(output_magnitude, extent=[time[0], time[-1], fre_focus[0], fre_focus[-1]], aspect='auto', origin='lower',cmap='viridis') 221 | axs[1].set_title('Test Output Signal (Time-Frequency)') 222 | axs[1].set_xlabel('Time [s]') 223 | axs[1].set_ylabel('Frequency [Hz]') 224 | 225 | # 干净信号的时频图 226 | axs[2].imshow(clean_magnitude, extent=[time[0], time[-1], fre_focus[0], fre_focus[-1]], aspect='auto', origin='lower', cmap='viridis') 227 | axs[2].set_title('Test Clean Signal (Time-Frequency)') 228 | axs[2].set_xlabel('Time [s]') 229 | axs[2].set_ylabel('Frequency [Hz]') 230 | 231 | plt.tight_layout() 232 | plt.show() 233 | 234 | 235 | # 在模型评估后计算测试输出与测试输入的复数差值 236 | diff_input = (output_real - test_data_real) + 1j * (output_imag - test_data_imag) 237 | diff_input_magnitude = np.abs(diff_input.cpu().numpy().squeeze()) 238 | 239 | # 计算测试输出与干净信号的复数差值 240 | diff_clean = (output_real - clean_labels_real) + 1j * (output_imag - clean_labels_imag) 241 | diff_clean_magnitude = np.abs(diff_clean.cpu().numpy().squeeze()) 242 | 243 | # 计算干净信号的模值均值和两个差值的模值均值 244 | clean_signal_magnitude_mean = np.mean(np.abs(clean_labels_real.cpu().numpy().squeeze() + 1j * clean_labels_imag.cpu().numpy().squeeze())) 245 | diff_input_magnitude_mean = np.mean(diff_input_magnitude) 246 | diff_clean_magnitude_mean = np.mean(diff_clean_magnitude) 247 | 248 | print(f'Clean Signal Magnitude Mean: {clean_signal_magnitude_mean}') 249 | print(f'Test Output vs Test Input Difference Magnitude Mean: {diff_input_magnitude_mean}') 250 | print(f'Test Output vs Clean Signal Difference Magnitude Mean: {diff_clean_magnitude_mean}') 251 | 252 | 253 | 254 | # 创建画布和子图 255 | fig, axs = plt.subplots(1, 2, figsize=(15, 5)) # 一行两列 256 | 257 | # 差值图像(测试输出与测试输入) 258 | axs[0].imshow(diff_input_magnitude, extent=[time[0], time[-1], fre_focus[0], fre_focus[-1]], aspect='auto', origin='lower', cmap='viridis') 259 | axs[0].set_title('Difference (Test Output vs Test Input)') 260 | axs[0].set_xlabel('Time [s]') 261 | axs[0].set_ylabel('Frequency [Hz]') 262 | 263 | # 差值图像(测试输出与干净信号) 264 | axs[1].imshow(diff_clean_magnitude, extent=[time[0], time[-1], fre_focus[0], fre_focus[-1]], aspect='auto', origin='lower', cmap='viridis') 265 | axs[1].set_title('Difference (Test Output vs Clean Signal)') 266 | axs[1].set_xlabel('Time [s]') 267 | axs[1].set_ylabel('Frequency [Hz]') 268 | 269 | plt.tight_layout() 270 | plt.show() 271 | 272 | 273 | 274 | # 确保output_real和output_imag是二维的 275 | output_real_2d = output_real.cpu().numpy().squeeze() 276 | output_imag_2d = output_imag.cpu().numpy().squeeze() 277 | 278 | # 使用窗口傅里叶反变换将测试输入、测试输出和干净信号从频率域转换回时域 279 | test_input_WFTI = WFTI_wxk(test_data_real.cpu().numpy().squeeze() + 1j * test_data_imag.cpu().numpy().squeeze(), dt, fre_focus, sigma_t) 280 | test_output_WFTI = WFTI_wxk(output_real_2d + 1j * output_imag_2d, dt, fre_focus, sigma_t) 281 | clean_signal_WFTI = WFTI_wxk(clean_labels_real.cpu().numpy().squeeze() + 1j * clean_labels_imag.cpu().numpy().squeeze(), dt, fre_focus, sigma_t) 282 | 283 | # 转换为实数部分用于绘图 284 | test_input_WFTI_real = np.real(test_input_WFTI) 285 | test_output_WFTI_real = np.real(test_output_WFTI) 286 | clean_signal_WFTI_real = np.real(clean_signal_WFTI) 287 | 288 | # 绘制时域图 289 | plt.figure(figsize=(15, 5)) 290 | 291 | # 测试输入信号的窗口傅里叶反变换 292 | plt.subplot(1, 3, 1) 293 | plt.plot(time, test_input_WFTI_real, label='Test Input WFTI') 294 | plt.title('Test Input Signal in Time Domain') 295 | plt.xlabel('Time (s)') 296 | plt.ylabel('Amplitude') 297 | plt.legend() 298 | 299 | # 测试输出信号的窗口傅里叶反变换 300 | plt.subplot(1, 3, 2) 301 | plt.plot(time, test_output_WFTI_real, label='Test Output WFTI') 302 | plt.title('Test Output Signal in Time Domain') 303 | plt.xlabel('Time (s)') 304 | plt.ylabel('Amplitude') 305 | plt.legend() 306 | 307 | # 干净信号的窗口傅里叶反变换 308 | plt.subplot(1, 3, 3) 309 | plt.plot(time, clean_signal_WFTI_real, label='Clean Signal WFTI') 310 | plt.title('Clean Signal in Time Domain') 311 | plt.xlabel('Time (s)') 312 | plt.ylabel('Amplitude') 313 | plt.legend() 314 | 315 | plt.tight_layout() 316 | plt.show() 317 | 318 | 319 | 320 | # 计算测试输出信号与测试输入信号的时域差值 321 | noise_signal_WFTI = test_input_WFTI_real - clean_signal_WFTI_real 322 | 323 | # 计算噪声数据和干净数据的信噪比 324 | signal_power = np.mean(clean_signal_WFTI_real ** 2) 325 | noise_power = np.mean(noise_signal_WFTI ** 2) 326 | snr = 10 * np.log10(signal_power / noise_power) 327 | 328 | # 输出信噪比 329 | print(f'Signal-to-Noise Ratio (SNR): {snr:.2f} dB') 330 | 331 | # 可视化噪声数据 332 | plt.figure(figsize=(15, 5)) 333 | plt.subplot(1, 3, 1) 334 | plt.plot(time, noise_signal_WFTI, label='Noise Signal WFTI') 335 | plt.title('Noise Signal in Time Domain') 336 | plt.xlabel('Time (s)') 337 | plt.ylabel('Amplitude') 338 | plt.legend() 339 | 340 | plt.tight_layout() 341 | plt.show() -------------------------------------------------------------------------------- /templates.py: -------------------------------------------------------------------------------- 1 | class Template: 2 | def encode(self, sample): 3 | """ 4 | Return prompted version of the example (without the answer/candidate) 5 | """ 6 | raise NotImplementedError 7 | 8 | def verbalize(self, sample, candidate): 9 | """ 10 | Return the prompted version of the example (with the answer/candidate) 11 | """ 12 | return candidate 13 | 14 | def encode_sfc(self, sample): 15 | """ 16 | Same as encode, but for SFC (calibration) -- this usually means the input is not included 17 | """ 18 | return "" 19 | 20 | def verbalize_sfc(self, sample, candidate): 21 | """ 22 | Same as verbalize, but for SFC (calibration) -- this usually means the input is not included 23 | """ 24 | return candidate 25 | 26 | 27 | class SST2Template(Template): 28 | verbalizer = {0: "terrible", 1: "great"} 29 | def encode(self, sample): 30 | text = sample.data["sentence"].strip() 31 | return f"{text} It was" 32 | 33 | def verbalize(self, sample, candidate): 34 | text = sample.data["sentence"].strip() 35 | return f"{text} It was {self.verbalizer[candidate]}" 36 | 37 | def encode_sfc(self, sample): 38 | return f" It was" 39 | 40 | def verbalize_sfc(self, sample, candidate): 41 | return f" It was {self.verbalizer[candidate]}" 42 | 43 | 44 | class CopaTemplate(Template): 45 | capitalization: str = "correct" 46 | effect_conj: str = " so " 47 | cause_conj: str = " because " 48 | 49 | def get_conjucture(self, sample): 50 | if sample.data["question"] == "effect": 51 | conjunction = self.effect_conj 52 | elif sample.data["question"] == "cause": 53 | conjunction = self.cause_conj 54 | else: 55 | raise NotImplementedError 56 | return conjunction 57 | 58 | def get_prompt(self, sample): 59 | premise = sample.data["premise"].rstrip() 60 | if premise.endswith("."): # TODO Add other scripts with different punctuation 61 | premise = premise[:-1] 62 | conjunction = self.get_conjucture(sample) 63 | prompt = premise + conjunction 64 | if self.capitalization == "upper": 65 | prompt = prompt.upper() 66 | elif self.capitalization == "lower": 67 | prompt = prompt.lower() 68 | return prompt 69 | 70 | def encode(self, sample): 71 | prompt = self.get_prompt(sample) 72 | return prompt 73 | 74 | def capitalize(self, c): 75 | if self.capitalization == "correct": 76 | words = c.split(" ") 77 | if words[0] != "I": 78 | words[0] = words[0].lower() 79 | return " ".join(words) 80 | elif self.capitalization == "bug": 81 | return c 82 | elif self.capitalization == "upper": 83 | return c.upper() 84 | elif self.capitalization == "lower": 85 | return c.lower() 86 | else: 87 | raise NotImplementedError 88 | 89 | def verbalize(self, sample, candidate): 90 | prompt = self.get_prompt(sample) 91 | return prompt + self.capitalize(candidate) 92 | 93 | def encode_sfc(self, sample): 94 | conjunction = self.get_conjucture(sample) 95 | return conjunction.strip() 96 | 97 | def verbalize_sfc(self, sample, candidate): 98 | conjunction = self.get_conjucture(sample) 99 | sfc_prompt = conjunction.strip() + " " + self.capitalize(candidate) 100 | return sfc_prompt 101 | 102 | 103 | class BoolQTemplate(Template): 104 | def encode(self, sample): 105 | passage = sample.data["passage"] 106 | question = sample.data["question"] 107 | if not question.endswith("?"): 108 | question = question + "?" 109 | question = question[0].upper() + question[1:] 110 | return f"{passage} {question}" 111 | 112 | def verbalize(self, sample, candidate): 113 | passage = sample.data["passage"] 114 | question = sample.data["question"] 115 | if not question.endswith("?"): 116 | question = question + "?" 117 | question = question[0].upper() + question[1:] 118 | return f"{passage} {question} {candidate}" 119 | 120 | def encode_sfc(self, sample): 121 | return "" 122 | 123 | def verbalize_sfc(self, sample, candidate): 124 | return candidate 125 | 126 | 127 | class BoolQTemplateV2(Template): 128 | def encode(self, sample): 129 | passage = sample.data["passage"] 130 | question = sample.data["question"] 131 | if not question.endswith("?"): 132 | question = question + "?" 133 | question = question[0].upper() + question[1:] 134 | return f"{passage} {question}\\n\\n" 135 | 136 | def verbalize(self, sample, candidate): 137 | passage = sample.data["passage"] 138 | question = sample.data["question"] 139 | if not question.endswith("?"): 140 | question = question + "?" 141 | question = question[0].upper() + question[1:] 142 | return f"{passage} {question}\\n\\n{candidate}" 143 | 144 | def encode_sfc(self, sample): 145 | return "" 146 | 147 | def verbalize_sfc(self, sample, candidate): 148 | return candidate 149 | 150 | 151 | class BoolQTemplateV3(Template): 152 | def encode(self, sample): 153 | passage = sample.data["passage"] 154 | question = sample.data["question"] 155 | if not question.endswith("?"): 156 | question = question + "?" 157 | question = question[0].upper() + question[1:] 158 | return f"{passage} {question}\n" 159 | 160 | def verbalize(self, sample, candidate): 161 | passage = sample.data["passage"] 162 | question = sample.data["question"] 163 | if not question.endswith("?"): 164 | question = question + "?" 165 | question = question[0].upper() + question[1:] 166 | return f"{passage} {question}\n{candidate}" 167 | 168 | def encode_sfc(self, sample): 169 | return "" 170 | 171 | def verbalize_sfc(self, sample, candidate): 172 | return candidate 173 | 174 | 175 | class MultiRCTemplate(Template): 176 | # From PromptSource 1 177 | verbalizer = {0: "No", 1: "Yes"} 178 | 179 | def encode(self, sample): 180 | paragraph = sample.data["paragraph"] 181 | question = sample.data["question"] 182 | answer = sample.data["answer"] 183 | return f"{paragraph}\nQuestion: {question}\nI found this answer \"{answer}\". Is that correct? Yes or No?\n" 184 | 185 | def verbalize(self, sample, candidate): 186 | paragraph = sample.data["paragraph"] 187 | question = sample.data["question"] 188 | answer = sample.data["answer"] 189 | return f"{paragraph}\nQuestion: {question}\nI found this answer \"{answer}\". Is that correct? Yes or No?\n{self.verbalizer[candidate]}" 190 | 191 | def encode_sfc(self, sample): 192 | return f"" 193 | 194 | def verbalize_sfc(self, sample, candidate): 195 | return f"{self.verbalizer[candidate]}" 196 | 197 | 198 | class CBTemplate(Template): 199 | # From PromptSource 1 200 | verbalizer = {0: "Yes", 1: "No", 2: "Maybe"} 201 | 202 | def encode(self, sample): 203 | premise = sample.data["premise"] 204 | hypothesis = sample.data["hypothesis"] 205 | return f"Suppose {premise} Can we infer that \"{hypothesis}\"? Yes, No, or Maybe?\n" 206 | 207 | def verbalize(self, sample, candidate): 208 | premise = sample.data["premise"] 209 | hypothesis = sample.data["hypothesis"] 210 | return f"Suppose {premise} Can we infer that \"{hypothesis}\"? Yes, No, or Maybe?\n{self.verbalizer[candidate]}" 211 | 212 | def encode_sfc(self, sample): 213 | return f"" 214 | 215 | def verbalize_sfc(self, sample, candidate): 216 | return f"{self.verbalizer[candidate]}" 217 | 218 | 219 | class WICTemplate(Template): 220 | # From PromptSource 1 221 | verbalizer = {0: "No", 1: "Yes"} 222 | 223 | def encode(self, sample): 224 | sent1 = sample.data["sentence1"] 225 | sent2 = sample.data["sentence2"] 226 | word = sample.data["word"] 227 | return f"Does the word \"{word}\" have the same meaning in these two sentences? Yes, No?\n{sent1}\n{sent2}\n" 228 | 229 | def verbalize(self, sample, candidate): 230 | sent1 = sample.data["sentence1"] 231 | sent2 = sample.data["sentence2"] 232 | word = sample.data["word"] 233 | return f"Does the word \"{word}\" have the same meaning in these two sentences? Yes, No?\n{sent1}\n{sent2}\n{self.verbalizer[candidate]}" 234 | 235 | def encode_sfc(self, sample): 236 | return f"" 237 | 238 | def verbalize_sfc(self, sample, candidate): 239 | return f"{self.verbalizer[candidate]}" 240 | 241 | 242 | class WSCTemplate(Template): 243 | # From PromptSource 1 244 | verbalizer = {0: "No", 1: "Yes"} 245 | 246 | def encode(self, sample): 247 | text = sample.data['text'] 248 | span1 = sample.data['span1_text'] 249 | span2 = sample.data['span2_text'] 250 | return f"{text}\nIn the previous sentence, does the pronoun \"{span2.lower()}\" refer to {span1}? Yes or No?\n" 251 | 252 | def verbalize(self, sample, candidate): 253 | text = sample.data['text'] 254 | span1 = sample.data['span1_text'] 255 | span2 = sample.data['span2_text'] 256 | return f"{text}\nIn the previous sentence, does the pronoun \"{span2.lower()}\" refer to {span1}? Yes or No?\n{self.verbalizer[candidate]}" 257 | 258 | def encode_sfc(self, sample): 259 | return f"" 260 | 261 | def verbalize_sfc(self, sample, candidate): 262 | return f"{self.verbalizer[candidate]}" 263 | 264 | 265 | class ReCoRDTemplate(Template): 266 | # From PromptSource 1 but modified 267 | 268 | def encode(self, sample): 269 | passage = sample.data['passage'] 270 | query = sample.data['query'] 271 | return f"{passage}\n{query}\nQuestion: what is the \"@placeholder\"\nAnswer:" 272 | 273 | def verbalize(self, sample, candidate): 274 | passage = sample.data['passage'] 275 | query = sample.data['query'] 276 | return f"{passage}\n{query}\nQuestion: what is the \"@placeholder\"\nAnswer: {candidate}" 277 | 278 | def encode_sfc(self, sample): 279 | return f"Answer:" 280 | 281 | def verbalize_sfc(self, sample, candidate): 282 | return f"Answer: {candidate}" 283 | 284 | 285 | class ReCoRDTemplateGPT3(Template): 286 | # From PromptSource 1 but modified 287 | 288 | def encode(self, sample): 289 | passage = sample.data['passage'].replace("@highlight\n", "- ") 290 | return f"{passage}\n-" 291 | 292 | def verbalize(self, sample, candidate): 293 | passage = sample.data['passage'].replace("@highlight\n", "- ") 294 | query = sample.data['query'].replace("@placeholder", candidate[0] if isinstance(candidate, list) else candidate) 295 | return f"{passage}\n- {query}" 296 | 297 | # passage = sample.data['passage'] 298 | # query = sample.data['query'] 299 | # return f"{passage}\n{query}\nQuestion: what is the \"@placeholder\"\nAnswer: {candidate}" 300 | 301 | def encode_sfc(self, sample): 302 | return f"-" 303 | 304 | def verbalize_sfc(self, sample, candidate): 305 | query = sample.data['query'].replace("@placeholder", candidate[0] if isinstance(candidate, list) else candidate) 306 | return f"- {query}" 307 | 308 | 309 | class RTETemplate(Template): 310 | # From PromptSource 1 311 | verbalizer={0: "Yes", 1: "No"} 312 | 313 | def encode(self, sample): 314 | premise = sample.data['premise'] 315 | hypothesis = sample.data['hypothesis'] 316 | return f"{premise}\nDoes this mean that \"{hypothesis}\" is true? Yes or No?\n" 317 | 318 | def verbalize(self, sample, candidate): 319 | premise = sample.data['premise'] 320 | hypothesis = sample.data['hypothesis'] 321 | return f"{premise}\nDoes this mean that \"{hypothesis}\" is true? Yes or No?\n{self.verbalizer[candidate]}" 322 | 323 | def encode_sfc(self, sample): 324 | return f"" 325 | 326 | def verbalize_sfc(self, sample, candidate): 327 | return f"{self.verbalizer[candidate]}" 328 | 329 | 330 | class SQuADv2Template(Template): 331 | 332 | def encode(self, sample): 333 | question = sample.data['question'].strip() 334 | title = sample.data['title'] 335 | context = sample.data['context'] 336 | answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one 337 | 338 | return f"Title: {title}\nContext: {context}\nQuestion: {question}\nAnswer:" 339 | 340 | def verbalize(self, sample, candidate): 341 | question = sample.data['question'].strip() 342 | title = sample.data['title'] 343 | context = sample.data['context'] 344 | answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one 345 | 346 | return f"Title: {title}\nContext: {context}\nQuestion: {question}\nAnswer: {answer}\n" 347 | 348 | 349 | def encode_sfc(self, sample): 350 | raise NotImplementedError 351 | 352 | def verbalize_sfc(self, sample, candidate): 353 | raise NotImplementedError 354 | 355 | 356 | class DROPTemplate(Template): 357 | 358 | def encode(self, sample): 359 | question = sample.data['question'].strip() 360 | # title = sample.data['title'] 361 | context = sample.data['context'] 362 | answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one 363 | 364 | return f"Passage: {context}\nQuestion: {question}\nAnswer:" 365 | 366 | def verbalize(self, sample, candidate): 367 | question = sample.data['question'].strip() 368 | # title = sample.data['title'] 369 | context = sample.data['context'] 370 | answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one 371 | 372 | return f"Passage: {context}\nQuestion: {question}\nAnswer: {answer}\n" 373 | 374 | 375 | def encode_sfc(self, sample): 376 | raise NotImplementedError 377 | 378 | def verbalize_sfc(self, sample, candidate): 379 | raise NotImplementedError 380 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import contextlib 4 | from typing import Optional, Union 5 | import numpy as np 6 | from dataclasses import dataclass, is_dataclass, asdict 7 | import logging 8 | import time 9 | from torch.nn import CrossEntropyLoss 10 | import torch.nn.functional as F 11 | from transformers.modeling_outputs import CausalLMOutputWithPast 12 | import torch 13 | from transformers.utils import PaddingStrategy 14 | from transformers import PreTrainedTokenizerBase 15 | from transformers.data.data_collator import DataCollatorMixin 16 | import transformers 17 | from typing import Optional, Union, List, Dict, Any 18 | import signal 19 | from subprocess import call 20 | from collections.abc import Mapping 21 | from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union 22 | InputDataClass = NewType("InputDataClass", Any) 23 | from dataclasses import dataclass 24 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 25 | 26 | 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | def forward_wrap_with_option_len(self, input_ids=None, labels=None, option_len=None, num_options=None, return_dict=None, **kwargs): 32 | """ 33 | This is to replace the original forward function of Transformer models to enable: 34 | (1) Partial target sequence: loss will only be calculated on part of the sequence 35 | (2) Classification-style training: a classification loss (CE) will be calculated over several options 36 | Input: 37 | - input_ids, labels: same as the original forward function 38 | - option_len: a list of int indicating the option lengths, and loss will be calculated only on the 39 | last option_len tokens 40 | - num_options: a list of int indicating the number of options for each example (this will be #label 41 | words for classification tasks and #choices for multiple choice tasks), and a classification loss 42 | will be calculated. 43 | """ 44 | outputs = self.original_forward(input_ids=input_ids, **kwargs) 45 | if labels is None: 46 | return outputs 47 | logits = outputs.logits 48 | 49 | loss = None 50 | # Shift so that tokens < n predict n 51 | shift_logits = logits[..., :-1, :].contiguous() 52 | # Here we use input_ids (which should always = labels) bc sometimes labels are correct candidate IDs 53 | shift_labels = torch.clone(input_ids)[..., 1:].contiguous() 54 | shift_labels[shift_labels == self.config.pad_token_id] = -100 55 | 56 | # Apply option len (do not calculate loss on the non-option part) 57 | for _i, _len in enumerate(option_len): 58 | shift_labels[_i, :-_len] = -100 59 | 60 | # Calculate the loss 61 | loss_fct = CrossEntropyLoss(ignore_index=-100) 62 | if num_options is not None: 63 | # Train as a classification tasks 64 | log_probs = F.log_softmax(shift_logits, dim=-1) 65 | mask = shift_labels != -100 # Option part 66 | shift_labels[~mask] = 0 # So that it doesn't mess up with indexing 67 | 68 | selected_log_probs = torch.gather(log_probs, dim=-1, index=shift_labels.unsqueeze(-1)).squeeze(-1) # (bsz x num_options, len) 69 | selected_log_probs = (selected_log_probs * mask).sum(-1) / mask.sum(-1) # (bsz x num_options) 70 | 71 | if any([x != num_options[0] for x in num_options]): 72 | # Multi choice tasks with different number of options 73 | loss = 0 74 | start_id = 0 75 | count = 0 76 | while start_id < len(num_options): 77 | end_id = start_id + num_options[start_id] 78 | _logits = selected_log_probs[start_id:end_id].unsqueeze(0) # (1, num_options) 79 | _labels = labels[start_id:end_id][0].unsqueeze(0) # (1) 80 | loss = loss_fct(_logits, _labels) + loss 81 | count += 1 82 | start_id = end_id 83 | loss = loss / count 84 | else: 85 | num_options = num_options[0] 86 | selected_log_probs = selected_log_probs.view(-1, num_options) # (bsz, num_options) 87 | labels = labels.view(-1, num_options)[:, 0] # Labels repeat so we only take the first one 88 | loss = loss_fct(selected_log_probs, labels) 89 | else: 90 | loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) 91 | 92 | if not return_dict: 93 | output = (logits,) + outputs[1:] 94 | return (loss,) + output if loss is not None else output 95 | 96 | return CausalLMOutputWithPast( 97 | loss=loss, 98 | logits=logits, 99 | past_key_values=outputs.past_key_values, 100 | hidden_states=outputs.hidden_states, 101 | attentions=outputs.attentions, 102 | ) 103 | 104 | 105 | def encode_prompt(task, template, train_samples, eval_sample, tokenizer, max_length, sfc=False, icl_sfc=False, generation=False, generation_with_gold=False, max_new_tokens=None): 106 | """ 107 | Encode prompts for eval_sample 108 | Input: 109 | - task, template: task and template class 110 | - train_samples, eval_sample: demonstrations and the actual sample 111 | - tokenizer, max_length: tokenizer and max length 112 | - sfc: generate prompts for calibration (surface form competition; https://arxiv.org/abs/2104.08315) 113 | - icl_sfc: generate prompts for ICL version calibration 114 | - generation: whether it is an generation task 115 | - generation_with_gold: whether to include the generation-task gold answers (for training) 116 | - max_new_tokens: max number of new tokens to generate so that we can save enough space 117 | (only for generation tasks) 118 | Output: 119 | - encodings: a list of N lists of tokens. N is the number of options for classification/multiple-choice. 120 | - option_lens: a list of N integers indicating the number of option tokens. 121 | """ 122 | 123 | # Demonstrations for ICL 124 | train_prompts = [template.verbalize(sample, sample.correct_candidate).strip() for sample in train_samples] 125 | train_prompts = task.train_sep.join(train_prompts).strip() 126 | 127 | # sfc or icl_sfc indicates that this example is used for calibration 128 | if sfc or icl_sfc: 129 | encode_fn = template.encode_sfc; verbalize_fn = template.verbalize_sfc 130 | else: 131 | encode_fn = template.encode; verbalize_fn = template.verbalize 132 | 133 | unverbalized_eval_prompt = encode_fn(eval_sample).strip(' ') 134 | if not generation: 135 | # We generate one prompt for each candidate (different classes in classification) 136 | # or different choices in multiple-choice tasks 137 | verbalized_eval_prompts = [verbalize_fn(eval_sample, cand).strip(' ') for cand in eval_sample.candidates] 138 | unverbalized_eval_prompt_length = len(tokenizer.encode(unverbalized_eval_prompt)) 139 | option_lens = [(len(tokenizer.encode(verbalized_eval_prompt)) - unverbalized_eval_prompt_length) for verbalized_eval_prompt in verbalized_eval_prompts] 140 | 141 | if sfc: 142 | # Without demonstrations 143 | final_prompts = verbalized_eval_prompts 144 | else: 145 | # With demonstrations 146 | final_prompts = [(train_prompts + task.train_sep + eval_prompt).lstrip().strip(' ') for eval_prompt in verbalized_eval_prompts] 147 | else: 148 | assert not sfc and not icl_sfc, "Generation tasks do not support SFC" 149 | if generation_with_gold: 150 | verbalized_eval_prompts = [verbalize_fn(eval_sample, eval_sample.correct_candidate)] 151 | unverbalized_eval_prompt_length = len(tokenizer.encode(unverbalized_eval_prompt)) 152 | option_lens = [(len(tokenizer.encode(verbalized_eval_prompt)) - unverbalized_eval_prompt_length) for verbalized_eval_prompt in verbalized_eval_prompts] 153 | final_prompts = [(train_prompts + task.train_sep + eval_prompt).lstrip().strip(' ') for eval_prompt in verbalized_eval_prompts] 154 | else: 155 | option_lens = [0] 156 | final_prompts = [(train_prompts + task.train_sep + unverbalized_eval_prompt).lstrip().strip(' ')] 157 | 158 | # Tokenize 159 | encodings = [tokenizer.encode(final_prompt) for final_prompt in final_prompts] 160 | 161 | # Truncate (left truncate as demonstrations are less important) 162 | if generation and max_new_tokens is not None: 163 | max_length = max_length - max_new_tokens 164 | 165 | if any([len(encoding) > max_length for encoding in encodings]): 166 | logger.warn("Exceed max length") 167 | encodings = [encoding[0:1] + encoding[1:][-(max_length-1):] for encoding in encodings] 168 | return encodings, option_lens 169 | 170 | 171 | 172 | @dataclass 173 | class ICLCollator: 174 | """ 175 | Collator for ICL 176 | """ 177 | tokenizer: PreTrainedTokenizerBase 178 | 179 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: 180 | if not isinstance(features[0], Mapping): 181 | features = [vars(f) for f in features] 182 | first = features[0] 183 | batch = {} 184 | 185 | pad_id = self.tokenizer.pad_token_id 186 | 187 | pad_ids = {"input_ids": pad_id, "attention_mask": 0, "sfc_input_ids": pad_id, "sfc_attention_mask": 0, "labels": pad_id} 188 | for key in first: 189 | pp = pad_ids[key] 190 | lens = [len(f[key]) for f in features] 191 | max_len = max(lens) 192 | feature = np.stack([np.pad(f[key], (0, max_len - lens[i]), "constant", constant_values=(0, pp)) for i, f in enumerate(features)]) 193 | padded_feature = torch.from_numpy(feature).long() 194 | batch[key] = padded_feature 195 | 196 | return batch 197 | 198 | 199 | @dataclass 200 | class DataCollatorWithPaddingAndNesting: 201 | """ 202 | Collator for training 203 | """ 204 | 205 | tokenizer: PreTrainedTokenizerBase 206 | padding: Union[bool, str, PaddingStrategy] = True 207 | max_length: Optional[int] = None 208 | pad_to_multiple_of: Optional[int] = None 209 | return_tensors: str = "pt" 210 | 211 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: 212 | features = [ff for f in features for ff in f] 213 | batch = self.tokenizer.pad( 214 | features, 215 | padding=self.padding, 216 | max_length=self.max_length, 217 | pad_to_multiple_of=self.pad_to_multiple_of, 218 | return_tensors=self.return_tensors, 219 | ) 220 | if "label" in batch: 221 | batch["labels"] = batch["label"] 222 | del batch["label"] 223 | if "label_ids" in batch: 224 | batch["labels"] = batch["label_ids"] 225 | del batch["label_ids"] 226 | return batch 227 | 228 | 229 | @dataclass 230 | class NondiffCollator(DataCollatorMixin): 231 | """ 232 | Collator for non-differentiable objectives 233 | """ 234 | tokenizer: PreTrainedTokenizerBase 235 | padding: Union[bool, str, PaddingStrategy] = True 236 | max_length: Optional[int] = None 237 | pad_to_multiple_of: Optional[int] = None 238 | label_pad_token_id: int = -100 239 | return_tensors: str = "pt" 240 | 241 | def torch_call(self, features): 242 | import torch 243 | 244 | label_name = "label" if "label" in features[0].keys() else "labels" 245 | labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None 246 | 247 | no_labels_features = [{k: v for k, v in feature.items() if k != label_name and k != "gold"} for feature in features] 248 | 249 | batch = self.tokenizer.pad( 250 | no_labels_features, 251 | padding=self.padding, 252 | max_length=self.max_length, 253 | pad_to_multiple_of=self.pad_to_multiple_of, 254 | return_tensors="pt", 255 | ) 256 | 257 | if labels is None: 258 | return batch 259 | 260 | sequence_length = batch["input_ids"].shape[1] 261 | padding_side = self.tokenizer.padding_side 262 | 263 | def to_list(tensor_or_iterable): 264 | if isinstance(tensor_or_iterable, torch.Tensor): 265 | return tensor_or_iterable.tolist() 266 | return list(tensor_or_iterable) 267 | 268 | if padding_side == "right": 269 | batch[label_name] = [ 270 | to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels 271 | ] 272 | else: 273 | batch[label_name] = [ 274 | [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels 275 | ] 276 | 277 | batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64) 278 | if "gold" in features[0]: 279 | batch["gold"] = [feature["gold"] for feature in features] 280 | 281 | return batch 282 | 283 | 284 | class SIGUSR1Callback(transformers.TrainerCallback): 285 | """ 286 | This callback is used to save the model when a SIGUSR1 signal is received 287 | (SLURM stop signal or a keyboard interruption signal). 288 | """ 289 | 290 | def __init__(self) -> None: 291 | super().__init__() 292 | self.signal_received = False 293 | signal.signal(signal.SIGUSR1, self.handle_signal) 294 | signal.signal(signal.SIGINT, self.handle_signal) 295 | logger.warn("Handler registered") 296 | 297 | def handle_signal(self, signum, frame): 298 | self.signal_received = True 299 | logger.warn("Signal received") 300 | 301 | def on_step_end(self, args, state, control, **kwargs): 302 | if self.signal_received: 303 | control.should_save = True 304 | control.should_training_stop = True 305 | 306 | def on_train_end(self, args, state, control, **kwargs): 307 | if self.signal_received: 308 | exit(0) 309 | 310 | 311 | @dataclass 312 | class Prediction: 313 | correct_candidate: Union[int, str] 314 | predicted_candidate: Union[int, str] 315 | 316 | 317 | @contextlib.contextmanager 318 | def count_time(name): 319 | logger.info("%s..." % name) 320 | start_time = time.time() 321 | try: 322 | yield 323 | finally: 324 | logger.info("Done with %.2fs" % (time.time() - start_time)) 325 | 326 | 327 | @contextlib.contextmanager 328 | def temp_seed(seed): 329 | state = np.random.get_state() 330 | np.random.seed(seed) 331 | try: 332 | yield 333 | finally: 334 | np.random.set_state(state) 335 | 336 | 337 | class EnhancedJSONEncoder(json.JSONEncoder): 338 | def default(self, o): 339 | if is_dataclass(o): 340 | return asdict(o) 341 | return super().default(o) 342 | 343 | 344 | def write_predictions_to_file(final_preds, output): 345 | with open(output, "w") as f: 346 | for pred in final_preds: 347 | f.write(json.dumps(pred, cls=EnhancedJSONEncoder) + "\n") 348 | 349 | 350 | def write_metrics_to_file(metrics, output): 351 | json.dump(metrics, open(output, "w"), cls=EnhancedJSONEncoder, indent=4) -------------------------------------------------------------------------------- /tasks.py: -------------------------------------------------------------------------------- 1 | from templates import * 2 | from utils import temp_seed 3 | import json 4 | import os 5 | from datasets import load_dataset 6 | from dataclasses import dataclass 7 | from typing import List, Union 8 | import string 9 | import random 10 | import datasets 11 | import sys 12 | import numpy as np 13 | import logging 14 | 15 | logger = logging.getLogger(__name__) 16 | logger.setLevel(logging.INFO) 17 | 18 | 19 | def get_task(task_name): 20 | aa = task_name.split("__") 21 | if len(aa) == 2: 22 | task_group, subtask = aa 23 | else: 24 | task_group = aa[0] 25 | subtask = None 26 | class_ = getattr(sys.modules[__name__], f"{task_group}Dataset") 27 | instance = class_(subtask) 28 | return instance 29 | 30 | 31 | @dataclass 32 | class Sample: 33 | id: int = None 34 | data: dict = None 35 | correct_candidate: Union[str, List[str]] = None 36 | candidates: List[str] = None 37 | 38 | 39 | class Dataset: 40 | mixed_set = False 41 | train_sep = "\n\n" 42 | generation = False # whether this is a generation task 43 | 44 | def __init__(self, subtask=None, **kwargs) -> None: 45 | self.subtask = subtask 46 | 47 | def get_task_name(self): 48 | return self.subtask 49 | 50 | def load_dataset(): 51 | raise NotImplementedError 52 | 53 | def get_template(self, template_version=0): 54 | templates = {0: Template} 55 | return templates[template_version] 56 | 57 | def build_sample(self, example): 58 | return 59 | 60 | def sample_train_sets(self, num_train=32, num_dev=None, num_eval=None, num_train_sets=None, seed=None): 61 | if seed is not None: 62 | # one train/demo set using the designated seed 63 | seeds = [seed] 64 | elif num_train_sets is not None: 65 | # num_train_sets train/demo sets 66 | seeds = list(range(num_train_sets)) 67 | else: 68 | # one train/demo set per evaluation sample 69 | assert num_dev is None # not supported 70 | len_valid_samples = len(self.samples["valid"]) if num_eval is None else num_eval 71 | with temp_seed(0): 72 | seeds = np.random.randint(0, 10000, len_valid_samples) 73 | 74 | train_samples = [] 75 | for i, set_seed in enumerate(seeds): 76 | if self.mixed_set: 77 | raise NotImplementedError 78 | train_samples.append(self.sample_subset(data_split="valid", seed=set_seed, num=num_train, exclude=i)) 79 | else: 80 | if num_dev is not None: 81 | train_samples.append(self.sample_subset(data_split="train", seed=set_seed, num=num_train+num_dev)) # dev set is included at the end of train set 82 | if num_train + num_dev > len(self.samples["train"]): 83 | logger.warn("num_train + num_dev > available training examples") 84 | else: 85 | train_samples.append(self.sample_subset(data_split="train", seed=set_seed, num=num_train)) 86 | if num_dev is not None: 87 | logger.info(f"Sample train set {len(train_samples[-1])}/{len(self.samples['train'])}") 88 | logger.info(f"... including dev set {num_dev} samples") 89 | return train_samples 90 | 91 | def sample_subset(self, data_split="train", seed=0, num=100, exclude=None): 92 | with temp_seed(seed): 93 | samples = self.samples[data_split] 94 | lens = len(samples) 95 | index = np.random.permutation(lens).tolist()[:num if exclude is None else num+1] 96 | if exclude is not None and exclude in index: 97 | index.remove(exclude) 98 | else: 99 | index = index[:num] 100 | return [samples[i] for i in index] 101 | 102 | @property 103 | def valid_samples(self): 104 | return self.samples["valid"] 105 | 106 | 107 | class SST2Dataset(Dataset): 108 | train_sep = "\n\n" 109 | def __init__(self, subtask=None, **kwargs) -> None: 110 | self.load_dataset(subtask, **kwargs) 111 | 112 | def load_dataset(self, path, **kwargs): 113 | d = load_dataset('glue', 'sst2') 114 | train_d = d["train"] 115 | validation_d = d["validation"] 116 | 117 | train_samples = [self.build_sample(example) for example in train_d] 118 | valid_samples = [self.build_sample(example) for example in validation_d] 119 | 120 | self.samples = {"train": train_samples, "valid": valid_samples} 121 | 122 | # for generative tasks, candidates are [] 123 | def build_sample(self, example): 124 | label = int(example["label"]) 125 | return Sample(id=example["idx"], data=example, correct_candidate=label, candidates=[0, 1]) 126 | 127 | def get_template(self, template_version=0): 128 | return {0: SST2Template}[template_version]() 129 | 130 | 131 | class CopaDataset(Dataset): 132 | train_sep = "\n\n" 133 | mixed_set = False 134 | 135 | def __init__(self, subtask=None, **kwargs) -> None: 136 | self.load_dataset(subtask, **kwargs) 137 | 138 | def load_dataset(self, path, **kwargs): 139 | train_examples = load_dataset('super_glue', "copa")["train"] 140 | valid_examples = load_dataset('super_glue', "copa")["validation"] 141 | 142 | train_samples = [self.build_sample(example) for example in train_examples] 143 | valid_samples = [self.build_sample(example) for example in valid_examples] 144 | self.samples = {"train": train_samples, "valid": valid_samples} 145 | 146 | # for generative tasks, candidates are [] 147 | def build_sample(self, example): 148 | sample = \ 149 | Sample( 150 | id=example["idx"], 151 | data=example, 152 | candidates=[example["choice1"], example["choice2"]], 153 | correct_candidate=example[f"choice{example['label'] + 1}"], 154 | ) 155 | 156 | return sample 157 | 158 | def get_template(self, template_version=0): 159 | return {0: CopaTemplate}[template_version]() 160 | 161 | 162 | class BoolQDataset(Dataset): 163 | def __init__(self, subtask=None, **kwargs) -> None: 164 | self.load_dataset(subtask, **kwargs) 165 | 166 | def load_dataset(self, path, **kwargs): 167 | d = load_dataset("boolq") 168 | train_set = d["train"] 169 | valid_set = d["validation"] 170 | 171 | train_samples = [self.build_sample(example) for example in train_set] 172 | valid_samples = [self.build_sample(example) for example in valid_set] 173 | self.samples = {"train": train_samples, "valid": valid_samples} 174 | 175 | def build_sample(self, example): 176 | sample = \ 177 | Sample( 178 | data=example, 179 | candidates=["Yes", "No"], 180 | correct_candidate="Yes" if example["answer"] else "No", 181 | ) 182 | 183 | return sample 184 | 185 | def get_template(self, template_version=2): 186 | return {0: BoolQTemplate, 1: BoolQTemplateV2, 2: BoolQTemplateV3}[template_version]() 187 | 188 | 189 | class MultiRCDataset(Dataset): 190 | 191 | def __init__(self, subtask=None, **kwargs) -> None: 192 | self.load_dataset(subtask, **kwargs) 193 | 194 | def load_dataset(self, path, **kwargs): 195 | d = load_dataset("super_glue", "multirc") 196 | train_set = d["train"] 197 | valid_set = d["validation"] 198 | 199 | train_samples = [self.build_sample(example) for example in train_set] 200 | valid_samples = [self.build_sample(example) for example in valid_set] 201 | self.samples = {"train": train_samples, "valid": valid_samples} 202 | 203 | def build_sample(self, example): 204 | sample = \ 205 | Sample( 206 | data=example, 207 | candidates=[0, 1], 208 | correct_candidate=example['label'] 209 | ) 210 | 211 | return sample 212 | 213 | def get_template(self, template_version=0): 214 | return {0: MultiRCTemplate}[template_version]() 215 | 216 | 217 | class CBDataset(Dataset): 218 | 219 | def __init__(self, subtask=None, **kwargs) -> None: 220 | self.load_dataset(subtask, **kwargs) 221 | 222 | def load_dataset(self, path, **kwargs): 223 | d = load_dataset("super_glue", "cb") 224 | train_set = d["train"] 225 | valid_set = d["validation"] 226 | 227 | train_samples = [self.build_sample(example) for example in train_set] 228 | valid_samples = [self.build_sample(example) for example in valid_set] 229 | self.samples = {"train": train_samples, "valid": valid_samples} 230 | 231 | def build_sample(self, example): 232 | sample = \ 233 | Sample( 234 | data=example, 235 | candidates=[0, 1, 2], 236 | correct_candidate=example['label'] 237 | ) 238 | 239 | return sample 240 | 241 | def get_template(self, template_version=0): 242 | return {0: CBTemplate}[template_version]() 243 | 244 | 245 | class WICDataset(Dataset): 246 | 247 | def __init__(self, subtask=None, **kwargs) -> None: 248 | self.load_dataset(subtask, **kwargs) 249 | 250 | def load_dataset(self, path, **kwargs): 251 | d = load_dataset("super_glue", "wic") 252 | train_set = d["train"] 253 | valid_set = d["validation"] 254 | 255 | train_samples = [self.build_sample(example) for example in train_set] 256 | valid_samples = [self.build_sample(example) for example in valid_set] 257 | self.samples = {"train": train_samples, "valid": valid_samples} 258 | 259 | def build_sample(self, example): 260 | sample = \ 261 | Sample( 262 | data=example, 263 | candidates=[0, 1], 264 | correct_candidate=example['label'] 265 | ) 266 | 267 | return sample 268 | 269 | def get_template(self, template_version=0): 270 | return {0: WICTemplate}[template_version]() 271 | 272 | 273 | class WSCDataset(Dataset): 274 | 275 | def __init__(self, subtask=None, **kwargs) -> None: 276 | self.load_dataset(subtask, **kwargs) 277 | 278 | def load_dataset(self, path, **kwargs): 279 | d = load_dataset("super_glue", "wsc.fixed") 280 | train_set = d["train"] 281 | valid_set = d["validation"] 282 | 283 | train_samples = [self.build_sample(example) for example in train_set] 284 | valid_samples = [self.build_sample(example) for example in valid_set] 285 | self.samples = {"train": train_samples, "valid": valid_samples} 286 | 287 | def build_sample(self, example): 288 | sample = \ 289 | Sample( 290 | data=example, 291 | candidates=[0, 1], 292 | correct_candidate=example['label'] 293 | ) 294 | 295 | return sample 296 | 297 | def get_template(self, template_version=0): 298 | return {0: WSCTemplate}[template_version]() 299 | 300 | 301 | class ReCoRDDataset(Dataset): 302 | 303 | def __init__(self, subtask=None, **kwargs) -> None: 304 | self.load_dataset(subtask, **kwargs) 305 | 306 | def load_dataset(self, path, **kwargs): 307 | d = load_dataset("super_glue", "record") 308 | train_set = d["train"] 309 | valid_set = d["validation"] 310 | 311 | train_samples = [self.build_sample(example) for example in train_set] 312 | valid_samples = [self.build_sample(example) for example in valid_set] 313 | self.samples = {"train": train_samples, "valid": valid_samples} 314 | 315 | def build_sample(self, example): 316 | sample = \ 317 | Sample( 318 | data=example, 319 | candidates=example['entities'], 320 | correct_candidate=example['answers'] 321 | ) 322 | 323 | return sample 324 | 325 | def get_template(self, template_version=0): 326 | return {0: ReCoRDTemplateGPT3}[template_version]() 327 | 328 | 329 | class RTEDataset(Dataset): 330 | 331 | def __init__(self, subtask=None, **kwargs) -> None: 332 | self.load_dataset(subtask, **kwargs) 333 | 334 | def load_dataset(self, path, **kwargs): 335 | d = load_dataset("super_glue", "rte") 336 | train_set = d["train"] 337 | valid_set = d["validation"] 338 | 339 | train_samples = [self.build_sample(example) for example in train_set] 340 | valid_samples = [self.build_sample(example) for example in valid_set] 341 | self.samples = {"train": train_samples, "valid": valid_samples} 342 | 343 | def build_sample(self, example): 344 | sample = \ 345 | Sample( 346 | data=example, 347 | candidates=[0, 1], 348 | correct_candidate=example['label'] 349 | ) 350 | 351 | return sample 352 | 353 | def get_template(self, template_version=0): 354 | return {0: RTETemplate}[template_version]() 355 | 356 | 357 | class SQuADDataset(Dataset): 358 | metric_name = "f1" 359 | generation = True 360 | 361 | def __init__(self, subtask=None, **kwargs) -> None: 362 | self.load_dataset() 363 | 364 | def load_dataset(self): 365 | dataset = load_dataset("squad") 366 | train_examples = dataset["train"] 367 | valid_examples = dataset["validation"] 368 | 369 | train_samples = [self.build_sample(example, idx) for idx, example in enumerate(train_examples)] 370 | valid_samples = [self.build_sample(example, idx) for idx, example in enumerate(valid_examples)] 371 | self.samples = {"train": train_samples, "valid": valid_samples} 372 | 373 | # for generative tasks, candidates are [] 374 | def build_sample(self, example, idx): 375 | answers = example['answers']['text'] 376 | assert len(answers) > 0 377 | return Sample( 378 | id=idx, 379 | data={ 380 | "title": example['title'], 381 | "context": example['context'], 382 | "question": example['question'], 383 | "answers": answers 384 | }, 385 | candidates=None, 386 | correct_candidate=answers 387 | ) 388 | 389 | def get_template(self, template_version=0): 390 | return {0: SQuADv2Template}[template_version]() 391 | 392 | 393 | class DROPDataset(Dataset): 394 | metric_name = "f1" 395 | generation = True 396 | 397 | def __init__(self, subtask=None, **kwargs) -> None: 398 | self.load_dataset() 399 | 400 | def load_dataset(self): 401 | dataset = load_dataset("drop") 402 | train_examples = dataset["train"] 403 | valid_examples = dataset["validation"] 404 | 405 | train_samples = [self.build_sample(example, idx) for idx, example in enumerate(train_examples)] 406 | valid_samples = [self.build_sample(example, idx) for idx, example in enumerate(valid_examples)] 407 | self.samples = {"train": train_samples, "valid": valid_samples} 408 | 409 | # for generative tasks, candidates are [] 410 | def build_sample(self, example, idx): 411 | answers = example['answers_spans']['spans'] 412 | assert len(answers) > 0 413 | return Sample( 414 | id=idx, 415 | data={ 416 | "context": example['passage'], 417 | "question": example['question'], 418 | "answers": answers 419 | }, 420 | candidates=None, 421 | correct_candidate=answers 422 | ) 423 | 424 | def get_template(self, template_version=0): 425 | return {0: DROPTemplate}[template_version]() 426 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 4 | logger = logging.getLogger(__name__) 5 | logger.setLevel(logging.INFO) 6 | 7 | import argparse 8 | import time 9 | import tasks 10 | from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, Trainer, HfArgumentParser, Trainer, TrainingArguments, DataCollatorWithPadding, DataCollatorForTokenClassification 11 | from typing import Union, Optional 12 | import torch 13 | from torch.nn.parameter import Parameter 14 | import numpy as np 15 | from dataclasses import dataclass, is_dataclass, asdict 16 | from tqdm import tqdm 17 | from tasks import get_task 18 | import json 19 | import torch.nn.functional as F 20 | from torch.utils.data import Dataset 21 | from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP 22 | from metrics import calculate_metric 23 | from utils import * 24 | from trainer import OurTrainer 25 | import random 26 | @dataclass 27 | class OurArguments(TrainingArguments): 28 | # dataset and sampling strategy 29 | task_name: str = "SST2" # task name should match the string before Dataset in the Dataset class name. We support the following task_name: SST2, RTE, CB, BoolQ, WSC, WIC, MultiRC, Copa, ReCoRD, SQuAD, DROP 30 | 31 | # Number of examples 32 | num_train: int = 0 # ICL mode: number of demonstrations; training mode: number of training samples 33 | num_dev: int = None # (only enabled with training) number of development samples 34 | num_eval: int = None # number of evaluation samples 35 | num_train_sets: int = None # how many sets of training samples/demos to sample; if None and train_set_seed is None, then we will sample one set for each evaluation sample 36 | train_set_seed: int = None # designated seed to sample training samples/demos 37 | result_file: str = None # file name for saving performance; if None, then use the task name, model name, and config 38 | do_grad_scaling: bool = False 39 | 40 | # Model loading 41 | model_name: str = "facebook/opt-125m" # HuggingFace model name 42 | load_float16: bool = False # load model parameters as float16 43 | load_bfloat16: bool = False # load model parameters as bfloat16 44 | load_int8: bool = False # load model parameters as int8 45 | max_length: int = 2048 # max length the model can take 46 | no_auto_device: bool = False # do not load model by auto device; should turn this on when using FSDP 47 | 48 | # Calibration 49 | sfc: bool = False # whether to use SFC calibration 50 | icl_sfc: bool = False # whether to use SFC calibration for ICL samples 51 | 52 | # Training 53 | trainer: str = "none" 54 | ## options 55 | ## - none: no training -- for zero-shot or in-context learning (ICL) 56 | ## - regular: regular huggingface trainer -- for fine-tuning 57 | ## - zo: zeroth-order (MeZO) training 58 | only_train_option: bool = True # whether to only train the option part of the input 59 | train_as_classification: bool = False # take the log likelihood of all options and train as classification 60 | 61 | # MeZO 62 | zo_eps: float = 1e-3 # eps in MeZO 63 | 64 | ###############diag 65 | warmup_step: int = 0 66 | decay_step: int = 0 67 | zo_lr_scheduler_type: str = 'constant' 68 | weight_decay: float = 0 69 | hessian_smooth_type: str = 'constant0' 70 | 71 | # Prefix tuning 72 | prefix_tuning: bool = False # whether to use prefix tuning 73 | num_prefix: int = 5 # number of prefixes to use 74 | no_reparam: bool = True # do not use reparameterization trick 75 | prefix_init_by_real_act: bool = True # initialize prefix by real activations of random words 76 | 77 | # LoRA 78 | lora: bool = False # whether to use LoRA 79 | lora_alpha: int = 16 # alpha in LoRA 80 | lora_r: int = 8 # r in LoRA 81 | 82 | # Generation 83 | sampling: bool = False # whether to use sampling 84 | temperature: float = 1.0 # temperature for generation 85 | num_beams: int = 1 # number of beams for generation 86 | top_k: int = None # top-k for generation 87 | top_p: float = 0.95 # top-p for generation 88 | max_new_tokens: int = 50 # max number of new tokens to generate 89 | eos_token: str = "\n" # end of sentence token 90 | 91 | # Saving 92 | save_model: bool = False # whether to save the model 93 | no_eval: bool = False # whether to skip evaluation 94 | tag: str = "" # saving tag 95 | 96 | # Linear probing 97 | linear_probing: bool = False # whether to do linear probing 98 | lp_early_stopping: bool = False # whether to do early stopping in linear probing 99 | head_tuning: bool = False # head tuning: only tune the LM head 100 | 101 | # Untie emb/lm_head weights 102 | untie_emb: bool = False # untie the embeddings and LM head 103 | 104 | # Display 105 | verbose: bool = False # verbose output 106 | 107 | # Non-diff objective 108 | non_diff: bool = False # use non-differentiable objective (only support F1 for SQuAD for now) 109 | 110 | # Auto saving when interrupted 111 | save_on_interrupt: bool = False # save model when interrupted (useful for long training) 112 | 113 | 114 | def parse_args(): 115 | parser = argparse.ArgumentParser() 116 | parser = HfArgumentParser(OurArguments) 117 | args = parser.parse_args_into_dataclasses()[0] 118 | print(args) 119 | return args 120 | 121 | 122 | def set_seed(seed: int): 123 | random.seed(seed) 124 | np.random.seed(seed) 125 | torch.manual_seed(seed) 126 | torch.cuda.manual_seed_all(seed) 127 | 128 | 129 | class Framework: 130 | 131 | def __init__(self, args, task): 132 | self.args = args 133 | self.task = task 134 | self.model, self.tokenizer = self.load_model() 135 | 136 | 137 | def load_model(self): 138 | """ 139 | Load HuggingFace models 140 | """ 141 | with count_time("Loading model with FP%d" % (16 if self.args.load_float16 else 32)): 142 | free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3) 143 | config = AutoConfig.from_pretrained(self.args.model_name, trust_remote_code=True) 144 | if self.args.untie_emb: 145 | # Untie embeddings/LM head 146 | logger.warn("Untie embeddings and LM head") 147 | config.tie_word_embeddings = False 148 | if self.args.head_tuning: 149 | # Head tuning 150 | from ht_opt import OPTForCausalLM 151 | model = OPTForCausalLM.from_pretrained( 152 | self.args.model_name, 153 | config=config, 154 | ) 155 | elif self.args.no_auto_device: 156 | # No auto device (use for FSDP) 157 | model = AutoModelForCausalLM.from_pretrained( 158 | self.args.model_name, 159 | config=config, 160 | ) 161 | else: 162 | # Auto device loading 163 | torch_dtype = torch.float32 164 | if self.args.load_float16: 165 | torch_dtype = torch.float16 166 | elif self.args.load_bfloat16: 167 | torch_dtype = torch.bfloat16 168 | model = AutoModelForCausalLM.from_pretrained( 169 | self.args.model_name, 170 | config=config, 171 | device_map='auto', 172 | torch_dtype=torch_dtype, 173 | max_memory={i: f'{free_in_GB-5}GB' for i in range(torch.cuda.device_count())}, 174 | load_in_8bit=self.args.load_int8, trust_remote_code=True 175 | ) 176 | 177 | 178 | model.eval() 179 | 180 | # Load tokenizer 181 | tokenizer = AutoTokenizer.from_pretrained(self.args.model_name, use_fast=False) 182 | 183 | tokenizer.pad_token_id = 0 184 | 185 | # HF tokenizer bug fix 186 | if "opt" in self.args.model_name: 187 | tokenizer.bos_token_id = 0 188 | 189 | # Prefix tuning/LoRA 190 | if self.args.prefix_tuning: 191 | from prefix import PrefixTuning 192 | PrefixTuning(model, num_prefix=self.args.num_prefix, reparam=not self.args.no_reparam, float16=self.args.load_float16, init_by_real_act=self.args.prefix_init_by_real_act) 193 | if self.args.lora: 194 | from lora import LoRA 195 | LoRA(model, r=self.args.lora_r, alpha=self.args.lora_alpha, float16=self.args.load_float16) 196 | 197 | if self.args.head_tuning: 198 | if model.config.model_type == "opt": 199 | head_name = "lm_head" if self.args.untie_emb else "embed_tokens" 200 | else: 201 | raise NotImplementedError 202 | for n, p in model.named_parameters(): 203 | if head_name not in n: 204 | p.requires_grad = False 205 | else: 206 | logger.info(f"Only tuning {n}") 207 | 208 | return model, tokenizer 209 | 210 | 211 | def forward(self, input_ids, option_len=None, generation=False): 212 | """ 213 | Given input_ids and the length of the option, return the log-likelihood of each token in the option. 214 | For generation tasks, return the generated text. 215 | This function is only for inference 216 | """ 217 | input_ids = torch.tensor([input_ids]).to(self.model.device) 218 | 219 | if generation: 220 | args = self.args 221 | # Autoregressive generation 222 | outputs = self.model.generate( 223 | input_ids, do_sample=args.sampling, temperature=args.temperature, 224 | num_beams=args.num_beams, top_p=args.top_p, top_k=args.top_k, max_new_tokens=min(args.max_new_tokens, args.max_length - input_ids.size(1)), 225 | num_return_sequences=1, eos_token_id=[self.tokenizer.encode(args.eos_token, add_special_tokens=False)[0], self.tokenizer.eos_token_id], 226 | ) 227 | # For generation, directly return the text output 228 | output_text = self.tokenizer.decode(outputs[0][input_ids.size(1):], skip_special_tokens=True).strip() 229 | return output_text 230 | else: 231 | with torch.inference_mode(): 232 | self.model.eval() 233 | logits = self.model(input_ids=input_ids).logits 234 | labels = input_ids[0, 1:] 235 | logits = logits[0, :-1] 236 | log_probs = F.log_softmax(logits, dim=-1) 237 | 238 | selected_log_probs = log_probs[torch.arange(len(labels)).to(labels.device), labels] 239 | selected_log_probs = selected_log_probs.cpu().detach() 240 | # Only return the option (candidate) part 241 | return selected_log_probs[-option_len:] 242 | 243 | 244 | def one_step_pred(self, train_samples, eval_sample, verbose=False): 245 | """ 246 | Return the prediction on the eval sample. In ICL, use train_samples as demonstrations 247 | """ 248 | verbose = verbose or self.args.verbose 249 | if verbose: 250 | logger.info("========= Example =========") 251 | logger.info(f"Candidate: {eval_sample.candidates}") 252 | logger.info(f"Correct candidate: {eval_sample.correct_candidate}") 253 | 254 | 255 | # Encode (add prompt and tokenize) the sample; if multiple-choice/classification, encode all candidates (options) 256 | encoded_candidates, option_lens = encode_prompt( 257 | self.task, self.task.get_template(), train_samples, eval_sample, self.tokenizer, max_length=self.args.max_length, 258 | generation=self.task.generation, max_new_tokens=self.args.max_new_tokens 259 | ) 260 | 261 | # Calibration 262 | if self.args.sfc or self.args.icl_sfc: 263 | sfc_encoded_candidates, sfc_option_lens = encode_prompt(self.task, self.task.get_template(), 264 | train_samples, eval_sample, self.tokenizer, max_length=self.args.max_length, 265 | sfc=self.args.sfc, icl_sfc=self.args.icl_sfc, generation=self.task.generation, 266 | max_new_tokens=self.args.max_new_tokens 267 | ) 268 | 269 | outputs = [] 270 | if self.task.generation: 271 | # For generation tasks, return the autoregressively-generated text 272 | output_text = self.forward(encoded_candidates[0], generation=True) 273 | if verbose: 274 | logger.info("=== Prompt ===") 275 | logger.info(self.tokenizer.decode(encoded_candidates[0])) 276 | logger.info(f"Output: {output_text}") 277 | return Prediction(correct_candidate=eval_sample.correct_candidate, predicted_candidate=output_text) 278 | else: 279 | # For classification/multiple-choice, calculate the probabilities of all candidates 280 | for candidate_id, encoded_candidate in enumerate(encoded_candidates): 281 | selected_log_probs = self.forward(encoded_candidate, option_len=option_lens[candidate_id]) 282 | if verbose: 283 | if candidate_id == 0: 284 | logger.info("=== Candidate %d ===" % candidate_id) 285 | logger.info(self.tokenizer.decode(encoded_candidate)) 286 | else: 287 | logger.info("=== Candidate %d (without context)===" % candidate_id) 288 | logger.info(self.tokenizer.decode(encoded_candidate).split(self.task.train_sep)[-1]) 289 | logger.info(f"Log probabilities of the option tokens: {selected_log_probs}") 290 | 291 | if self.args.sfc or self.args.icl_sfc: 292 | sfc_selected_log_probs = self.forward(sfc_encoded_candidates[candidate_id], option_len=sfc_option_lens[candidate_id]) 293 | if verbose: 294 | logger.info("=== Candidate %d (without context) SFC ===" % candidate_id) 295 | logger.info(self.tokenizer.decode(sfc_encoded_candidates[candidate_id]).split(self.task.train_sep)[-1]) 296 | logger.info(f"Log probabilities of the option tokens: {sfc_selected_log_probs}") 297 | 298 | outputs.append({"log_probs": selected_log_probs, "sfc_log_probs": sfc_selected_log_probs if self.args.sfc or self.args.icl_sfc else None}) 299 | 300 | if self.args.sfc or self.args.icl_sfc: 301 | # Calibrated probabilities (surface form competition; https://arxiv.org/pdf/2104.08315.pdf) 302 | # log p(candidate | input) = log p_lm(candidate | input) - log p_lm(candidate | sfc prompt) 303 | scores = [x['log_probs'].sum().item() - x['sfc_log_probs'].sum().item() for x in outputs] 304 | else: 305 | # (Default) length-normalized log probabilities 306 | # log p(candidate | input) = log p_lm(candidate | input) / |candidate #tokens| 307 | scores = [x['log_probs'].mean().item() for x in outputs] 308 | 309 | if verbose: 310 | logger.info(f"Prediction scores: {scores}") 311 | 312 | if isinstance(eval_sample.correct_candidate, list): 313 | # For some datasets there are multiple correct answers 314 | correct_candidate_id = [eval_sample.candidates.index(c) for c in eval_sample.correct_candidate] 315 | else: 316 | correct_candidate_id = eval_sample.candidates.index(eval_sample.correct_candidate) 317 | 318 | return Prediction(correct_candidate=correct_candidate_id, predicted_candidate=int(np.argmax(scores))) 319 | 320 | 321 | def evaluate(self, train_samples, eval_samples, one_train_set_per_eval_sample=False): 322 | """ 323 | Evaluate function. If one_train_set_per_eval_sample is True, then each eval sample has its own training (demonstration) set. 324 | """ 325 | if one_train_set_per_eval_sample: 326 | logger.info(f"There are {len(eval_samples)} validation samples and one train set per eval sample") 327 | else: 328 | logger.info(f"There are {len(train_samples)} training samples and {len(eval_samples)} validation samples") 329 | 330 | # Prediction loop 331 | predictions = [] 332 | for eval_id, eval_sample in enumerate(tqdm(eval_samples)): 333 | predictions.append( 334 | self.one_step_pred(train_samples[eval_id] if one_train_set_per_eval_sample else train_samples, eval_sample, verbose=(eval_id < 3)) 335 | ) 336 | 337 | # Calculate metrics 338 | metric_name = getattr(self.task, "metric_name", "accuracy") 339 | metrics = {metric_name: calculate_metric(predictions, metric_name)} 340 | return metrics 341 | 342 | 343 | def train(self, train_samples, eval_samples): 344 | """ 345 | Training function 346 | """ 347 | # Set tokenizer to left padding (so that all the options are right aligned) 348 | self.tokenizer.padding_side = "left" 349 | 350 | class HFDataset(Dataset): 351 | 352 | def __init__(self, data): 353 | self.data = data 354 | 355 | def __len__(self): 356 | return len(self.data) 357 | 358 | def __getitem__(self, idx): 359 | return self.data[idx] 360 | 361 | 362 | def _convert(samples): 363 | """ 364 | Convert samples to HF-compatible dataset 365 | """ 366 | data = [] 367 | for sample in samples: 368 | encoded_candidates, option_lens = encode_prompt( 369 | self.task, self.task.get_template(), [], sample, self.tokenizer, 370 | max_length=self.args.max_length, generation=self.task.generation, generation_with_gold=True, 371 | max_new_tokens=self.args.max_new_tokens 372 | ) 373 | if self.task.generation: 374 | correct_candidate_id = 0 375 | elif isinstance(sample.correct_candidate, list): 376 | correct_candidate_id = sample.candidates.index(sample.correct_candidate[0]) 377 | else: 378 | correct_candidate_id = sample.candidates.index(sample.correct_candidate) 379 | 380 | if self.args.non_diff: 381 | # For non-differentiable objective, there is no teacher forcing thus the 382 | # current answer part is removed 383 | encoded_candidates[correct_candidate_id] = encoded_candidates[correct_candidate_id][:-option_lens[correct_candidate_id]] 384 | 385 | if self.args.train_as_classification: 386 | # For classification, we provide the label as the correct candidate id 387 | data.append([{"input_ids": encoded_candidates[_i], "labels": correct_candidate_id, "option_len": option_lens[_i], "num_options": len(sample.candidates)} for _i in range(len(encoded_candidates))]) 388 | elif self.args.only_train_option: 389 | # Otherwise, it is just LM-style teacher forcing 390 | if self.args.non_diff: 391 | # For non-differentiable objective, we need to provide the gold answer to calculate F1/acc 392 | data.append({"input_ids": encoded_candidates[correct_candidate_id], "labels": encoded_candidates[correct_candidate_id], "option_len": option_lens[correct_candidate_id], "gold": sample.correct_candidate}) 393 | else: 394 | data.append({"input_ids": encoded_candidates[correct_candidate_id], "labels": encoded_candidates[correct_candidate_id], "option_len": option_lens[correct_candidate_id]}) 395 | else: 396 | data.append({"input_ids": encoded_candidates[correct_candidate_id], "labels": encoded_candidates[correct_candidate_id]}) 397 | return data 398 | 399 | with count_time("Tokenizing training samples"): 400 | train_dataset = HFDataset(_convert(train_samples)) 401 | eval_dataset = HFDataset(_convert(eval_samples)) 402 | 403 | if self.args.only_train_option and not self.args.non_diff: 404 | # If --only_train_option and not with a non-differentiable objective, we wrap the forward function 405 | self.model.original_forward = self.model.forward 406 | self.model.forward = forward_wrap_with_option_len.__get__(self.model, type(self.model)) 407 | 408 | if self.args.non_diff: 409 | collator = NondiffCollator 410 | else: 411 | collator = DataCollatorForTokenClassification 412 | 413 | trainer = OurTrainer( 414 | model=self.model, 415 | args=self.args, 416 | train_dataset=train_dataset, 417 | eval_dataset=eval_dataset, 418 | tokenizer=self.tokenizer, 419 | data_collator=DataCollatorWithPaddingAndNesting(self.tokenizer, pad_to_multiple_of=8) if self.args.train_as_classification else collator(self.tokenizer, pad_to_multiple_of=8), 420 | ) 421 | if self.args.save_on_interrupt: 422 | trainer.add_callback(SIGUSR1Callback()) 423 | 424 | # Resume training from a last checkpoint 425 | last_checkpoint = None 426 | 427 | from transformers.trainer_utils import get_last_checkpoint 428 | if os.path.isdir(self.args.output_dir) and not self.args.overwrite_output_dir: 429 | last_checkpoint = get_last_checkpoint(self.args.output_dir) 430 | if last_checkpoint is not None and self.args.resume_from_checkpoint is None: 431 | logger.info( 432 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 433 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 434 | ) 435 | if self.args.resume_from_checkpoint is not None: 436 | last_checkpoint = self.args.resume_from_checkpoint 437 | 438 | trainer.train(resume_from_checkpoint=last_checkpoint) 439 | 440 | # Explicitly save the model 441 | if self.args.save_model: 442 | logger.warn("Save model..") 443 | trainer.save_model() 444 | 445 | # FSDP compatibility 446 | self.model = trainer.model 447 | 448 | # Reset the forward function for evaluation 449 | if self.args.only_train_option and not self.args.non_diff: 450 | if type(self.model) == FSDP: 451 | logger.info("This is an FSDP model now. Be careful when assigning back the original forward function") 452 | self.model._fsdp_wrapped_module.forward = self.model._fsdp_wrapped_module.original_forward 453 | else: 454 | self.model.forward = self.model.original_forward 455 | 456 | 457 | def result_file_tag(args): 458 | """ 459 | Get the result file tag 460 | """ 461 | save_model_name = args.model_name.split("/")[-1] 462 | sfc_tag = "-sfc" if args.sfc else "" 463 | icl_sfc_tag = "-icl_sfc" if args.icl_sfc else "" 464 | sample_eval_tag = "-sampleeval%d" % args.num_eval if args.num_eval is not None else "" 465 | sample_train_tag = "-ntrain%d" % args.num_train if args.num_train > 0 else "" 466 | sample_dev_tag = "-ndev%d" % args.num_dev if args.num_dev is not None else "" 467 | customized_tag = f"-{args.tag}" if len(args.tag) > 0 else "" 468 | return f"{args.task_name}-{save_model_name}" + sfc_tag + icl_sfc_tag + sample_eval_tag + sample_train_tag + sample_dev_tag + customized_tag 469 | 470 | 471 | def main(): 472 | args = parse_args() 473 | 474 | set_seed(args.seed) 475 | task = get_task(args.task_name) 476 | train_sets = task.sample_train_sets(num_train=args.num_train, num_dev=args.num_dev, num_eval=args.num_eval, num_train_sets=args.num_train_sets, seed=args.train_set_seed) 477 | 478 | # Initialize trainer and load model 479 | framework = Framework(args, task) 480 | if args.train_set_seed is not None or args.num_train_sets is not None: 481 | # Eval samples share one (or multiple) training set(s) 482 | for train_set_id, train_samples in enumerate(train_sets): 483 | train_set_seed = train_set_id if args.train_set_seed is None else args.train_set_seed 484 | 485 | # Sample eval samples 486 | if args.num_eval is not None: 487 | eval_samples = task.sample_subset(data_split="valid", seed=train_set_seed, num=args.num_eval) 488 | else: 489 | eval_samples = task.valid_samples 490 | 491 | if args.trainer != "none": 492 | if args.num_dev is not None: 493 | # Dev samples 494 | dev_samples = train_samples[-args.num_dev:] 495 | train_samples = train_samples[:-args.num_dev] 496 | else: 497 | dev_samples = None 498 | 499 | # Training 500 | framework.train(train_samples, dev_samples if dev_samples is not None else eval_samples) 501 | 502 | if not args.no_eval: 503 | metrics = framework.evaluate([], eval_samples) # No in-context learning if there is training 504 | if dev_samples is not None: 505 | dev_metrics = framework.evaluate([], dev_samples) 506 | for m in dev_metrics: 507 | metrics["dev_" + m] = dev_metrics[m] 508 | else: 509 | assert args.num_dev is None 510 | # Zero-shot / in-context learning 511 | metrics = framework.evaluate(train_samples, eval_samples) 512 | 513 | if not args.no_eval: 514 | logger.info("===== Train set %d =====" % train_set_seed) 515 | logger.info(metrics) 516 | if args.local_rank <= 0: 517 | write_metrics_to_file(metrics, "result/" + result_file_tag(args) + f"-trainset{train_set_id}.json" if args.result_file is None else args.result_file) 518 | 519 | else: 520 | # For each eval sample, there is a training set. no training is allowed 521 | # This is for in-context learning (ICL) 522 | assert args.trainer == "none" 523 | if args.num_eval is not None: 524 | eval_samples = task.sample_subset(data_split="valid", seed=0, num=args.num_eval) 525 | else: 526 | eval_samples = task.valid_samples 527 | 528 | metrics = framework.evaluate(train_sets, eval_samples, one_train_set_per_eval_sample=True) 529 | logger.info(metrics) 530 | 531 | if __name__ == "__main__": 532 | main() 533 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import functools 3 | import glob 4 | import inspect 5 | import math 6 | import os 7 | import random 8 | import re 9 | import shutil 10 | import sys 11 | import time 12 | import warnings 13 | from collections.abc import Mapping 14 | from pathlib import Path 15 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union 16 | import copy 17 | from metrics import f1 18 | import numpy as np 19 | from tqdm.auto import tqdm 20 | from transformers import Trainer 21 | from sklearn.linear_model import LinearRegression, LogisticRegression, LogisticRegressionCV 22 | 23 | # Integrations must be imported before ML frameworks: 24 | from transformers.integrations import ( 25 | get_reporting_integration_callbacks, 26 | hp_params, 27 | is_optuna_available, 28 | is_ray_tune_available, 29 | is_sigopt_available, 30 | is_wandb_available, 31 | run_hp_search_optuna, 32 | run_hp_search_ray, 33 | run_hp_search_sigopt, 34 | run_hp_search_wandb, 35 | ) 36 | 37 | import numpy as np 38 | import torch 39 | import torch.distributed as dist 40 | from packaging import version 41 | from torch import nn 42 | from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler 43 | from torch.utils.data.distributed import DistributedSampler 44 | from huggingface_hub import Repository 45 | 46 | from transformers import __version__ 47 | from transformers.configuration_utils import PretrainedConfig 48 | from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator 49 | from transformers.debug_utils import DebugOption, DebugUnderflowOverflow 50 | from transformers.dependency_versions_check import dep_version_check 51 | from transformers.modelcard import TrainingSummary 52 | from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model 53 | from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES 54 | from transformers.optimization import Adafactor, get_scheduler 55 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 56 | from transformers.trainer_callback import ( 57 | CallbackHandler, 58 | DefaultFlowCallback, 59 | PrinterCallback, 60 | ProgressCallback, 61 | TrainerCallback, 62 | TrainerControl, 63 | TrainerState, 64 | ExportableState, 65 | ) 66 | from transformers.trainer_pt_utils import ( 67 | DistributedLengthGroupedSampler, 68 | DistributedSamplerWithLoop, 69 | DistributedTensorGatherer, 70 | IterableDatasetShard, 71 | LabelSmoother, 72 | LengthGroupedSampler, 73 | SequentialDistributedSampler, 74 | ShardSampler, 75 | distributed_broadcast_scalars, 76 | distributed_concat, 77 | find_batch_size, 78 | get_module_class_from_name, 79 | get_parameter_names, 80 | nested_concat, 81 | nested_detach, 82 | nested_numpify, 83 | nested_truncate, 84 | nested_xla_mesh_reduce, 85 | reissue_pt_warnings, 86 | ) 87 | from transformers.trainer_utils import ( 88 | PREFIX_CHECKPOINT_DIR, 89 | BestRun, 90 | EvalLoopOutput, 91 | EvalPrediction, 92 | FSDPOption, 93 | HPSearchBackend, 94 | HubStrategy, 95 | IntervalStrategy, 96 | PredictionOutput, 97 | RemoveColumnsCollator, 98 | TrainerMemoryTracker, 99 | TrainOutput, 100 | default_compute_objective, 101 | denumpify_detensorize, 102 | enable_full_determinism, 103 | find_executable_batch_size, 104 | get_last_checkpoint, 105 | has_length, 106 | number_of_arguments, 107 | seed_worker, 108 | set_seed, 109 | speed_metrics, 110 | ) 111 | from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments 112 | from transformers.utils import ( 113 | CONFIG_NAME, 114 | WEIGHTS_INDEX_NAME, 115 | WEIGHTS_NAME, 116 | find_labels, 117 | get_full_repo_name, 118 | is_apex_available, 119 | is_datasets_available, 120 | is_in_notebook, 121 | is_ipex_available, 122 | is_sagemaker_dp_enabled, 123 | is_sagemaker_mp_enabled, 124 | is_torch_tensorrt_fx_available, 125 | is_torch_tpu_available, 126 | is_torchdynamo_available, 127 | logging, 128 | ) 129 | from transformers.utils.generic import ContextManagers 130 | from lr_scheduler import zo_lr_scheduler 131 | from Hessian_smooth_scheduler import Hessian_smooth_scheduler 132 | 133 | DEFAULT_CALLBACKS = [DefaultFlowCallback] 134 | DEFAULT_PROGRESS_CALLBACK = ProgressCallback 135 | 136 | if is_in_notebook(): 137 | from .utils.notebook import NotebookProgressCallback 138 | 139 | DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback 140 | 141 | if is_apex_available(): 142 | from apex import amp 143 | 144 | if is_datasets_available(): 145 | import datasets 146 | 147 | if is_torch_tpu_available(check_device=False): 148 | import torch_xla.core.xla_model as xm 149 | import torch_xla.debug.metrics as met 150 | import torch_xla.distributed.parallel_loader as pl 151 | 152 | if is_sagemaker_mp_enabled(): 153 | import smdistributed.modelparallel.torch as smp 154 | from smdistributed.modelparallel import __version__ as SMP_VERSION 155 | 156 | IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") 157 | 158 | from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat 159 | else: 160 | IS_SAGEMAKER_MP_POST_1_10 = False 161 | 162 | if TYPE_CHECKING: 163 | import optuna 164 | 165 | logger = logging.get_logger(__name__) 166 | # Name of the files used for checkpointing 167 | TRAINING_ARGS_NAME = "training_args.bin" 168 | TRAINER_STATE_NAME = "trainer_state.json" 169 | OPTIMIZER_NAME = "optimizer.pt" 170 | SCHEDULER_NAME = "scheduler.pt" 171 | SCALER_NAME = "scaler.pt" 172 | 173 | 174 | class OurTrainer(Trainer): 175 | 176 | from transformers.trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state 177 | 178 | def _inner_training_loop( 179 | self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None 180 | ): 181 | """ 182 | We overload the original training loop to add linear probing and MeZO. Search key word "MeZO added" 183 | for those updates. 184 | """ 185 | self._train_batch_size = batch_size 186 | self.do_grad_scaling = self.args.do_grad_scaling 187 | self.best_eval_loss = 100.0 188 | train_dataloader = self.get_train_dataloader() 189 | # Setting up training control variables: 190 | # number of training epochs: num_train_epochs 191 | # number of training steps per epoch: num_update_steps_per_epoch 192 | # total number of training steps to execute: max_steps 193 | total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size #16*1*1 194 | 195 | len_dataloader = None 196 | if has_length(train_dataloader): 197 | len_dataloader = len(train_dataloader) 198 | num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps 199 | num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) 200 | num_examples = self.num_examples(train_dataloader) 201 | if args.max_steps > 0: 202 | max_steps = args.max_steps #20000 203 | num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( 204 | args.max_steps % num_update_steps_per_epoch > 0 205 | ) #318 206 | # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's 207 | # the best we can do. 208 | num_train_samples = args.max_steps * total_train_batch_size 209 | else: 210 | max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) 211 | num_train_epochs = math.ceil(args.num_train_epochs) 212 | num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs 213 | elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size 214 | max_steps = args.max_steps 215 | # Setting a very large number of epochs so we go as many times as necessary over the iterator. 216 | num_train_epochs = sys.maxsize 217 | num_update_steps_per_epoch = max_steps 218 | num_examples = total_train_batch_size * args.max_steps 219 | num_train_samples = args.max_steps * total_train_batch_size 220 | else: 221 | raise ValueError( 222 | "args.max_steps must be set to a positive value if dataloader does not have a length, was" 223 | f" {args.max_steps}" 224 | ) 225 | 226 | if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: 227 | if self.args.n_gpu > 1: 228 | # nn.DataParallel(model) replicates the model, creating new variables and module 229 | # references registered here no longer work on other gpus, breaking the module 230 | raise ValueError( 231 | "Currently --debug underflow_overflow is not supported under DP. Please use DDP" 232 | " (torch.distributed.launch)." 233 | ) 234 | else: 235 | debug_overflow = DebugUnderflowOverflow(self.model) # noqa 236 | 237 | delay_optimizer_creation = False 238 | self.create_optimizer_and_scheduler(num_training_steps=max_steps) 239 | 240 | self.state = TrainerState() 241 | self.state.is_hyper_param_search = trial is not None 242 | 243 | # Activate gradient checkpointing if needed 244 | if args.gradient_checkpointing: 245 | self.model.gradient_checkpointing_enable() 246 | 247 | model = self._wrap_model(self.model_wrapped) 248 | 249 | if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None: 250 | self._load_from_checkpoint(resume_from_checkpoint, model) 251 | 252 | # for the rest of this function `model` is the outside model, whether it was wrapped or not 253 | if model is not self.model: 254 | self.model_wrapped = model 255 | 256 | if delay_optimizer_creation: 257 | self.create_optimizer_and_scheduler(num_training_steps=max_steps) 258 | self._load_optimizer_and_scheduler(resume_from_checkpoint) 259 | # Train! 260 | logger.info("***** Running training *****") 261 | logger.info(f" Num examples = {num_examples}") 262 | logger.info(f" Num Epochs = {num_train_epochs}") 263 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 264 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") 265 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 266 | logger.info(f" Total optimization steps = {max_steps}") 267 | logger.info( 268 | f" Number of trainable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad)}" 269 | ) 270 | 271 | self.state.epoch = 0 272 | start_time = time.time() 273 | epochs_trained = 0 274 | steps_trained_in_current_epoch = 0 275 | steps_trained_progress_bar = None 276 | 277 | # Check if continuing training from a checkpoint 278 | if resume_from_checkpoint is not None and os.path.isfile( 279 | os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) 280 | ): 281 | self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) 282 | epochs_trained = self.state.global_step // num_update_steps_per_epoch 283 | if not args.ignore_data_skip: 284 | steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) 285 | steps_trained_in_current_epoch *= args.gradient_accumulation_steps 286 | else: 287 | steps_trained_in_current_epoch = 0 288 | 289 | logger.info(" Continuing training from checkpoint, will skip to saved global_step") 290 | logger.info(f" Continuing training from epoch {epochs_trained}") 291 | logger.info(f" Continuing training from global step {self.state.global_step}") 292 | if not args.ignore_data_skip: 293 | logger.info( 294 | f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} " 295 | "batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` " 296 | "flag to your launch command, but you will resume the training on data already seen by your model." 297 | ) 298 | if self.is_local_process_zero() and not args.disable_tqdm: 299 | steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch) 300 | steps_trained_progress_bar.set_description("Skipping the first batches") 301 | 302 | # Update the references 303 | self.callback_handler.model = self.model 304 | self.callback_handler.optimizer = self.optimizer 305 | self.callback_handler.lr_scheduler = self.lr_scheduler 306 | self.callback_handler.train_dataloader = train_dataloader 307 | if self.hp_name is not None and self._trial is not None: 308 | # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial 309 | # parameter to Train when using DDP. 310 | self.state.trial_name = self.hp_name(self._trial) 311 | if trial is not None: 312 | assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial 313 | self.state.trial_params = hp_params(assignments) 314 | else: 315 | self.state.trial_params = None 316 | # This should be the same if the state has been saved but in case the training arguments changed, it's safer 317 | # to set this after the load. 318 | self.state.max_steps = max_steps 319 | self.state.num_train_epochs = num_train_epochs 320 | self.state.is_local_process_zero = self.is_local_process_zero() 321 | self.state.is_world_process_zero = self.is_world_process_zero() 322 | 323 | # tr_loss is a tensor to avoid synchronization of TPUs through .item() 324 | tr_loss = torch.tensor(0.0).to(args.device) 325 | # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses 326 | self._total_loss_scalar = 0.0 327 | self._globalstep_last_logged = self.state.global_step 328 | model.zero_grad() 329 | 330 | self.control = self.callback_handler.on_train_begin(args, self.state, self.control) 331 | 332 | # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. 333 | if not args.ignore_data_skip: 334 | for epoch in range(epochs_trained): 335 | is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance( 336 | train_dataloader.sampler, RandomSampler 337 | ) 338 | if not is_random_sampler: 339 | # We just need to begin an iteration to create the randomization of the sampler. 340 | # That was before PyTorch 1.11 however... 341 | for _ in train_dataloader: 342 | break 343 | else: 344 | # Otherwise we need to call the whooooole sampler cause there is some random operation added 345 | # AT THE VERY END! 346 | _ = list(train_dataloader.sampler) 347 | 348 | self.Hessian_matrix = {} 349 | for name, param in model.named_parameters(): 350 | if param.requires_grad: 351 | self.Hessian_matrix[name] = torch.ones(size=param.data.size(), device=param.data.device, dtype=param.data.dtype) 352 | for epoch in range(epochs_trained, num_train_epochs): 353 | 354 | 355 | zo_learning_rate = zo_lr_scheduler(self.args.learning_rate, self.args.zo_lr_scheduler_type, self.args.warmup_step, self.args.decay_step, self.state.global_step, int(num_train_epochs)) 356 | Hessian_smooth = Hessian_smooth_scheduler(self.args.hessian_smooth_type, self.state.global_step, int(num_train_epochs)) 357 | 358 | if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): 359 | train_dataloader.sampler.set_epoch(epoch) 360 | elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard): 361 | train_dataloader.dataset.set_epoch(epoch) 362 | 363 | if is_torch_tpu_available(): 364 | parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device) 365 | epoch_iterator = parallel_loader 366 | else: 367 | epoch_iterator = train_dataloader 368 | 369 | # Reset the past mems state at the beginning of each epoch if necessary. 370 | if args.past_index >= 0: 371 | self._past = None 372 | 373 | steps_in_epoch = ( 374 | len(epoch_iterator) 375 | if len_dataloader is not None 376 | else args.max_steps * args.gradient_accumulation_steps 377 | ) 378 | self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) 379 | 380 | if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: 381 | self._load_rng_state(resume_from_checkpoint) 382 | 383 | step = -1 384 | for step, inputs in enumerate(epoch_iterator): 385 | # Skip past any already trained steps if resuming training 386 | if steps_trained_in_current_epoch > 0: 387 | steps_trained_in_current_epoch -= 1 388 | if steps_trained_progress_bar is not None: 389 | steps_trained_progress_bar.update(1) 390 | if steps_trained_in_current_epoch == 0: 391 | self._load_rng_state(resume_from_checkpoint) 392 | continue 393 | elif steps_trained_progress_bar is not None: 394 | steps_trained_progress_bar.close() 395 | steps_trained_progress_bar = None 396 | 397 | if step % args.gradient_accumulation_steps == 0: 398 | self.control = self.callback_handler.on_step_begin(args, self.state, self.control) 399 | 400 | tr_loss_step = self.zo_Hessian_step_update(model, inputs, zo_learning_rate, Hessian_smooth) 401 | 402 | if ( 403 | args.logging_nan_inf_filter 404 | and not is_torch_tpu_available() 405 | and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) 406 | ): 407 | # if loss is nan or inf simply add the average of previous logged losses 408 | tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) 409 | else: 410 | tr_loss += tr_loss_step 411 | 412 | self.current_flos += float(self.floating_point_ops(inputs)) 413 | 414 | 415 | self.state.global_step += 1 416 | self.state.epoch = epoch + (step + 1) / steps_in_epoch 417 | self.control = self.callback_handler.on_step_end(args, self.state, self.control) 418 | self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) 419 | 420 | if self.control.should_epoch_stop or self.control.should_training_stop: 421 | break 422 | 423 | self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) 424 | 425 | if DebugOption.TPU_METRICS_DEBUG in self.args.debug: 426 | if is_torch_tpu_available(): 427 | # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) 428 | xm.master_print(met.metrics_report()) 429 | else: 430 | logger.warning( 431 | "You enabled PyTorch/XLA debug metrics but you don't have a TPU " 432 | "configured. Check your training configuration if this is unexpected." 433 | ) 434 | if self.control.should_training_stop: 435 | break 436 | 437 | if args.past_index and hasattr(self, "_past"): 438 | # Clean the state at the end of training 439 | delattr(self, "_past") 440 | 441 | logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") 442 | 443 | self._load_best_model() 444 | 445 | # add remaining tr_loss 446 | self._total_loss_scalar += tr_loss.item() 447 | train_loss = self._total_loss_scalar / self.state.global_step 448 | 449 | metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) 450 | self.store_flos() 451 | metrics["total_flos"] = self.state.total_flos 452 | metrics["train_loss"] = train_loss 453 | 454 | self.is_in_train = False 455 | 456 | self._memory_tracker.stop_and_update_metrics(metrics) 457 | 458 | self.log(metrics) 459 | 460 | run_dir = self._get_output_dir(trial) 461 | checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) 462 | 463 | # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint. 464 | if self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: 465 | for checkpoint in checkpoints_sorted: 466 | if checkpoint != self.state.best_model_checkpoint: 467 | logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") 468 | shutil.rmtree(checkpoint) 469 | logger.info(f"Best model checkpoint [{self.state.best_model_checkpoint}]") 470 | self.control = self.callback_handler.on_train_end(args, self.state, self.control) 471 | 472 | return TrainOutput(self.state.global_step, train_loss, metrics) 473 | 474 | 475 | def efficient_Hessian_perturb_parameters(self, model: nn.Module, random_seed, Hessian_matrix=None, scaling_factor=1): 476 | torch.manual_seed(random_seed) 477 | for name, param in self.named_parameters_to_optim: 478 | z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype) 479 | param.data = param.data + scaling_factor / torch.sqrt(Hessian_matrix[name]) * z * self.args.zo_eps 480 | return model 481 | 482 | def zo_Hessian_step_update(self, model, inputs, zo_learning_rate, Hessian_smooth): 483 | 484 | self.named_parameters_to_optim = [] 485 | for name, param in model.named_parameters(): 486 | if param.requires_grad: 487 | self.named_parameters_to_optim.append((name, param)) 488 | 489 | random_seed = np.random.randint(1000000000) 490 | with torch.no_grad(): 491 | loss_original = self.zo_forward(model, inputs) 492 | 493 | # first function evaluation 494 | model = self.efficient_Hessian_perturb_parameters(model, random_seed, self.Hessian_matrix, scaling_factor=1) 495 | loss1 = self.zo_forward(model, inputs) 496 | 497 | # second function evaluation 498 | model = self.efficient_Hessian_perturb_parameters(model, random_seed, self.Hessian_matrix, scaling_factor=-2) 499 | loss2 = self.zo_forward(model, inputs) 500 | 501 | model = self.efficient_Hessian_perturb_parameters(model, random_seed, self.Hessian_matrix, scaling_factor=1) 502 | 503 | torch.manual_seed(random_seed) 504 | for name, param in self.named_parameters_to_optim: 505 | z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype) 506 | 507 | Hessian_temp = self.Hessian_matrix[name] * z * z 508 | Hessian_estimator = (torch.abs(loss1+loss2-2 * loss_original)* Hessian_temp * Hessian_smooth /(2 * self.args.zo_eps*self.args.zo_eps)) 509 | 510 | self.Hessian_matrix[name] = ((1-Hessian_smooth) * self.Hessian_matrix[name] + Hessian_estimator) 511 | 512 | grad = (loss1-loss2)/(2 * self.args.zo_eps) * z / torch.sqrt(self.Hessian_matrix[name]) 513 | param.data = param.data - zo_learning_rate * (grad + self.args.weight_decay * param.data) 514 | 515 | loss_out = self.zo_forward(model, inputs) 516 | return loss_out 517 | 518 | 519 | 520 | def zo_perturb_parameters(self, random_seed=None, scaling_factor=1): 521 | """ 522 | Perturb the parameters with random vector z. 523 | Input: 524 | - random_seed: random seed for MeZO in-place perturbation (if it's None, we will use self.zo_random_seed) 525 | - scaling_factor: theta = theta + scaling_factor * z * eps 526 | """ 527 | 528 | # Set the random seed to ensure that we sample the same z for perturbation/update 529 | torch.manual_seed(random_seed if random_seed is not None else self.zo_random_seed) 530 | 531 | for name, param in self.named_parameters_to_optim: 532 | z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype) 533 | param.data = param.data + scaling_factor * z * self.args.zo_eps 534 | 535 | 536 | def zo_forward(self, model, inputs): 537 | """ 538 | Get (no gradient) loss from the model. Dropout is turned off too. 539 | """ 540 | model.eval() 541 | with torch.inference_mode(): 542 | inputs = self._prepare_inputs(inputs) 543 | with self.compute_loss_context_manager(): 544 | loss = self.compute_loss(model, inputs) 545 | if self.args.n_gpu > 1: 546 | # Warning: this is copied from the original Huggingface Trainer. Untested. 547 | loss = loss.mean() # mean() to average on multi-gpu parallel training 548 | return loss.detach() 549 | 550 | def zo_step(self, model, inputs): 551 | """ 552 | Estimate gradient by MeZO. Return the loss from f(theta + z) 553 | """ 554 | args = self.args 555 | 556 | # What parameters to optimize 557 | self.named_parameters_to_optim = [] 558 | for name, param in model.named_parameters(): 559 | if param.requires_grad: 560 | self.named_parameters_to_optim.append((name, param)) 561 | 562 | # Sample the random seed for sampling z 563 | self.zo_random_seed = np.random.randint(1000000000) 564 | 565 | # First function evaluation 566 | self.zo_perturb_parameters(scaling_factor=1) 567 | loss1 = self.zo_forward(model, inputs) 568 | 569 | # Second function evaluation 570 | self.zo_perturb_parameters(scaling_factor=-2) 571 | loss2 = self.zo_forward(model, inputs) 572 | 573 | self.projected_grad = ((loss1 - loss2) / (2 * self.args.zo_eps)).item() 574 | 575 | # No gradient accumulation support 576 | assert self.args.gradient_accumulation_steps == 1 577 | 578 | # Reset model back to its parameters at start of step 579 | self.zo_perturb_parameters(scaling_factor=1) 580 | 581 | return loss1 582 | 583 | 584 | def zo_update(self, model): 585 | args = self.args 586 | # Reset the random seed for sampling zs 587 | torch.manual_seed(self.zo_random_seed) 588 | for name, param in self.named_parameters_to_optim: 589 | # Resample z 590 | z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype) 591 | if "bias" not in name and "layer_norm" not in name and "layernorm" not in name: 592 | param.data = param.data - self._get_learning_rate() * (self.projected_grad * z + args.weight_decay * param.data) 593 | else: 594 | param.data = param.data - self._get_learning_rate() * (self.projected_grad * z) 595 | self.lr_scheduler.step() 596 | 597 | 598 | ############## Misc overload functions ############## 599 | def _set_signature_columns_if_needed(self): 600 | """ 601 | We overload this function for non-differentiable objective training to pass "gold" -- the gold text for the task 602 | """ 603 | if self._signature_columns is None: 604 | # Inspect model forward signature to keep only the arguments it accepts. 605 | signature = inspect.signature(self.model.forward) 606 | self._signature_columns = list(signature.parameters.keys()) 607 | # Labels may be named label or label_ids, the default data collator handles that. 608 | self._signature_columns += list(set(["label", "label_ids"] + self.label_names)) 609 | self._signature_columns += ["gold"] 610 | 611 | def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): 612 | if output_dir is None: 613 | output_dir = self.args.output_dir 614 | 615 | if is_torch_tpu_available(): 616 | self._save_tpu(output_dir) 617 | elif self.args.should_save: 618 | self._save(output_dir) 619 | 620 | # Push to the Hub when `save_model` is called by the user. 621 | if self.args.push_to_hub and not _internal_call: 622 | self.push_to_hub(commit_message="Model save") 623 | 624 | def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval): 625 | 626 | if self.state.global_step%10==0: 627 | logs: Dict[str, float] = {} 628 | tr_loss_scalar = self._nested_gather(tr_loss).mean().item() 629 | tr_loss -= tr_loss 630 | logs["train_loss"] = round(tr_loss_scalar / (10), 4) 631 | logs["learning_rate"] = self._get_learning_rate() 632 | self._total_loss_scalar += tr_loss_scalar 633 | self._globalstep_last_logged = self.state.global_step 634 | self.store_flos() 635 | self.log(logs) 636 | 637 | metrics = None 638 | if self.state.global_step%self.args.save_steps==0: 639 | metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) 640 | self._report_to_hp_search(trial, self.state.global_step, metrics) 641 | 642 | # Run delayed LR scheduler now that metrics are populated 643 | if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): 644 | metric_to_check = self.args.metric_for_best_model 645 | if not metric_to_check.startswith("eval_"): 646 | metric_to_check = f"eval_{metric_to_check}" 647 | self.lr_scheduler.step(metrics[metric_to_check]) 648 | 649 | if metrics['eval_loss']