├── README.md ├── image_classification ├── LRScheduler.py ├── bypass_bn.py ├── config │ ├── cifar10 │ │ ├── resnet20 │ │ │ ├── adahessian.sh │ │ │ ├── adamw.sh │ │ │ ├── msassha.sh │ │ │ ├── sam_sgd.sh │ │ │ ├── sassha.sh │ │ │ ├── sgd.sh │ │ │ ├── shampoo.sh │ │ │ └── sophiah.sh │ │ └── resnet32 │ │ │ ├── adahessian.sh │ │ │ ├── adamw.sh │ │ │ ├── msassha.sh │ │ │ ├── sam_sgd.sh │ │ │ ├── sassha.sh │ │ │ ├── sgd.sh │ │ │ ├── shampoo.sh │ │ │ └── sophiah.sh │ ├── cifar100 │ │ ├── resnet32 │ │ │ ├── adahessian.sh │ │ │ ├── adamw.sh │ │ │ ├── msassha.sh │ │ │ ├── sam_sgd.sh │ │ │ ├── sassha.sh │ │ │ ├── sgd.sh │ │ │ ├── shampoo.sh │ │ │ └── sophiah.sh │ │ └── wideresnet │ │ │ ├── adahessian.sh │ │ │ ├── adamw.sh │ │ │ ├── msassha.sh │ │ │ ├── sam_sgd.sh │ │ │ ├── sassha.sh │ │ │ ├── sgd.sh │ │ │ ├── shampoo.sh │ │ │ └── sophiah.sh │ └── imagenet │ │ ├── resnet50 │ │ ├── adahessian.sh │ │ ├── adamw.sh │ │ ├── msassha.sh │ │ ├── sam_sgd.sh │ │ ├── sassha.sh │ │ ├── sgd.sh │ │ └── sophiah.sh │ │ └── vit_small │ │ ├── adahessian.sh │ │ ├── adamw.sh │ │ ├── msassha.sh │ │ ├── sam_adamw.sh │ │ ├── sassha.sh │ │ ├── sgd.sh │ │ └── sophiah.sh ├── models │ ├── __init__.py │ ├── resnet.py │ ├── simple_vit.py │ └── wide_resnet.py ├── train.py └── utils.py ├── label_noise ├── bypass_bn.py ├── config │ ├── cifar10 │ │ ├── adahessian │ │ │ ├── noise20.sh │ │ │ ├── noise40.sh │ │ │ └── noise60.sh │ │ ├── msassha │ │ │ ├── noise20.sh │ │ │ ├── noise40.sh │ │ │ └── noise60.sh │ │ ├── sam_sgd │ │ │ ├── noise20.sh │ │ │ ├── noise40.sh │ │ │ └── noise60.sh │ │ ├── sassha │ │ │ ├── noise20.sh │ │ │ ├── noise40.sh │ │ │ └── noise60.sh │ │ ├── sgd │ │ │ ├── noise20.sh │ │ │ ├── noise40.sh │ │ │ └── noise60.sh │ │ ├── shampoo │ │ │ ├── noise20.sh │ │ │ ├── noise40.sh │ │ │ └── noise60.sh │ │ └── sophiah │ │ │ ├── noise20.sh │ │ │ ├── noise40.sh │ │ │ └── noise60.sh │ └── cifar100 │ │ ├── adahessian │ │ ├── noise20.sh │ │ ├── noise40.sh │ │ └── noise60.sh │ │ ├── msassha │ │ ├── noise20.sh │ │ ├── noise40.sh │ │ └── noise60.sh │ │ ├── sam_sgd │ │ ├── noise20.sh │ │ ├── noise40.sh │ │ └── noise60.sh │ │ ├── sassha │ │ ├── noise20.sh │ │ ├── noise40.sh │ │ └── noise60.sh │ │ ├── sgd │ │ ├── noise20.sh │ │ ├── noise40.sh │ │ └── noise60.sh │ │ ├── shampoo │ │ ├── noise20.sh │ │ ├── noise40.sh │ │ └── noise60.sh │ │ └── sophiah │ │ ├── noise20.sh │ │ ├── noise40.sh │ │ └── noise60.sh ├── models │ └── resnet.py ├── train.py └── utils.py ├── language_tasks ├── README.md ├── config │ ├── adahessian │ │ ├── mnli.sh │ │ ├── mrpc.sh │ │ ├── qnli.sh │ │ ├── qqp.sh │ │ ├── rte.sh │ │ ├── sst2.sh │ │ └── stsb.sh │ ├── adamw │ │ ├── mnli.sh │ │ ├── mrpc.sh │ │ ├── qnli.sh │ │ ├── qqp.sh │ │ ├── rte.sh │ │ ├── sst2.sh │ │ └── stsb.sh │ ├── msassha │ │ ├── mnli.sh │ │ ├── mrpc.sh │ │ ├── qnli.sh │ │ ├── qqp.sh │ │ ├── rte.sh │ │ ├── sst2.sh │ │ └── stsb.sh │ ├── sam_adamw │ │ ├── mnli.sh │ │ ├── mrpc.sh │ │ ├── qnli.sh │ │ ├── qqp.sh │ │ ├── rte.sh │ │ ├── sst2.sh │ │ └── stsb.sh │ ├── sassha │ │ ├── mnli.sh │ │ ├── mrpc.sh │ │ ├── qnli.sh │ │ ├── qqp.sh │ │ ├── rte.sh │ │ ├── sst2.sh │ │ └── stsb.sh │ ├── sophiah │ │ ├── mnli.sh │ │ ├── mrpc.sh │ │ ├── qnli.sh │ │ ├── qqp.sh │ │ ├── rte.sh │ │ ├── sst2.sh │ │ └── stsb.sh │ └── wandb │ │ └── rte_sophia.sh └── finetune.py ├── optimizers ├── __init__.py ├── adahessian.py ├── hessian_scheduler.py ├── msassha.py ├── sam.py ├── sassha.py ├── shampoo.py └── sophiaH.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # SASSHA: Sharpness-aware Adaptive Second-order Optimization with Stable Hessian Approximation 2 | 3 | This repository contains Pytorch source code for arXiv paper [SASSHA: Sharpness-aware Adaptive Second-order Optimization With Stable Hessian Approximation](https://arxiv.org/abs/2502.18153) by Dahun Shin*, [Dongyeop Lee](https://edong6768.github.io/)*, Jinseok Chung, and [Namhoon Lee](https://namhoonlee.github.io/). 4 | 5 | ## Introduction 6 | 7 | SASSHA is a novel second-order method designed to enhance generalization by explicitly reducing sharpness of the solution, while stabilizing the computation of approximate Hessians along the optimization trajectory. 8 | 9 | This Pytorch implementation supports various tasks, including image classification, finetuning, and label noise experiments. 10 | 11 | For a detailed explanation of the SASSHA algorithm, please refer to [our paper](https://arxiv.org/pdf/2502.18153). 12 | 13 | 14 | ## Getting Started 15 | 16 | First, clone our repository to your local system: 17 | ```bash 18 | git clone https://github.com/LOG-postech/Sassha.git 19 | cd Sassha 20 | ``` 21 | 22 | We recommend using Anaconda to set up the environment and install all necessary dependencies: 23 | 24 | ```bash 25 | conda create -n "sassha" python=3.9 26 | conda activate sassha 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | Ensure you are using Python 3.9 or later. 31 | 32 | Navigate to the example folder of your choice. For instance, to run an image classification experiment: 33 | 34 | ```bash 35 | cd image_classification 36 | ``` 37 | 38 | Now, train the model with the following command: 39 | ```bash 40 | python train.py --workers 4 --dataset imagenet -a resnet50 --epochs 90 -b 256 \ 41 | --LRScheduler multi_step --lr-decay-epoch 30 60 --lr-decay 0.1 \ 42 | --optimizer sassha \ 43 | --lr 0.3 --wd 1e-4 --rho 0.2 --lazy_hessian 10 --seed 0 \ 44 | --project_name sassha \ 45 | {enter/your/imagenet-folder/with/train_and_val_data} 46 | ``` 47 | 48 | Here, enter the path to imagenet datasets in `{enter/your/imagenet-folder/with/train_and_val_data}`. 49 | 50 | ### Distributed Training (Single node, multiple GPUs) 51 | SASSHA is fully compatible with multi-GPU environments for distributed training. Use the following command to train a model across multiple GPUs on a single node: 52 | ```bash 53 | python train.py --dist-url 'tcp://127.0.0.1:23456' --dist-backend 'nccl' --multiprocessing-distributed --world-size 1 --rank 0 \ 54 | --workers 4 --dataset imagenet -a vit_b_32 --epochs 90 -b 1024 \ 55 | --LRScheduler cosine --warmup_epochs 8 \ 56 | --optimizer sassha \ 57 | --lr 0.6 --wd 2e-4 --rho 0.25 --lazy_hessian 10 --eps 1e-6 --seed 0 \ 58 | --project_name sassha \ 59 | {enter/your/imagenet-folder/with/train_and_val_data} 60 | ``` 61 | Ensure that NCCL is properly configured on your system and that your GPUs are available before running the script. 62 | 63 | ### Reproducing Paper Results 64 | Configurations used in [our paper](https://arxiv.org/pdf/2502.18153) are provided as shell scrips in each example folder. 65 | 66 | ### Environments 67 | - cuda 11.6.2 68 | - python 3.9 69 | 70 | ## General Usage 71 | 72 | SASSHA can be imported and used as follows: 73 | 74 | ```python 75 | from optimizers import SASSHA 76 | 77 | ... 78 | 79 | # Initialize your model and optimizer 80 | model = YourModel() 81 | optimizer = SASSHA(model.parameters(), ...) 82 | 83 | ... 84 | 85 | # training loop 86 | for input, output in data: 87 | 88 | # first forward-backward pass 89 | loss = loss_function(output, model(input)) 90 | loss.backward() 91 | optimizer.perturb_weights(zero_grad=True) 92 | 93 | # second forward-backward pass 94 | loss_function(output, model(input)).backward(create_graph=True) 95 | optimizer.unperturb() 96 | optimizer.step() 97 | optimizer.zero_grad() 98 | 99 | ... 100 | ``` 101 | 102 | ## Citation 103 | ```bibtex 104 | @article{shin2025sassha, 105 | title={SASSHA: Sharpness-aware Adaptive Second-order Optimization With Stable Hessian Approximation}, 106 | author={Shin, Dahun and Lee, Dongyeop and Chung, Jinseok and Lee, Namhoon}, 107 | journal={arXiv preprint arXiv:2502.18153}, 108 | year={2025} 109 | } 110 | ``` -------------------------------------------------------------------------------- /image_classification/LRScheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | class CosineAnnealingWarmupRestarts(_LRScheduler): 6 | """ 7 | optimizer (Optimizer): Wrapped optimizer. 8 | first_cycle_steps (int): First cycle step size. 9 | cycle_mult(float): Cycle steps magnification. Default: -1. 10 | max_lr(float): First cycle's max learning rate. Default: 0.1. 11 | min_lr(float): Min learning rate. Default: 0.001. 12 | warmup_steps(int): Linear warmup step size. Default: 0. 13 | gamma(float): Decrease rate of max learning rate by cycle. Default: 1. 14 | last_epoch (int): The index of last epoch. Default: -1. 15 | """ 16 | 17 | def __init__(self, 18 | optimizer : torch.optim.Optimizer, 19 | first_cycle_steps : int, 20 | cycle_mult : float = 1., 21 | max_lr : float = 0.1, 22 | min_lr : float = 0.001, 23 | warmup_steps : int = 0, 24 | gamma : float = 1., 25 | last_epoch : int = -1 26 | ): 27 | assert warmup_steps < first_cycle_steps 28 | 29 | self.first_cycle_steps = first_cycle_steps # first cycle step size 30 | self.cycle_mult = cycle_mult # cycle steps magnification 31 | self.base_max_lr = max_lr # first max learning rate 32 | self.max_lr = max_lr # max learning rate in the current cycle 33 | self.min_lr = min_lr # min learning rate 34 | self.warmup_steps = warmup_steps # warmup step size 35 | self.gamma = gamma # decrease rate of max learning rate by cycle 36 | 37 | self.cur_cycle_steps = first_cycle_steps # first cycle step size 38 | self.cycle = 0 # cycle count 39 | self.step_in_cycle = last_epoch # step size of the current cycle 40 | 41 | super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch) 42 | 43 | # set learning rate min_lr 44 | self.init_lr() 45 | 46 | def init_lr(self): 47 | self.base_lrs = [] 48 | for param_group in self.optimizer.param_groups: 49 | param_group['lr'] = self.min_lr 50 | self.base_lrs.append(self.min_lr) 51 | 52 | def get_lr(self): 53 | if self.step_in_cycle == -1: 54 | return self.base_lrs 55 | elif self.step_in_cycle < self.warmup_steps: 56 | return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs] 57 | else: 58 | return [base_lr + (self.max_lr - base_lr) \ 59 | * (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \ 60 | / (self.cur_cycle_steps - self.warmup_steps))) / 2 61 | for base_lr in self.base_lrs] 62 | 63 | def step(self, epoch=None): 64 | if epoch is None: 65 | epoch = self.last_epoch + 1 66 | self.step_in_cycle = self.step_in_cycle + 1 67 | if self.step_in_cycle >= self.cur_cycle_steps: 68 | self.cycle += 1 69 | self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps 70 | self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps 71 | else: 72 | if epoch >= self.first_cycle_steps: 73 | if self.cycle_mult == 1.: 74 | self.step_in_cycle = epoch % self.first_cycle_steps 75 | self.cycle = epoch // self.first_cycle_steps 76 | else: 77 | n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult)) 78 | self.cycle = n 79 | self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1)) 80 | self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n) 81 | else: 82 | self.cur_cycle_steps = self.first_cycle_steps 83 | self.step_in_cycle = epoch 84 | 85 | self.max_lr = self.base_max_lr * (self.gamma**self.cycle) 86 | self.last_epoch = math.floor(epoch) 87 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 88 | param_group['lr'] = lr -------------------------------------------------------------------------------- /image_classification/bypass_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.modules.batchnorm import _BatchNorm 4 | 5 | 6 | def disable_running_stats(model): 7 | def _disable(module): 8 | if isinstance(module, _BatchNorm): 9 | module.backup_momentum = module.momentum 10 | module.momentum = 0 11 | 12 | model.apply(_disable) 13 | 14 | def enable_running_stats(model): 15 | def _enable(module): 16 | if isinstance(module, _BatchNorm) and hasattr(module, "backup_momentum"): 17 | module.momentum = module.backup_momentum 18 | 19 | model.apply(_enable) 20 | -------------------------------------------------------------------------------- /image_classification/config/cifar10/resnet20/adahessian.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar10 -a resnet20 --epochs 160 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \ 3 | --optimizer adahessian \ 4 | --lr 0.15 --wd 5e-4 --lazy_hessian 1 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/cifar10/resnet20/adamw.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar10 -a resnet20 --epochs 160 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \ 3 | --optimizer adamw \ 4 | --lr 0.01 --wd 5e-4 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/cifar10/resnet20/msassha.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar10 -a resnet20 --epochs 160 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \ 3 | --optimizer msassha \ 4 | --lr 0.3 --wd 1e-3 --rho 0.8 --lazy_hessian 10 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/cifar10/resnet20/sam_sgd.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar10 -a resnet20 --epochs 160 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \ 3 | --optimizer samsgd \ 4 | --lr 0.1 --wd 5e-4 --rho 0.15 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/cifar10/resnet20/sassha.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar10 -a resnet20 --epochs 160 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \ 3 | --optimizer sassha \ 4 | --hessian_power_scheduler constant \ 5 | --lr 0.15 --min_lr 0.0015 --wd 5e-4 --rho 0.2 --lazy_hessian 10 --eps 1e-6 --seed 0, 1, 2 \ 6 | -------------------------------------------------------------------------------- /image_classification/config/cifar10/resnet20/sgd.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar10 -a resnet20 --epochs 160 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \ 3 | --optimizer sgd \ 4 | --lr 0.1 --wd 5e-4 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/cifar10/resnet20/shampoo.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar10 -a resnet20 --epochs 160 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \ 3 | --optimizer shampoo \ 4 | --lr 0.8 --wd 5e-3 --eps 1e-4 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/cifar10/resnet20/sophiah.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar10 -a resnet20 --epochs 160 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \ 3 | --optimizer sophiah \ 4 | --lr 1e-3 --wd 1e-3 --lazy_hessian 1 --clip_threshold 0.01 --eps 1e-4 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/cifar10/resnet32/adahessian.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar10 -a resnet32 --epochs 160 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \ 3 | --optimizer adahessian \ 4 | --lr 0.15 --wd 5e-4 --lazy_hessian 1 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/cifar10/resnet32/adamw.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar10 -a resnet32 --epochs 160 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \ 3 | --optimizer adamw \ 4 | --lr 0.01 --wd 5e-4 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/cifar10/resnet32/msassha.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar10 -a resnet32 --epochs 160 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \ 3 | --optimizer msassha \ 4 | --lr 0.15 --wd 1e-3 --rho 0.6 --lazy_hessian 10 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/cifar10/resnet32/sam_sgd.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar10 -a resnet32 --epochs 160 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \ 3 | --optimizer samsgd \ 4 | --lr 0.1 --wd 5e-4 --rho 0.15 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/cifar10/resnet32/sassha.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar10 -a resnet32 --epochs 160 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \ 3 | --optimizer sassha \ 4 | --hessian_power_scheduler constant \ 5 | --lr 0.15 --min_lr 0.0015 --wd 5e-4 --rho 0.2 --lazy_hessian 10 --seed 0, 1, 2 6 | -------------------------------------------------------------------------------- /image_classification/config/cifar10/resnet32/sgd.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar10 -a resnet32 --epochs 160 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \ 3 | --optimizer sgd \ 4 | --lr 0.1 --wd 5e-4 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/cifar10/resnet32/shampoo.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar10 -a resnet32 --epochs 160 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \ 3 | --optimizer shampoo \ 4 | --lr 0.4 --wd 1e-2 --eps 1e-4 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/cifar10/resnet32/sophiah.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar10 -a resnet32 --epochs 160 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \ 3 | --optimizer sophiah \ 4 | --lr 1e-3 --wd 2e-4 --lazy_hessian 1 --clip_threshold 0.1 --eps 1e-4 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/cifar100/resnet32/adahessian.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar100 -a resnet32 --epochs 160 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \ 3 | --optimizer adahessian \ 4 | --lr 0.15 --wd 5e-4 --lazy_hessian 1 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/cifar100/resnet32/adamw.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar100 -a resnet32 --epochs 160 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \ 3 | --optimizer adamw \ 4 | --lr 0.01 --wd 5e-4 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/cifar100/resnet32/msassha.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar100 -a resnet32 --epochs 160 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \ 3 | --optimizer msassha \ 4 | --lr 0.3 --wd 1e-3 --rho 0.6 --lazy_hessian 10 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/cifar100/resnet32/sam_sgd.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar100 -a resnet32 --epochs 160 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \ 3 | --optimizer samsgd \ 4 | --lr 0.1 --wd 5e-4 --rho 0.2 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/cifar100/resnet32/sassha.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar100 -a resnet32 --epochs 160 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \ 3 | --optimizer sassha \ 4 | --hessian_power_scheduler constant \ 5 | --lr 0.3 --min_lr 0.003 --wd 1e-3 --rho 0.25 --lazy_hessian 10 --eps 1e-6 --seed 0, 1, 2 \ 6 | -------------------------------------------------------------------------------- /image_classification/config/cifar100/resnet32/sgd.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar100 -a resnet32 --epochs 160 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \ 3 | --optimizer sgd \ 4 | --lr 0.1 --wd 5e-4 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/cifar100/resnet32/shampoo.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar100 -a resnet32 --epochs 160 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \ 3 | --optimizer shampoo \ 4 | --lr 1 --wd 5e-3 --eps 1e-4 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/cifar100/resnet32/sophiah.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar100 -a resnet32 --epochs 160 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \ 3 | --optimizer sophiah \ 4 | --lr 1e-3 --wd 2e-4 --lazy_hessian 1 --clip_threshold 0.05 --eps 1e-4 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/cifar100/wideresnet/adahessian.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar100 -a wideresnet_28_10 --epochs 200 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 60 120 160 --lr-decay 0.2 \ 3 | --optimizer adahessian \ 4 | --lr 0.3 --wd 5e-4 --lazy_hessian 1 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/cifar100/wideresnet/adamw.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar100 -a wideresnet_28_10 --epochs 200 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 60 120 160 --lr-decay 0.2 \ 3 | --optimizer adamw \ 4 | --lr 1e-3 --wd 5e-4 --seed 0, 1, 2 -------------------------------------------------------------------------------- /image_classification/config/cifar100/wideresnet/msassha.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar100 -a wideresnet_28_10 --epochs 200 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 60 120 160 --lr-decay 0.2 \ 3 | --optimizer msassha \ 4 | --lr 0.15 --wd 1e-3 --rho 0.25 --lazy_hessian 10 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/cifar100/wideresnet/sam_sgd.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar100 -a wideresnet_28_10 --epochs 200 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 60 120 160 --lr-decay 0.2 \ 3 | --optimizer samsgd \ 4 | --lr 0.1 --wd 1e-3 --rho 0.15 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/cifar100/wideresnet/sassha.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar100 -a wideresnet_28_10 --epochs 200 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 60 120 160 --lr-decay 0.2 \ 3 | --optimizer sassha \ 4 | --hessian_power_scheduler constant \ 5 | --lr 0.15 --min_lr 0.0015 --wd 0.0015 --rho 0.2 --lazy_hessian 10 --seed 0, 1, 2 -------------------------------------------------------------------------------- /image_classification/config/cifar100/wideresnet/sgd.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar100 -a wideresnet_28_10 --epochs 200 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 60 120 160 --lr-decay 0.2 \ 3 | --optimizer sgd \ 4 | --lr 0.1 --wd 5e-4 --seed 0, 1, 2 -------------------------------------------------------------------------------- /image_classification/config/cifar100/wideresnet/shampoo.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar100 -a wideresnet_28_10 --epochs 200 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 60 120 160 --lr-decay 0.2 \ 3 | --optimizer shampoo \ 4 | --lr 1 --wd 5e-3 --eps 1e-4 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/cifar100/wideresnet/sophiah.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 2 --dataset cifar100 -a wideresnet_28_10 --epochs 200 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 60 120 160 --lr-decay 0.2 \ 3 | --optimizer sophiah \ 4 | --lr 1e-3 --wd 1e-3 --lazy_hessian 1 --clip_threshold 0.01 --eps 1e-4 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /image_classification/config/imagenet/resnet50/adahessian.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 4 --dataset imagenet -a resnet50 --epochs 90 -b 256 \ 2 | --LRScheduler plateau \ 3 | --optimizer adahessian \ 4 | --lr 0.15 --wd 1e-4 --lazy_hessian 1 --seed 0, 1, 2 \ 5 | /home/shared/dataset/imagenet -------------------------------------------------------------------------------- /image_classification/config/imagenet/resnet50/adamw.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 4 --dataset imagenet -a resnet50 --epochs 90 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 30 60 --lr-decay 0.1 \ 3 | --optimizer adamw \ 4 | --lr 1e-3 --wd 1e-4 --seed 0, 1, 2 \ 5 | /home/shared/dataset/imagenet -------------------------------------------------------------------------------- /image_classification/config/imagenet/resnet50/msassha.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 4 --dataset imagenet -a resnet50 --epochs 90 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 30 60 --lr-decay 0.1 \ 3 | --optimizer msassha \ 4 | --lr 0.15 --wd 1e-4 --rho 0.1 --lazy_hessian 10 --seed 0, 1, 2 \ 5 | /home/shared/dataset/imagenet -------------------------------------------------------------------------------- /image_classification/config/imagenet/resnet50/sam_sgd.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 4 --dataset imagenet -a resnet50 --epochs 90 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 30 60 --lr-decay 0.1 \ 3 | --optimizer samsgd \ 4 | --lr 0.1 --wd 1e-4 --rho 0.1 --seed 0, 1, 2 \ 5 | /home/shared/dataset/imagenet -------------------------------------------------------------------------------- /image_classification/config/imagenet/resnet50/sassha.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 4 --dataset imagenet -a resnet50 --epochs 90 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 30 60 --lr-decay 0.1 \ 3 | --optimizer sassha \ 4 | --hessian_power_scheduler constant \ 5 | --lr 0.3 --min_lr 0.003 --wd 1e-4 --rho 0.2 --lazy_hessian 10 --seed 0, 1, 2 \ 6 | /home/shared/dataset/imagenet -------------------------------------------------------------------------------- /image_classification/config/imagenet/resnet50/sgd.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 4 --dataset imagenet -a resnet50 --epochs 90 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 30 60 --lr-decay 0.1 \ 3 | --optimizer sgd \ 4 | --lr 0.1 --wd 1e-4 --seed 0, 1, 2 \ 5 | /home/shared/dataset/imagenet -------------------------------------------------------------------------------- /image_classification/config/imagenet/resnet50/sophiah.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 4 --dataset imagenet -a resnet50 --epochs 90 -b 256 \ 2 | --LRScheduler multi_step --lr-decay-epoch 30 60 --lr-decay 0.1 \ 3 | --optimizer sophiah \ 4 | --lr 1e-2 --wd 1e-4 --lazy_hessian 1 --clip_threshold 0.1 --eps 1e-4 --seed 0, 1, 2 \ 5 | /home/shared/dataset/imagenet -------------------------------------------------------------------------------- /image_classification/config/imagenet/vit_small/adahessian.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 4 --dataset imagenet -a vit_s_32 --epochs 90 -b 1024 \ 2 | --LRScheduler cosine --warmup_epochs 8 \ 3 | --optimizer adahessian \ 4 | --lr 0.15 --wd 1e-4 --lazy_hessian 1 --seed 0, 1, 2 \ 5 | /home/shared/dataset/imagenet -------------------------------------------------------------------------------- /image_classification/config/imagenet/vit_small/adamw.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 4 --dataset imagenet -a vit_s_32 --epochs 90 -b 1024 \ 2 | --LRScheduler cosine --warmup_epochs 8 \ 3 | --optimizer adamw \ 4 | --lr 1e-3 --wd 1e-4 --grad_clip 1 --seed 0, 1, 2 \ 5 | /home/shared/dataset/imagenet -------------------------------------------------------------------------------- /image_classification/config/imagenet/vit_small/msassha.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 4 --dataset imagenet -a vit_s_32 --epochs 90 -b 1024 \ 2 | --LRScheduler cosine --warmup_epochs 8 \ 3 | --optimizer msassha \ 4 | --lr 0.6 --wd 4e-4 --rho 0.4 --lazy_hessian 10 --eps 1e-6 --seed 0, 1, 2 \ 5 | /home/shared/dataset/imagenet -------------------------------------------------------------------------------- /image_classification/config/imagenet/vit_small/sam_adamw.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 4 --dataset imagenet -a vit_s_32 --epochs 90 -b 1024 \ 2 | --LRScheduler cosine --warmup_epochs 8 \ 3 | --optimizer samadamw \ 4 | --lr 1e-3 --wd 1e-4 --rho 0.2 --grad_clip 1 --seed 0, 1, 2 \ 5 | /home/shared/dataset/imagenet -------------------------------------------------------------------------------- /image_classification/config/imagenet/vit_small/sassha.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 4 --dataset imagenet -a vit_s_32 --epochs 90 -b 1024 \ 2 | --LRScheduler cosine --warmup_epochs 8 \ 3 | --optimizer sassha \ 4 | --hessian_power_scheduler constant \ 5 | --lr 0.6 --min_lr 0.0011981 --wd 2e-4 --rho 0.25 --lazy_hessian 10 --eps 1e-6 --seed 0, 1, 2 \ 6 | /home/shared/dataset/imagenet -------------------------------------------------------------------------------- /image_classification/config/imagenet/vit_small/sgd.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 4 --dataset imagenet -a vit_s_32 --epochs 90 -b 1024 \ 2 | --LRScheduler cosine --warmup_epochs 8 \ 3 | --optimizer sgd \ 4 | --lr 0.1 --wd 1e-4 --seed 0, 1, 2 \ 5 | /home/shared/dataset/imagenet -------------------------------------------------------------------------------- /image_classification/config/imagenet/vit_small/sophiah.sh: -------------------------------------------------------------------------------- 1 | python train.py --workers 4 --dataset imagenet -a vit_s_32 --epochs 90 -b 1024 \ 2 | --LRScheduler cosine --warmup_epochs 8 \ 3 | --optimizer sophiah \ 4 | --lr 1e-3 --wd 1e-4 --lazy_hessian 1 --clip_threshold 0.01 --eps 1e-4 --seed 0, 1, 2 \ 5 | /home/shared/dataset/imagenet -------------------------------------------------------------------------------- /image_classification/models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.simple_vit import SimpleViT 2 | from models.resnet import * 3 | from models.wide_resnet import * 4 | 5 | import re 6 | from torchvision import models 7 | 8 | def create_vit_model(arch): 9 | """Create a Vision Transformer (ViT) model.""" 10 | patch_size = int(arch.split('_')[2]) 11 | if arch.startswith('vit_s_'): 12 | return SimpleViT( 13 | image_size=224, 14 | patch_size=patch_size, 15 | num_classes=1000, 16 | dim=384, 17 | depth=12, 18 | heads=6, 19 | mlp_dim=1536 20 | ) 21 | elif arch.startswith('vit_b_'): 22 | return SimpleViT( 23 | image_size=224, 24 | patch_size=patch_size, 25 | num_classes=1000, 26 | dim=768, 27 | depth=12, 28 | heads=12, 29 | mlp_dim=3072 30 | ) 31 | elif arch.startswith('vit_l_'): 32 | return SimpleViT( 33 | image_size=224, 34 | patch_size=patch_size, 35 | num_classes=1000, 36 | dim=1024, 37 | depth=24, 38 | heads=16, 39 | mlp_dim=4096 40 | ) 41 | else: 42 | raise ValueError(f"Unknown ViT architecture: {arch}") 43 | 44 | def create_resnet_model(arch, dataset): 45 | """Create a ResNet model.""" 46 | num_classes = 1000 if dataset == 'imagenet' else int(re.findall(r'\d+', dataset)[-1]) 47 | depth = int(re.findall(r'\d+', arch)[-1]) 48 | return resnet(num_classes=num_classes, depth=depth) 49 | 50 | def create_wideresnet_model(arch, dataset): 51 | """Create a Wide ResNet model.""" 52 | num_classes = 1000 if dataset == 'imagenet' else int(re.findall(r'\d+', dataset)[-1]) 53 | parts = arch.split('_') 54 | depth = int(parts[-2]) 55 | widen_factor = int(parts[-1]) 56 | return Wide_ResNet(depth=depth, widen_factor=widen_factor, dropout_rate=0.0, num_classes=num_classes) 57 | 58 | def get_model(args): 59 | """Main function to create a model based on the architecture.""" 60 | 61 | if args.arch.startswith('vit_'): 62 | return create_vit_model(args.arch) 63 | elif args.arch in ['resnet20', 'resnet32']: 64 | return create_resnet_model(args.arch, args.dataset) 65 | elif args.arch.startswith('wideresnet_'): 66 | return create_wideresnet_model(args.arch, args.dataset) 67 | else: 68 | # Default case: PyTorch model loaded directly 69 | return models.__dict__[args.arch]() 70 | -------------------------------------------------------------------------------- /image_classification/models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | '''Resnet for cifar dataset. 4 | Ported form 5 | https://github.com/facebook/fb.resnet.torch 6 | and 7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 8 | (c) YANG, Wei 9 | ''' 10 | import torch.nn as nn 11 | import math 12 | from copy import deepcopy 13 | 14 | __all__ = ['resnet'] 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | "3x3 convolution with padding" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | 22 | 23 | class BasicBlock(nn.Module): 24 | expansion = 1 25 | 26 | def __init__( 27 | self, 28 | inplanes, 29 | planes, 30 | residual_not, 31 | batch_norm_not, 32 | stride=1, 33 | downsample=None): 34 | super(BasicBlock, self).__init__() 35 | self.residual_not = residual_not 36 | self.batch_norm_not = batch_norm_not 37 | self.conv1 = conv3x3(inplanes, planes, stride) 38 | if self.batch_norm_not: 39 | self.bn1 = nn.BatchNorm2d(planes) 40 | self.relu = nn.ReLU(inplace=True) 41 | self.conv2 = conv3x3(planes, planes) 42 | if self.batch_norm_not: 43 | self.bn2 = nn.BatchNorm2d(planes) 44 | self.downsample = downsample 45 | self.stride = stride 46 | 47 | def forward(self, x): 48 | residual = x 49 | 50 | out = self.conv1(x) 51 | if self.batch_norm_not: 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | if self.batch_norm_not: 57 | out = self.bn2(out) 58 | 59 | if self.downsample is not None: 60 | residual = self.downsample(x) 61 | if self.residual_not: 62 | out += residual 63 | out = self.relu(out) 64 | 65 | return out 66 | 67 | 68 | class Bottleneck(nn.Module): 69 | expansion = 4 70 | 71 | def __init__( 72 | self, 73 | inplanes, 74 | planes, 75 | residual_not, 76 | batch_norm_not, 77 | stride=1, 78 | downsample=None): 79 | super(Bottleneck, self).__init__() 80 | self.residual_not = residual_not 81 | self.batch_norm_not = batch_norm_not 82 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 83 | if self.batch_norm_not: 84 | self.bn1 = nn.BatchNorm2d(planes) 85 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 86 | padding=1, bias=False) 87 | if self.batch_norm_not: 88 | self.bn2 = nn.BatchNorm2d(planes) 89 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 90 | if self.batch_norm_not: 91 | self.bn3 = nn.BatchNorm2d(planes * 4) 92 | self.relu = nn.ReLU(inplace=True) 93 | self.downsample = downsample 94 | self.stride = stride 95 | 96 | def forward(self, x): 97 | residual = x 98 | 99 | out = self.conv1(x) 100 | if self.batch_norm_not: 101 | out = self.bn1(out) 102 | out = self.relu(out) 103 | 104 | out = self.conv2(out) 105 | if self.batch_norm_not: 106 | out = self.bn2(out) 107 | out = self.relu(out) 108 | 109 | out = self.conv3(out) 110 | if self.batch_norm_not: 111 | out = self.bn3(out) 112 | 113 | if self.downsample is not None: 114 | residual = self.downsample(x) 115 | if self.residual_not: 116 | out += residual 117 | 118 | out = self.relu(out) 119 | 120 | return out 121 | 122 | 123 | ALPHA_ = 1 124 | 125 | 126 | class ResNet(nn.Module): 127 | 128 | def __init__( 129 | self, 130 | depth, 131 | residual_not=True, 132 | batch_norm_not=True, 133 | base_channel=16, 134 | num_classes=10): 135 | super(ResNet, self).__init__() 136 | # Model type specifies number of layers for CIFAR-10 model 137 | assert (depth - 2) % 6 == 0, 'depth should be 6n+2' 138 | n = (depth - 2) // 6 139 | 140 | # block = Bottleneck if depth >=44 else BasicBlock 141 | block = BasicBlock 142 | 143 | self.base_channel = int(base_channel) 144 | self.residual_not = residual_not 145 | self.batch_norm_not = batch_norm_not 146 | self.inplanes = self.base_channel * ALPHA_ 147 | self.conv1 = nn.Conv2d( 148 | 3, 149 | self.base_channel * 150 | ALPHA_, 151 | kernel_size=3, 152 | padding=1, 153 | bias=False) 154 | if self.batch_norm_not: 155 | self.bn1 = nn.BatchNorm2d(self.base_channel * ALPHA_) 156 | self.relu = nn.ReLU(inplace=True) 157 | self.layer1 = self._make_layer( 158 | block, 159 | self.base_channel * 160 | ALPHA_, 161 | n, 162 | self.residual_not, 163 | self.batch_norm_not) 164 | self.layer2 = self._make_layer( 165 | block, 166 | self.base_channel * 167 | 2 * 168 | ALPHA_, 169 | n, 170 | self.residual_not, 171 | self.batch_norm_not, 172 | stride=2) 173 | self.layer3 = self._make_layer( 174 | block, 175 | self.base_channel * 176 | 4 * 177 | ALPHA_, 178 | n, 179 | self.residual_not, 180 | self.batch_norm_not, 181 | stride=2) 182 | self.avgpool = nn.AvgPool2d(8) 183 | self.fc = nn.Linear( 184 | self.base_channel * 185 | 4 * 186 | ALPHA_ * 187 | block.expansion, 188 | num_classes) 189 | 190 | for m in self.modules(): 191 | if isinstance(m, nn.Conv2d): 192 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 193 | m.weight.data.normal_(0, math.sqrt(2. / n)) 194 | elif isinstance(m, nn.BatchNorm2d): 195 | m.weight.data.fill_(1) 196 | m.bias.data.zero_() 197 | 198 | def _make_layer( 199 | self, 200 | block, 201 | planes, 202 | blocks, 203 | residual_not, 204 | batch_norm_not, 205 | stride=1): 206 | downsample = None 207 | if (stride != 1 or self.inplanes != planes * 208 | block.expansion) and (residual_not): 209 | if batch_norm_not: 210 | downsample = nn.Sequential( 211 | nn.Conv2d(self.inplanes, planes * block.expansion, 212 | kernel_size=1, stride=stride, bias=False), 213 | nn.BatchNorm2d(planes * block.expansion), 214 | ) 215 | else: 216 | downsample = nn.Sequential( 217 | nn.Conv2d(self.inplanes, planes * block.expansion, 218 | kernel_size=1, stride=stride, bias=False), 219 | ) 220 | 221 | layers = nn.ModuleList() 222 | layers.append( 223 | block( 224 | self.inplanes, 225 | planes, 226 | residual_not, 227 | batch_norm_not, 228 | stride, 229 | downsample)) 230 | self.inplanes = planes * block.expansion 231 | for i in range(1, blocks): 232 | layers.append( 233 | block( 234 | self.inplanes, 235 | planes, 236 | residual_not, 237 | batch_norm_not)) 238 | 239 | # return nn.Sequential(*layers) 240 | return layers 241 | 242 | def forward(self, x): 243 | output_list = [] 244 | x = self.conv1(x) 245 | if self.batch_norm_not: 246 | x = self.bn1(x) 247 | x = self.relu(x) # 32x32 248 | output_list.append(x.view(x.size(0), -1)) 249 | 250 | for layer in self.layer1: 251 | x = layer(x) # 32x32 252 | output_list.append(x.view(x.size(0), -1)) 253 | for layer in self.layer2: 254 | x = layer(x) # 16x16 255 | output_list.append(x.view(x.size(0), -1)) 256 | for layer in self.layer3: 257 | x = layer(x) # 8x8 258 | output_list.append(x.view(x.size(0), -1)) 259 | 260 | x = self.avgpool(x) 261 | x = x.view(x.size(0), -1) 262 | x = self.fc(x) 263 | output_list.append(x.view(x.size(0), -1)) 264 | 265 | # return output_list, x 266 | return x 267 | 268 | 269 | def resnet(**kwargs): 270 | """ 271 | Constructs a ResNet model. 272 | """ 273 | return ResNet(**kwargs) 274 | -------------------------------------------------------------------------------- /image_classification/models/simple_vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from einops import rearrange 5 | from einops.layers.torch import Rearrange 6 | 7 | # helpers 8 | 9 | def pair(t): 10 | return t if isinstance(t, tuple) else (t, t) 11 | 12 | def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32): # 7, 7, 384 13 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") 14 | assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" 15 | omega = torch.arange(dim // 4) / (dim // 4 - 1) 16 | omega = 1.0 / (temperature ** omega) 17 | 18 | y = y.flatten()[:, None] * omega[None, :] 19 | x = x.flatten()[:, None] * omega[None, :] 20 | pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) 21 | return pe.type(dtype) 22 | 23 | # classes 24 | 25 | class FeedForward(nn.Module): 26 | def __init__(self, dim, hidden_dim): 27 | super().__init__() 28 | self.net = nn.Sequential( 29 | nn.LayerNorm(dim), 30 | nn.Linear(dim, hidden_dim), 31 | nn.GELU(), 32 | nn.Linear(hidden_dim, dim), 33 | ) 34 | def forward(self, x): 35 | return self.net(x) 36 | 37 | class Attention(nn.Module): 38 | def __init__(self, dim, heads = 8, dim_head = 64): 39 | super().__init__() 40 | inner_dim = dim_head * heads 41 | self.heads = heads 42 | self.scale = dim_head ** -0.5 43 | self.norm = nn.LayerNorm(dim) 44 | 45 | self.attend = nn.Softmax(dim = -1) 46 | 47 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 48 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 49 | 50 | def forward(self, x): 51 | x = self.norm(x) 52 | 53 | qkv = self.to_qkv(x).chunk(3, dim = -1) 54 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 55 | 56 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 57 | 58 | attn = self.attend(dots) 59 | 60 | out = torch.matmul(attn, v) 61 | out = rearrange(out, 'b h n d -> b n (h d)') 62 | return self.to_out(out) 63 | 64 | class Transformer(nn.Module): 65 | def __init__(self, dim, depth, heads, dim_head, mlp_dim): 66 | super().__init__() 67 | self.norm = nn.LayerNorm(dim) # MSAM (mean -> norm -> linear) This (norm > mean > Linear) 68 | self.layers = nn.ModuleList([]) 69 | for _ in range(depth): 70 | self.layers.append(nn.ModuleList([ 71 | Attention(dim, heads = heads, dim_head = dim_head), 72 | FeedForward(dim, mlp_dim) 73 | ])) 74 | def forward(self, x): 75 | for attn, ff in self.layers: 76 | x = attn(x) + x 77 | x = ff(x) + x 78 | return self.norm(x) 79 | 80 | class SimpleViT(nn.Module): 81 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64): 82 | super().__init__() 83 | image_height, image_width = pair(image_size) 84 | patch_height, patch_width = pair(patch_size) 85 | 86 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 87 | 88 | patch_dim = channels * patch_height * patch_width # patch_dim = 3 * 32 * 32 = 3072 89 | 90 | self.to_patch_embedding = nn.Sequential( 91 | Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width), # b 3 (7 32) (7 32) -> b 49 3072 92 | nn.LayerNorm(patch_dim), 93 | nn.Linear(patch_dim, dim), # dim = 384 (b 49 384) 94 | nn.LayerNorm(dim), 95 | ) 96 | 97 | self.pos_embedding = posemb_sincos_2d( 98 | h = image_height // patch_height, # 224 // 32 = 7 99 | w = image_width // patch_width, # 224 // 32 = 7 100 | dim = dim, # 384 101 | ) 102 | 103 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) 104 | 105 | self.pool = "mean" 106 | self.to_latent = nn.Identity() 107 | 108 | self.linear_head = nn.Linear(dim, num_classes) 109 | 110 | def forward(self, img): 111 | device = img.device 112 | 113 | x = self.to_patch_embedding(img) # (b 49 384) 114 | x += self.pos_embedding.to(device, dtype=x.dtype) 115 | 116 | x = self.transformer(x) 117 | x = x.mean(dim = 1) 118 | 119 | x = self.to_latent(x) 120 | return self.linear_head(x) 121 | -------------------------------------------------------------------------------- /image_classification/models/wide_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | import sys 8 | import numpy as np 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 12 | 13 | def conv_init(m): 14 | classname = m.__class__.__name__ 15 | if classname.find('Conv') != -1: 16 | init.xavier_uniform_(m.weight, gain=np.sqrt(2)) 17 | init.constant_(m.bias, 0) 18 | elif classname.find('BatchNorm') != -1: 19 | init.constant_(m.weight, 1) 20 | init.constant_(m.bias, 0) 21 | 22 | class wide_basic(nn.Module): 23 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 24 | super(wide_basic, self).__init__() 25 | self.bn1 = nn.BatchNorm2d(in_planes) 26 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 27 | self.dropout = nn.Dropout(p=dropout_rate) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 30 | 31 | self.shortcut = nn.Sequential() 32 | if stride != 1 or in_planes != planes: 33 | self.shortcut = nn.Sequential( 34 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 35 | ) 36 | 37 | def forward(self, x): 38 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 39 | out = self.conv2(F.relu(self.bn2(out))) 40 | out += self.shortcut(x) 41 | 42 | return out 43 | 44 | class Wide_ResNet(nn.Module): 45 | def __init__(self, depth, widen_factor, dropout_rate, num_classes): 46 | super(Wide_ResNet, self).__init__() 47 | self.in_planes = 16 48 | 49 | assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4' 50 | n = (depth-4)/6 51 | k = widen_factor 52 | 53 | print('| Wide-Resnet %dx%d' %(depth, k)) 54 | nStages = [16, 16*k, 32*k, 64*k] 55 | 56 | self.conv1 = conv3x3(3,nStages[0]) 57 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1) 58 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2) 59 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2) 60 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 61 | self.linear = nn.Linear(nStages[3], num_classes) 62 | 63 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 64 | strides = [stride] + [1]*(int(num_blocks)-1) 65 | layers = [] 66 | 67 | for stride in strides: 68 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 69 | self.in_planes = planes 70 | 71 | return nn.Sequential(*layers) 72 | 73 | def forward(self, x): 74 | out = self.conv1(x) 75 | out = self.layer1(out) 76 | out = self.layer2(out) 77 | out = self.layer3(out) 78 | out = F.relu(self.bn1(out)) 79 | out = F.avg_pool2d(out, 8) 80 | out = out.view(out.size(0), -1) 81 | out = self.linear(out) 82 | 83 | return out 84 | 85 | if __name__ == '__main__': 86 | net=Wide_ResNet(28, 10, 0.3, 10) 87 | y = net(Variable(torch.randn(1,3,32,32))) 88 | 89 | print(y.size()) 90 | -------------------------------------------------------------------------------- /image_classification/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | from torchvision import datasets, transforms 9 | from torch.utils.data import Dataset, DataLoader 10 | from torch.autograd import Variable 11 | 12 | 13 | def getData( 14 | name='cifar10', 15 | path='/home/dahunshin/imagenet', 16 | train_bs=256, 17 | test_bs=1000, 18 | num_workers=1, 19 | distributed=False): 20 | 21 | if name == 'mnist': 22 | 23 | train_loader = torch.utils.data.DataLoader( 24 | datasets.MNIST('../data', train=True, download=True, 25 | transform=transforms.Compose([ 26 | transforms.ToTensor(), 27 | transforms.Normalize((0.1307,), (0.3081,)) 28 | ])), 29 | batch_size=train_bs, shuffle=True) 30 | test_loader = torch.utils.data.DataLoader( 31 | datasets.MNIST( 32 | '../data', 33 | train=False, 34 | transform=transforms.Compose( 35 | [ 36 | transforms.ToTensor(), 37 | transforms.Normalize( 38 | (0.1307, 39 | ), 40 | (0.3081, 41 | ))])), 42 | batch_size=test_bs, 43 | shuffle=False) 44 | 45 | 46 | if name == 'cifar10': 47 | transform_train = transforms.Compose([ 48 | transforms.RandomCrop(32, padding=4), 49 | transforms.RandomHorizontalFlip(), 50 | transforms.ToTensor(), 51 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 52 | ]) 53 | 54 | transform_test = transforms.Compose([ 55 | transforms.ToTensor(), 56 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 57 | ]) 58 | 59 | trainset = datasets.CIFAR10( 60 | root='../data', 61 | train=True, 62 | download=True, 63 | transform=transform_train) 64 | 65 | testset = datasets.CIFAR10( 66 | root='../data', 67 | train=False, 68 | download=False, 69 | transform=transform_test) 70 | 71 | if distributed: 72 | train_sampler = torch.utils.data.distributed.DistributedSampler(trainset, shuffle=True, drop_last=False) 73 | test_sampler = torch.utils.data.distributed.DistributedSampler(testset, shuffle=False, drop_last=True) 74 | else: 75 | train_sampler = None 76 | test_sampler = None 77 | 78 | train_loader = torch.utils.data.DataLoader( 79 | trainset, batch_size=train_bs, shuffle=(train_sampler is None), 80 | num_workers=num_workers, pin_memory=True, sampler=train_sampler) 81 | 82 | test_loader = torch.utils.data.DataLoader( 83 | testset, batch_size=test_bs, shuffle=False, 84 | num_workers=num_workers, pin_memory=True, sampler=test_sampler) 85 | 86 | 87 | if name == 'cifar100': 88 | 89 | transform_train = transforms.Compose([ 90 | transforms.RandomCrop(32, padding=4), 91 | transforms.RandomHorizontalFlip(), 92 | transforms.ToTensor(), 93 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 94 | ]) 95 | 96 | transform_test = transforms.Compose([ 97 | transforms.ToTensor(), 98 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 99 | ]) 100 | 101 | trainset = datasets.CIFAR100( 102 | root='../data', 103 | train=True, 104 | download=True, 105 | transform=transform_train) 106 | 107 | testset = datasets.CIFAR100( 108 | root='../data', 109 | train=False, 110 | download=False, 111 | transform=transform_test) 112 | 113 | if distributed: 114 | train_sampler = torch.utils.data.distributed.DistributedSampler(trainset, shuffle=True, drop_last=False) 115 | test_sampler = torch.utils.data.distributed.DistributedSampler(testset, shuffle=False, drop_last=True) 116 | else: 117 | train_sampler = None 118 | test_sampler = None 119 | 120 | train_loader = torch.utils.data.DataLoader( 121 | trainset, batch_size=train_bs, shuffle=(train_sampler is None), 122 | num_workers=num_workers, pin_memory=True, sampler=train_sampler) 123 | 124 | test_loader = torch.utils.data.DataLoader( 125 | testset, batch_size=test_bs, shuffle=False, 126 | num_workers=num_workers, pin_memory=True, sampler=test_sampler) 127 | 128 | 129 | if name == 'imagenet': 130 | traindir = os.path.join(path, 'train') 131 | valdir = os.path.join(path, 'val') 132 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 133 | std=[0.229, 0.224, 0.225]) 134 | 135 | train_dataset = datasets.ImageFolder( 136 | traindir, 137 | transforms.Compose([ 138 | transforms.RandomResizedCrop(224), 139 | transforms.RandomHorizontalFlip(), 140 | transforms.ToTensor(), 141 | normalize, 142 | ])) 143 | 144 | val_dataset = datasets.ImageFolder( 145 | valdir, 146 | transforms.Compose([ 147 | transforms.Resize(256), 148 | transforms.CenterCrop(224), 149 | transforms.ToTensor(), 150 | normalize, 151 | ])) 152 | 153 | if distributed: 154 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=False) 155 | test_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True) 156 | else: 157 | train_sampler = None 158 | test_sampler = None 159 | 160 | train_loader = torch.utils.data.DataLoader( 161 | train_dataset, batch_size=train_bs, shuffle=(train_sampler is None), 162 | num_workers=num_workers, pin_memory=True, sampler=train_sampler) 163 | 164 | test_loader = torch.utils.data.DataLoader( 165 | val_dataset, batch_size=test_bs, shuffle=False, 166 | num_workers=num_workers, pin_memory=True, sampler=test_sampler) 167 | 168 | 169 | return train_loader, test_loader, train_sampler, test_sampler 170 | -------------------------------------------------------------------------------- /label_noise/bypass_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.modules.batchnorm import _BatchNorm 4 | 5 | 6 | def disable_running_stats(model): 7 | def _disable(module): 8 | if isinstance(module, _BatchNorm): 9 | module.backup_momentum = module.momentum 10 | module.momentum = 0 11 | 12 | model.apply(_disable) 13 | 14 | def enable_running_stats(model): 15 | def _enable(module): 16 | if isinstance(module, _BatchNorm) and hasattr(module, "backup_momentum"): 17 | module.momentum = module.backup_momentum 18 | 19 | model.apply(_enable) 20 | -------------------------------------------------------------------------------- /label_noise/config/cifar10/adahessian/noise20.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer adahessian \ 3 | --noise_level 0.2 \ 4 | --lr 0.1 --wd 5e-4 --lazy_hessian 1 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar10/adahessian/noise40.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer adahessian \ 3 | --noise_level 0.4 \ 4 | --lr 0.1 --wd 5e-4 --lazy_hessian 1 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar10/adahessian/noise60.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer adahessian \ 3 | --noise_level 0.6 \ 4 | --lr 0.05 --wd 5e-4 --lazy_hessian 1 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar10/msassha/noise20.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer msassha \ 3 | --noise_level 0.2 \ 4 | --lr 0.15 --wd 1e-3 --rho 0.8 --lazy_hessian 10 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar10/msassha/noise40.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer msassha \ 3 | --noise_level 0.4 \ 4 | --lr 0.05 --wd 1e-3 --rho 0.25 --lazy_hessian 10 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar10/msassha/noise60.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer msassha \ 3 | --noise_level 0.6 \ 4 | --lr 0.05 --wd 1e-3 --rho 0.25 --lazy_hessian 10 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar10/sam_sgd/noise20.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer samsgd \ 3 | --noise_level 0.2 \ 4 | --lr 0.1 --wd 1e-3 --rho 0.2 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar10/sam_sgd/noise40.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer samsgd \ 3 | --noise_level 0.4 \ 4 | --lr 0.1 --wd 5e-4 --rho 0.2 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar10/sam_sgd/noise60.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer samsgd \ 3 | --noise_level 0.6 \ 4 | --lr 0.03 --wd 1e-3 --rho 0.1 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar10/sassha/noise20.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer sassha \ 3 | --noise_level 0.2 \ 4 | --hessian_power_scheduler constant \ 5 | --lr 0.1 --min_lr 0.001 --wd 1e-3 --rho 0.1 --lazy_hessian 10 --seed 0, 1, 2 \ 6 | -------------------------------------------------------------------------------- /label_noise/config/cifar10/sassha/noise40.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer sassha \ 3 | --noise_level 0.4 \ 4 | --hessian_power_scheduler constant \ 5 | --lr 0.1 --min_lr 0.001 --wd 5e-4 --rho 0.2 --lazy_hessian 10 --seed 0, 1, 2 \ 6 | -------------------------------------------------------------------------------- /label_noise/config/cifar10/sassha/noise60.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer sassha \ 3 | --noise_level 0.6 \ 4 | --hessian_power_scheduler constant \ 5 | --lr 0.1 --min_lr 0.001 --wd 5e-4 --rho 0.2 --lazy_hessian 10 --seed 0, 1, 2 \ 6 | -------------------------------------------------------------------------------- /label_noise/config/cifar10/sgd/noise20.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer sgd \ 3 | --noise_level 0.2 \ 4 | --lr 0.015 --wd 2e-3 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar10/sgd/noise40.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer sgd \ 3 | --noise_level 0.4 \ 4 | --lr 0.015 --wd 2e-3 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar10/sgd/noise60.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer sgd \ 3 | --noise_level 0.6 \ 4 | --lr 0.05 --wd 1e-3 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar10/shampoo/noise20.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer shampoo \ 3 | --noise_level 0.2 \ 4 | --lr 0.6 --wd 5e-3 --eps 1e-2 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar10/shampoo/noise40.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer shampoo \ 3 | --noise_level 0.4 \ 4 | --lr 0.15 --wd 5e-2 --eps 1e-2 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar10/shampoo/noise60.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer shampoo \ 3 | --noise_level 0.6 \ 4 | --lr 0.15 --wd 5e-2 --eps 1e-2 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar10/sophiah/noise20.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer sophiah \ 3 | --noise_level 0.2 \ 4 | --lr 1e-3 --wd 5e-4 --clip_threshold 0.1 --eps 1e-4 --lazy_hessian 1 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar10/sophiah/noise40.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer sophiah \ 3 | --noise_level 0.4 \ 4 | --lr 1e-3 --wd 5e-4 --clip_threshold 0.1 --eps 1e-4 --lazy_hessian 1 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar10/sophiah/noise60.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer sophiah \ 3 | --noise_level 0.6 \ 4 | --lr 1e-3 --wd 5e-4 --clip_threshold 0.1 --eps 1e-4 --lazy_hessian 1 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar100/adahessian/noise20.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer adahessian \ 3 | --noise_level 0.2 \ 4 | --lr 0.0015 --wd 5e-4 --lazy_hessian 1 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar100/adahessian/noise40.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer adahessian \ 3 | --noise_level 0.4 \ 4 | --lr 1e-3 --wd 5e-4 --lazy_hessian 1 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar100/adahessian/noise60.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer adahessian \ 3 | --noise_level 0.6 \ 4 | --lr 1e-2 --wd 5e-4 --lazy_hessian 1 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar100/msassha/noise20.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer msassha \ 3 | --noise_level 0.2 \ 4 | --lr 0.15 --wd 1e-3 --rho 0.8 --lazy_hessian 10 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar100/msassha/noise40.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer msassha \ 3 | --noise_level 0.4 \ 4 | --lr 0.15 --wd 1e-3 --rho 0.8 --lazy_hessian 10 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar100/msassha/noise60.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer msassha \ 3 | --noise_level 0.6 \ 4 | --lr 0.15 --wd 1e-3 --rho 0.8 --lazy_hessian 10 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar100/sam_sgd/noise20.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer samsgd \ 3 | --noise_level 0.2 \ 4 | --lr 0.1 --wd 1e-3 --rho 0.2 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar100/sam_sgd/noise40.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer samsgd \ 3 | --noise_level 0.4 \ 4 | --lr 0.01 --wd 1e-3 --rho 0.1 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar100/sam_sgd/noise60.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer samsgd \ 3 | --noise_level 0.6 \ 4 | --lr 0.03 --wd 1e-3 --rho 0.1 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar100/sassha/noise20.sh: -------------------------------------------------------------------------------- 1 | # we use three seeds 0, 1, 2 2 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \ 3 | --optimizer sassha \ 4 | --noise_level 0.2 \ 5 | --hessian_power_scheduler constant \ 6 | --lr 0.1 --min_lr 0.001 --wd 5e-4 --rho 0.1 --lazy_hessian 10 --seed 0, 1, 2 7 | 8 | 9 | -------------------------------------------------------------------------------- /label_noise/config/cifar100/sassha/noise40.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer sassha \ 3 | --noise_level 0.4 \ 4 | --hessian_power_scheduler constant \ 5 | --lr 0.1 --min_lr 0.001 --wd 5e-4 --rho 0.2 --lazy_hessian 10 --seed 0, 1, 2 \ 6 | -------------------------------------------------------------------------------- /label_noise/config/cifar100/sassha/noise60.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer sassha \ 3 | --noise_level 0.6 \ 4 | --hessian_power_scheduler constant \ 5 | --lr 0.1 --min_lr 0.001 --wd 5e-4 --rho 0.2 --lazy_hessian 10 --seed 0, 1, 2 \ 6 | -------------------------------------------------------------------------------- /label_noise/config/cifar100/sgd/noise20.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer sgd \ 3 | --noise_level 0.2 \ 4 | --lr 5e-4 --wd 2e-3 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar100/sgd/noise40.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer sgd \ 3 | --noise_level 0.4 \ 4 | --lr 5e-4 --wd 2e-3 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar100/sgd/noise60.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer sgd \ 3 | --noise_level 0.6 \ 4 | --lr 5e-4 --wd 2e-3 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar100/shampoo/noise20.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer shampoo \ 3 | --noise_level 0.2 \ 4 | --lr 0.6 --wd 1e-2 --eps 1e-4 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar100/shampoo/noise40.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer shampoo \ 3 | --noise_level 0.4 \ 4 | --lr 0.6 --wd 1e-2 --eps 1e-4 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar100/shampoo/noise60.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer shampoo \ 3 | --noise_level 0.6 \ 4 | --lr 0.3 --wd 1e-2 --eps 1e-4 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar100/sophiah/noise20.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer sophiah \ 3 | --noise_level 0.2 \ 4 | --lr 1e-5 --wd 5e-4 --clip_threshold 0.05 --eps 1e-4 --lazy_hessian 1 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar100/sophiah/noise40.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer sophiah \ 3 | --noise_level 0.4 \ 4 | --lr 1e-5 --wd 5e-4 --clip_threshold 0.1 --eps 1e-4 --lazy_hessian 1 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/config/cifar100/sophiah/noise60.sh: -------------------------------------------------------------------------------- 1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \ 2 | --optimizer sophiah \ 3 | --noise_level 0.6 \ 4 | --lr 1e-5 --wd 5e-4 --clip_threshold 0.1 --eps 1e-4 --lazy_hessian 1 --seed 0, 1, 2 \ 5 | -------------------------------------------------------------------------------- /label_noise/models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | '''Resnet for cifar dataset. 4 | Ported form 5 | https://github.com/facebook/fb.resnet.torch 6 | and 7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 8 | (c) YANG, Wei 9 | ''' 10 | import torch.nn as nn 11 | import math 12 | from copy import deepcopy 13 | 14 | __all__ = ['resnet'] 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | "3x3 convolution with padding" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | 22 | 23 | class BasicBlock(nn.Module): 24 | expansion = 1 25 | 26 | def __init__( 27 | self, 28 | inplanes, 29 | planes, 30 | residual_not, 31 | batch_norm_not, 32 | stride=1, 33 | downsample=None): 34 | super(BasicBlock, self).__init__() 35 | self.residual_not = residual_not 36 | self.batch_norm_not = batch_norm_not 37 | self.conv1 = conv3x3(inplanes, planes, stride) 38 | if self.batch_norm_not: 39 | self.bn1 = nn.BatchNorm2d(planes) 40 | self.relu = nn.ReLU(inplace=True) 41 | self.conv2 = conv3x3(planes, planes) 42 | if self.batch_norm_not: 43 | self.bn2 = nn.BatchNorm2d(planes) 44 | self.downsample = downsample 45 | self.stride = stride 46 | 47 | def forward(self, x): 48 | residual = x 49 | 50 | out = self.conv1(x) 51 | if self.batch_norm_not: 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | if self.batch_norm_not: 57 | out = self.bn2(out) 58 | 59 | if self.downsample is not None: 60 | residual = self.downsample(x) 61 | if self.residual_not: 62 | out += residual 63 | out = self.relu(out) 64 | 65 | return out 66 | 67 | 68 | class Bottleneck(nn.Module): 69 | expansion = 4 70 | 71 | def __init__( 72 | self, 73 | inplanes, 74 | planes, 75 | residual_not, 76 | batch_norm_not, 77 | stride=1, 78 | downsample=None): 79 | super(Bottleneck, self).__init__() 80 | self.residual_not = residual_not 81 | self.batch_norm_not = batch_norm_not 82 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 83 | if self.batch_norm_not: 84 | self.bn1 = nn.BatchNorm2d(planes) 85 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 86 | padding=1, bias=False) 87 | if self.batch_norm_not: 88 | self.bn2 = nn.BatchNorm2d(planes) 89 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 90 | if self.batch_norm_not: 91 | self.bn3 = nn.BatchNorm2d(planes * 4) 92 | self.relu = nn.ReLU(inplace=True) 93 | self.downsample = downsample 94 | self.stride = stride 95 | 96 | def forward(self, x): 97 | residual = x 98 | 99 | out = self.conv1(x) 100 | if self.batch_norm_not: 101 | out = self.bn1(out) 102 | out = self.relu(out) 103 | 104 | out = self.conv2(out) 105 | if self.batch_norm_not: 106 | out = self.bn2(out) 107 | out = self.relu(out) 108 | 109 | out = self.conv3(out) 110 | if self.batch_norm_not: 111 | out = self.bn3(out) 112 | 113 | if self.downsample is not None: 114 | residual = self.downsample(x) 115 | if self.residual_not: 116 | out += residual 117 | 118 | out = self.relu(out) 119 | 120 | return out 121 | 122 | 123 | ALPHA_ = 1 124 | 125 | 126 | class ResNet(nn.Module): 127 | 128 | def __init__( 129 | self, 130 | depth, 131 | residual_not=True, 132 | batch_norm_not=True, 133 | base_channel=16, 134 | num_classes=10): 135 | super(ResNet, self).__init__() 136 | # Model type specifies number of layers for CIFAR-10 model 137 | assert (depth - 2) % 6 == 0, 'depth should be 6n+2' 138 | n = (depth - 2) // 6 139 | 140 | # block = Bottleneck if depth >=44 else BasicBlock 141 | block = BasicBlock 142 | 143 | self.base_channel = int(base_channel) 144 | self.residual_not = residual_not 145 | self.batch_norm_not = batch_norm_not 146 | self.inplanes = self.base_channel * ALPHA_ 147 | self.conv1 = nn.Conv2d( 148 | 3, 149 | self.base_channel * 150 | ALPHA_, 151 | kernel_size=3, 152 | padding=1, 153 | bias=False) 154 | if self.batch_norm_not: 155 | self.bn1 = nn.BatchNorm2d(self.base_channel * ALPHA_) 156 | self.relu = nn.ReLU(inplace=True) 157 | self.layer1 = self._make_layer( 158 | block, 159 | self.base_channel * 160 | ALPHA_, 161 | n, 162 | self.residual_not, 163 | self.batch_norm_not) 164 | self.layer2 = self._make_layer( 165 | block, 166 | self.base_channel * 167 | 2 * 168 | ALPHA_, 169 | n, 170 | self.residual_not, 171 | self.batch_norm_not, 172 | stride=2) 173 | self.layer3 = self._make_layer( 174 | block, 175 | self.base_channel * 176 | 4 * 177 | ALPHA_, 178 | n, 179 | self.residual_not, 180 | self.batch_norm_not, 181 | stride=2) 182 | self.avgpool = nn.AvgPool2d(8) 183 | self.fc = nn.Linear( 184 | self.base_channel * 185 | 4 * 186 | ALPHA_ * 187 | block.expansion, 188 | num_classes) 189 | 190 | for m in self.modules(): 191 | if isinstance(m, nn.Conv2d): 192 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 193 | m.weight.data.normal_(0, math.sqrt(2. / n)) 194 | elif isinstance(m, nn.BatchNorm2d): 195 | m.weight.data.fill_(1) 196 | m.bias.data.zero_() 197 | 198 | def _make_layer( 199 | self, 200 | block, 201 | planes, 202 | blocks, 203 | residual_not, 204 | batch_norm_not, 205 | stride=1): 206 | downsample = None 207 | if (stride != 1 or self.inplanes != planes * 208 | block.expansion) and (residual_not): 209 | if batch_norm_not: 210 | downsample = nn.Sequential( 211 | nn.Conv2d(self.inplanes, planes * block.expansion, 212 | kernel_size=1, stride=stride, bias=False), 213 | nn.BatchNorm2d(planes * block.expansion), 214 | ) 215 | else: 216 | downsample = nn.Sequential( 217 | nn.Conv2d(self.inplanes, planes * block.expansion, 218 | kernel_size=1, stride=stride, bias=False), 219 | ) 220 | 221 | layers = nn.ModuleList() 222 | layers.append( 223 | block( 224 | self.inplanes, 225 | planes, 226 | residual_not, 227 | batch_norm_not, 228 | stride, 229 | downsample)) 230 | self.inplanes = planes * block.expansion 231 | for i in range(1, blocks): 232 | layers.append( 233 | block( 234 | self.inplanes, 235 | planes, 236 | residual_not, 237 | batch_norm_not)) 238 | 239 | # return nn.Sequential(*layers) 240 | return layers 241 | 242 | def forward(self, x): 243 | output_list = [] 244 | x = self.conv1(x) 245 | if self.batch_norm_not: 246 | x = self.bn1(x) 247 | x = self.relu(x) # 32x32 248 | output_list.append(x.view(x.size(0), -1)) 249 | 250 | for layer in self.layer1: 251 | x = layer(x) # 32x32 252 | output_list.append(x.view(x.size(0), -1)) 253 | for layer in self.layer2: 254 | x = layer(x) # 16x16 255 | output_list.append(x.view(x.size(0), -1)) 256 | for layer in self.layer3: 257 | x = layer(x) # 8x8 258 | output_list.append(x.view(x.size(0), -1)) 259 | 260 | x = self.avgpool(x) 261 | x = x.view(x.size(0), -1) 262 | x = self.fc(x) 263 | output_list.append(x.view(x.size(0), -1)) 264 | 265 | # return output_list, x 266 | return x 267 | 268 | 269 | def resnet(**kwargs): 270 | """ 271 | Constructs a ResNet model. 272 | """ 273 | return ResNet(**kwargs) 274 | -------------------------------------------------------------------------------- /label_noise/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import logging 3 | import os 4 | import sys 5 | import random 6 | 7 | import numpy as np 8 | import argparse 9 | from tqdm import tqdm, trange 10 | 11 | import torch 12 | import torch.backends.cudnn as cudnn 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | import torch.optim.lr_scheduler as lr_scheduler 17 | from torchvision import datasets, transforms 18 | from torch.autograd import Variable 19 | 20 | from utils import * 21 | from models.resnet import * 22 | from bypass_bn import enable_running_stats, disable_running_stats 23 | 24 | # load optimizers 25 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname('optimizers')))) 26 | from optimizers import get_optimizer 27 | 28 | # load hessain power scheduler 29 | from optimizers.hessian_scheduler import ConstantScheduler, ProportionScheduler, LinearScheduler, CosineScheduler 30 | 31 | # Training settings 32 | parser = argparse.ArgumentParser(description='PyTorch Example') 33 | parser.add_argument('--batch-size', type=int, default=256, metavar='B', 34 | help='input batch size for training (default: 256)') 35 | parser.add_argument('--test-batch-size', type=int, default=256, metavar='TB', 36 | help='input batch size for testing (default: 256)') 37 | parser.add_argument('--epochs', type=int, default=160, metavar='E', 38 | help='number of epochs to train (default: 10)') 39 | parser.add_argument('--lr', type=float, default=0.15, metavar='LR', 40 | help='learning rate (default: 0.15)') 41 | parser.add_argument('--lr-decay', type=float, default=0.1, 42 | help='learning rate ratio') 43 | parser.add_argument('--lr-decay-epoch', type=int, nargs='+', default=[80, 120], 44 | help='decrease learning rate at these epochs.') 45 | parser.add_argument('--seed', type=int, default=0, metavar='S', 46 | help='random seed (default: 0)') 47 | parser.add_argument('--weight_decay', '--wd', default=5e-4, type=float, 48 | metavar='W', help='weight decay (default: 5e-4)') 49 | parser.add_argument('--depth', type=int, default=20, 50 | help='choose the depth of resnet') 51 | parser.add_argument('--optimizer', type=str, default='sassha', 52 | help='choose optim') 53 | parser.add_argument('--data', type=str, default='cifar100', 54 | help='choose dataset cifar10/100') 55 | parser.add_argument('--noise_level', type=float, default=0.2, 56 | help='noise_level') 57 | parser.add_argument('--min_lr', type=float, default=0.0, help="the minimum value of learning rate") 58 | 59 | # Second-order optimization settings 60 | parser.add_argument("--n_samples", default=1, type=int, help="the number of sampling") 61 | parser.add_argument('--betas', type=float, nargs='*', default=[0.9, 0.999], help='betas') 62 | parser.add_argument("--eps", default=1e-4, type=float, help="add a small number for stability") 63 | parser.add_argument("--lazy_hessian", default=10, type=int, help="Delayed hessian update.") 64 | parser.add_argument("--clip_threshold", default=0.05, type=float, help="Clipping threshold.") 65 | 66 | # Hessian power scheduler 67 | parser.add_argument('--hessian_power_scheduler', type=str, default='constant', help="choose LRScheduler 1. 'constant', 2. 'proportion', 3. 'linear', 4. 'cosine'") 68 | parser.add_argument('--max_hessian_power', type=float, default=1) 69 | parser.add_argument('--min_hessian_power', type=float, default=0.5) 70 | 71 | # Sharpness minimization settings 72 | parser.add_argument("--rho", default=0.05, type=float, help="Rho parameter for SAM.") 73 | parser.add_argument("--adaptive", default=False, type=bool, help="True if you want to use the Adaptive SAM.") 74 | parser.add_argument('--project_name', type=str, default='project_name', help="project_name") 75 | 76 | args = parser.parse_args() 77 | 78 | # wandb logging 79 | wandb_log = True 80 | wandb_project = args.project_name 81 | wandb_run_name = f'{args.optimizer}_lr_{args.lr}_wd_{args.weight_decay}_rho_{args.rho}' 82 | 83 | num_classes = ''.join([c for c in args.data if c.isdigit()]) 84 | num_classes = int(num_classes) 85 | 86 | # set random seed to reproduce the work 87 | random.seed(args.seed) 88 | np.random.seed(args.seed) 89 | torch.manual_seed(args.seed) 90 | torch.cuda.manual_seed(args.seed) 91 | cudnn.deterministic = True 92 | cudnn.benchmark = False 93 | 94 | for arg in vars(args): 95 | print(arg, getattr(args, arg)) 96 | if not os.path.isdir('checkpoint/'): 97 | os.makedirs('checkpoint/') 98 | 99 | # get a dataset (e.g,. cifar10, cifar100) 100 | train_loader, test_loader = getNoisyData(name=args.data, 101 | train_bs=args.batch_size, 102 | test_bs=args.test_batch_size, 103 | noise_level=args.noise_level) 104 | 105 | # get a model 106 | model = resnet(num_classes=num_classes, depth=args.depth).cuda() 107 | print(model) 108 | 109 | model = torch.nn.DataParallel(model) 110 | print(' Total params: %.2fM' % (sum(p.numel() 111 | for p in model.parameters()) / 1000000.0)) 112 | # define a loss 113 | criterion = nn.CrossEntropyLoss() 114 | 115 | # get an optimizer 116 | optimizer, create_graph, two_steps = get_optimizer(model, args) 117 | 118 | # learning rate schedule 119 | scheduler = lr_scheduler.MultiStepLR( 120 | optimizer, 121 | args.lr_decay_epoch, 122 | gamma=args.lr_decay, 123 | last_epoch=-1) 124 | 125 | # select a hessian power scheduler 126 | if args.hessian_power_scheduler == 'constant': 127 | hessian_power_scheduler = ConstantScheduler( 128 | T_max=args.epochs*len(train_loader), 129 | max_value=0.5, 130 | min_value=0.5) 131 | 132 | elif args.hessian_power_scheduler == 'proportion': 133 | hessian_power_scheduler = ProportionScheduler( 134 | pytorch_lr_scheduler=scheduler, 135 | max_lr=args.lr, 136 | min_lr=args.min_lr, 137 | max_value=args.max_hessian_power, 138 | min_value=args.min_hessian_power) 139 | 140 | elif args.hessian_power_scheduler == 'linear': 141 | hessian_power_scheduler = LinearScheduler( 142 | T_max=args.epochs*len(train_loader), 143 | max_value=args.max_hessian_power, 144 | min_value=args.min_hessian_power) 145 | 146 | elif args.hessian_power_scheduler == 'cosine': 147 | hessian_power_scheduler = CosineScheduler( 148 | T_max=args.epochs*len(train_loader), 149 | max_value=args.max_hessian_power, 150 | min_value=args.min_hessian_power) 151 | 152 | optimizer.hessian_power_scheduler = hessian_power_scheduler 153 | 154 | # import and init wandb 155 | if wandb_log: 156 | import wandb 157 | os.environ["WANDB__SERVICE_WAIT"] = "300" 158 | wandb.init(project=wandb_project, name=wandb_run_name) 159 | wandb.config.update(args) 160 | 161 | best_acc = 0.0 162 | iter_num = 0 163 | # training loop 164 | for epoch in range(1, args.epochs + 1): 165 | print('Current Epoch: ', epoch) 166 | train_loss = 0. 167 | total_num = 0 168 | correct = 0 169 | 170 | scheduler.step() 171 | model.train() 172 | 173 | if args.optimizer == 'msassha': 174 | optimizer.move_up_to_momentumAscent() 175 | 176 | with tqdm(total=len(train_loader.dataset)) as progressbar: 177 | for batch_idx, (data, target) in enumerate(train_loader): 178 | data, target = data.cuda(), target.cuda() 179 | 180 | if two_steps: 181 | enable_running_stats(model) 182 | output = model(data) 183 | loss = criterion(output, target) 184 | loss.backward() 185 | 186 | if args.optimizer == 'sassha': 187 | optimizer.perturb_weights(zero_grad=True) 188 | 189 | elif args.optimizer in ['samsgd', 'samadamw']: 190 | optimizer.first_step(zero_grad=True) 191 | 192 | disable_running_stats(model) 193 | criterion(model(data), target).backward(create_graph=create_graph) 194 | 195 | if args.optimizer == 'sassha': 196 | optimizer.unperturb() 197 | optimizer.step() 198 | optimizer.zero_grad() 199 | 200 | elif args.optimizer in ['samsgd', 'samadamw']: 201 | optimizer.second_step(zero_grad=True) 202 | 203 | else: 204 | output = model(data) 205 | loss = criterion(output, target) 206 | loss.backward(create_graph=create_graph) 207 | optimizer.step() 208 | optimizer.zero_grad() 209 | 210 | # for records 211 | train_loss += loss.item() * target.size()[0] 212 | total_num += target.size()[0] 213 | _, predicted = output.max(1) 214 | correct += predicted.eq(target).sum().item() 215 | progressbar.update(target.size(0)) 216 | iter_num += 1 217 | 218 | if args.optimizer == 'msassha': 219 | optimizer.move_back_from_momentumAscent() 220 | 221 | acc, val_loss = test(model, test_loader, criterion) 222 | 223 | train_loss /= total_num 224 | train_acc = correct / total_num * 100 225 | 226 | if acc > best_acc: 227 | best_acc = acc 228 | 229 | if wandb_log: 230 | wandb.log({ 231 | "iter": iter_num, 232 | "train/loss": train_loss, 233 | "train/acc": train_acc, 234 | "val/acc": acc*100, 235 | "val/loss": val_loss, 236 | "lr": scheduler.get_last_lr()[-1], 237 | 'best_accuracy': best_acc, 238 | "hessian_power": optimizer.hessian_power_t if args.optimizer == 'sassha' else 0}, 239 | step=epoch) 240 | 241 | print(f"Training Loss of Epoch {epoch}: {np.around(train_loss, 2)}") 242 | print(f"Testing of Epoch {epoch}: {np.around(acc * 100, 2)} \n") 243 | 244 | print(f'Best Acc: {np.around(best_acc * 100, 2)}') 245 | -------------------------------------------------------------------------------- /label_noise/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torchvision import datasets, transforms 8 | from torch.utils.data import Dataset, DataLoader 9 | from torch.autograd import Variable 10 | 11 | def introduce_label_noise(labels, num_classes, noise_level=0.1): 12 | n_samples = len(labels) 13 | n_noisy = int(n_samples * noise_level) 14 | noisy_indices = torch.randperm(n_samples)[:n_noisy] 15 | # Generate new random labels for the selected indices 16 | new_labels = torch.randint(0, num_classes, (n_noisy,)) 17 | labels[noisy_indices] = new_labels 18 | return labels 19 | 20 | def getNoisyData( 21 | name='cifar10', 22 | train_bs=128, 23 | test_bs=1000, 24 | noise_level=0.1, 25 | ): 26 | 27 | if name == 'mnist': 28 | 29 | train_loader = torch.utils.data.DataLoader( 30 | datasets.MNIST('../data', train=True, download=True, 31 | transform=transforms.Compose([ 32 | transforms.ToTensor(), 33 | transforms.Normalize((0.1307,), (0.3081,)) 34 | ])), 35 | batch_size=train_bs, shuffle=True) 36 | 37 | test_loader = torch.utils.data.DataLoader( 38 | datasets.MNIST( 39 | '../data', 40 | train=False, 41 | transform=transforms.Compose( 42 | [ 43 | transforms.ToTensor(), 44 | transforms.Normalize( 45 | (0.1307, 46 | ), 47 | (0.3081, 48 | ))])), 49 | batch_size=test_bs, 50 | shuffle=False) 51 | 52 | if name == 'cifar10': 53 | transform_train = transforms.Compose([ 54 | transforms.RandomCrop(32, padding=4), 55 | transforms.RandomHorizontalFlip(), 56 | transforms.ToTensor(), 57 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 58 | ]) 59 | 60 | trainset = datasets.CIFAR10( 61 | root='../data', 62 | train=True, 63 | download=True, 64 | transform=transform_train) 65 | 66 | trainset.targets = introduce_label_noise(torch.tensor(trainset.targets), 67 | num_classes=10, 68 | noise_level=noise_level) 69 | 70 | train_loader = torch.utils.data.DataLoader( 71 | trainset, batch_size=train_bs, shuffle=True, drop_last=False) 72 | 73 | transform_test = transforms.Compose([ 74 | transforms.ToTensor(), 75 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 76 | ]) 77 | 78 | testset = datasets.CIFAR10( 79 | root='../data', 80 | train=False, 81 | download=False, 82 | transform=transform_test) 83 | 84 | test_loader = torch.utils.data.DataLoader( 85 | testset, batch_size=test_bs, shuffle=False) 86 | 87 | if name == 'cifar100': 88 | 89 | transform_train = transforms.Compose([ 90 | transforms.RandomCrop(32, padding=4), 91 | transforms.RandomHorizontalFlip(), 92 | transforms.ToTensor(), 93 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 94 | ]) 95 | 96 | trainset = datasets.CIFAR100( 97 | root='../data', 98 | train=True, 99 | download=True, 100 | transform=transform_train) 101 | 102 | trainset.targets = introduce_label_noise(torch.tensor(trainset.targets), 103 | num_classes=100, 104 | noise_level=noise_level) 105 | 106 | train_loader = torch.utils.data.DataLoader( 107 | trainset, batch_size=train_bs, shuffle=True, drop_last=False) 108 | 109 | transform_test = transforms.Compose([ 110 | transforms.ToTensor(), 111 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 112 | ]) 113 | 114 | testset = datasets.CIFAR100( 115 | root='../data', 116 | train=False, 117 | download=False, 118 | transform=transform_test) 119 | 120 | test_loader = torch.utils.data.DataLoader( 121 | testset, batch_size=test_bs, shuffle=False) 122 | 123 | return train_loader, test_loader 124 | 125 | 126 | def test(model, test_loader, criterion): 127 | # print('Testing') 128 | model.eval() 129 | correct = 0 130 | total_num = 0 131 | val_loss = 0.0 132 | with torch.no_grad(): 133 | for data, target in test_loader: 134 | data, target = data.cuda(), target.cuda() 135 | output = model(data) 136 | loss = criterion(output, target) 137 | # get the index of the max log-probability 138 | pred = output.data.max(1, keepdim=True)[1] 139 | correct += pred.eq(target.data.view_as(pred)).cpu().sum().item() 140 | val_loss += loss.item() * target.size()[0] 141 | total_num += len(data) 142 | # print('testing_correct: ', correct / total_num, '\n') 143 | return (correct / total_num, val_loss / total_num) 144 | 145 | 146 | def get_params_grad(model): 147 | """ 148 | get model parameters and corresponding gradients 149 | """ 150 | params = [] 151 | grads = [] 152 | for param in model.parameters(): 153 | if not param.requires_grad: 154 | continue 155 | params.append(param) 156 | grads.append(0. if param.grad is None else param.grad + 0.) 157 | return params, grads 158 | -------------------------------------------------------------------------------- /language_tasks/README.md: -------------------------------------------------------------------------------- 1 | # We plan to release code for pretraining large language models that exceed the scale of the mini-gpt1 model. -------------------------------------------------------------------------------- /language_tasks/config/adahessian/mnli.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name mnli \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer adahessian \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-3 \ 10 | --weight_decay 1e-6 \ 11 | --eps 1e-6 \ 12 | --lazy_hessian 1 \ 13 | --seed 0, 1, 2 14 | -------------------------------------------------------------------------------- /language_tasks/config/adahessian/mrpc.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name mrpc \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer adahessian \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-2 \ 10 | --weight_decay 1e-6 \ 11 | --eps 1e-4 \ 12 | --lazy_hessian 1 \ 13 | --seed 0, 1, 2 14 | -------------------------------------------------------------------------------- /language_tasks/config/adahessian/qnli.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name qnli \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer adahessian \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-3 \ 10 | --weight_decay 1e-7 \ 11 | --eps 1e-4 \ 12 | --lazy_hessian 1 \ 13 | --seed 0, 1, 2 14 | -------------------------------------------------------------------------------- /language_tasks/config/adahessian/qqp.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name qqp \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer adahessian \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-2 \ 10 | --weight_decay 1e-7 \ 11 | --eps 1e-4 \ 12 | --lazy_hessian 1 \ 13 | --seed 0, 1, 2 14 | -------------------------------------------------------------------------------- /language_tasks/config/adahessian/rte.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name rte \ 4 | --max_length 512 \ 5 | --num_train_epochs 10 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer adahessian \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-3 \ 10 | --weight_decay 1e-6 \ 11 | --eps 1e-4 \ 12 | --lazy_hessian 1 \ 13 | --seed 0, 1, 2 14 | -------------------------------------------------------------------------------- /language_tasks/config/adahessian/sst2.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name sst2 \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer adahessian \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-3 \ 10 | --weight_decay 1e-6 \ 11 | --eps 1e-8 \ 12 | --lazy_hessian 1 \ 13 | --seed 0, 1, 2 14 | -------------------------------------------------------------------------------- /language_tasks/config/adahessian/stsb.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name stsb \ 4 | --max_length 512 \ 5 | --num_train_epochs 10 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer sassha \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-2 \ 10 | --weight_decay 1e-6 \ 11 | --eps 1e-4 \ 12 | --lazy_hessian 1 \ 13 | --seed 0, 1, 2 14 | -------------------------------------------------------------------------------- /language_tasks/config/adamw/mnli.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name mnli \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer adamw \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-5 \ 10 | --weight_decay 1e-6 \ 11 | --eps 1e-6 \ 12 | --seed 0, 1, 2 13 | -------------------------------------------------------------------------------- /language_tasks/config/adamw/mrpc.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name mrpc \ 4 | --max_length 512 \ 5 | --num_train_epochs 10 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer adamw \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-4 \ 10 | --weight_decay 1e-7 \ 11 | --eps 1e-4 \ 12 | --seed 0, 1, 2 13 | -------------------------------------------------------------------------------- /language_tasks/config/adamw/qnli.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name qnli \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer adamw \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-5 \ 10 | --weight_decay 1e-7 \ 11 | --eps 1e-8 \ 12 | --seed 0, 1, 2 13 | -------------------------------------------------------------------------------- /language_tasks/config/adamw/qqp.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name qqp \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer adamw \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-5 \ 10 | --weight_decay 1e-6 \ 11 | --eps 1e-6 \ 12 | --seed 0, 1, 2 13 | -------------------------------------------------------------------------------- /language_tasks/config/adamw/rte.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name rte \ 4 | --max_length 512 \ 5 | --num_train_epochs 10 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer adamw \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-2 \ 10 | --weight_decay 1e-8 \ 11 | --eps 1e-8 \ 12 | --lazy_hessian 1 \ 13 | --seed 0, 1, 2 14 | -------------------------------------------------------------------------------- /language_tasks/config/adamw/sst2.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name sst2 \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer adamw \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-5 \ 10 | --weight_decay 1e-8 \ 11 | --eps 1e-8 \ 12 | --seed 0, 1, 2 13 | -------------------------------------------------------------------------------- /language_tasks/config/adamw/stsb.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name stsb \ 4 | --max_length 512 \ 5 | --num_train_epochs 10 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer adamw \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-5 \ 10 | --weight_decay 1e-5 \ 11 | --eps 1e-8 \ 12 | --seed 0, 1, 2 13 | -------------------------------------------------------------------------------- /language_tasks/config/msassha/mnli.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name mnli \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer msassha \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-2 \ 10 | --weight_decay 1e-6 \ 11 | --rho 1e-2 \ 12 | --eps 1e-2 \ 13 | --lazy_hessian 1 \ 14 | --seed 0, 1, 2 15 | -------------------------------------------------------------------------------- /language_tasks/config/msassha/mrpc.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name mrpc \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer msassha \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-2 \ 10 | --weight_decay 1e-8 \ 11 | --rho 1e-5 \ 12 | --eps 1e-4 \ 13 | --lazy_hessian 1 \ 14 | --seed 0, 1, 2 15 | -------------------------------------------------------------------------------- /language_tasks/config/msassha/qnli.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name qnli \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer msassha \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-2 \ 10 | --weight_decay 1e-6 \ 11 | --rho 1e-5 \ 12 | --eps 1e-2 \ 13 | --lazy_hessian 1 \ 14 | --seed 0, 1, 2 15 | -------------------------------------------------------------------------------- /language_tasks/config/msassha/qqp.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name qqp \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer msassha \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-2 \ 10 | --weight_decay 1e-6 \ 11 | --rho 1e-5 \ 12 | --eps 1e-4 \ 13 | --lazy_hessian 1 \ 14 | --seed 0, 1, 2 15 | -------------------------------------------------------------------------------- /language_tasks/config/msassha/rte.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name rte \ 4 | --max_length 512 \ 5 | --num_train_epochs 10 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer msassha \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-2 \ 10 | --weight_decay 1e-6 \ 11 | --rho 1e-3 \ 12 | --eps 1e-8 \ 13 | --lazy_hessian 1 \ 14 | --seed 0, 1, 2 15 | -------------------------------------------------------------------------------- /language_tasks/config/msassha/sst2.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name sst2 \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer msassha \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-2 \ 10 | --weight_decay 1e-6 \ 11 | --rho 1e-4 \ 12 | --eps 1e-8 \ 13 | --lazy_hessian 1 \ 14 | --seed 0, 1, 2 15 | -------------------------------------------------------------------------------- /language_tasks/config/msassha/stsb.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name stsb \ 4 | --max_length 512 \ 5 | --num_train_epochs 10 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer msassha \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-3 \ 10 | --weight_decay 1e-5 \ 11 | --rho 1e-5 \ 12 | --eps 1e-4 \ 13 | --lazy_hessian 1 \ 14 | --seed 0, 1, 2 15 | -------------------------------------------------------------------------------- /language_tasks/config/sam_adamw/mnli.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name mnli \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer samadamw \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-5 \ 10 | --weight_decay 1e-6 \ 11 | --rho 1e-2 \ 12 | --eps 1e-6 \ 13 | --seed 0, 1, 2 14 | -------------------------------------------------------------------------------- /language_tasks/config/sam_adamw/mrpc.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name mrpc \ 4 | --max_length 512 \ 5 | --num_train_epochs 10 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer samadamw \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-5 \ 10 | --weight_decay 1e-7 \ 11 | --rho 1e-4 \ 12 | --eps 1e-6 \ 13 | --seed 0, 1, 2 14 | -------------------------------------------------------------------------------- /language_tasks/config/sam_adamw/qnli.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name qnli \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer samadamw \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-5 \ 10 | --weight_decay 1e-7 \ 11 | --rho 1e-2 \ 12 | --eps 1e-8 \ 13 | --seed 0, 1, 2 14 | -------------------------------------------------------------------------------- /language_tasks/config/sam_adamw/qqp.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name qqp \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer samadamw \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-4 \ 10 | --weight_decay 1e-6 \ 11 | --rho 1e-2 \ 12 | --eps 1e-6 \ 13 | --seed 0, 1, 2 14 | -------------------------------------------------------------------------------- /language_tasks/config/sam_adamw/rte.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name rte \ 4 | --max_length 512 \ 5 | --num_train_epochs 10 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer samadamw \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-4 \ 10 | --weight_decay 1e-7 \ 11 | --rho 1e-3 \ 12 | --eps 1e-6 \ 13 | --lazy_hessian 1 \ 14 | --seed 0, 1, 2 15 | -------------------------------------------------------------------------------- /language_tasks/config/sam_adamw/sst2.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name sst2 \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer samadamw \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-5 \ 10 | --weight_decay 1e-5 \ 11 | --rho 1e-2 \ 12 | --eps 1e-8 \ 13 | --seed 0, 1, 2 14 | -------------------------------------------------------------------------------- /language_tasks/config/sam_adamw/stsb.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name stsb \ 4 | --max_length 512 \ 5 | --num_train_epochs 10 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer samadamw \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-4 \ 10 | --weight_decay 1e-7 \ 11 | --rho 1e-3 \ 12 | --eps 1e-6 \ 13 | --seed 0, 1, 2 14 | -------------------------------------------------------------------------------- /language_tasks/config/sassha/mnli.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name mnli \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer sassha \ 8 | --lr_scheduler_type polynomial \ 9 | --hessian_power_scheduler constant \ 10 | --lr 1e-2 \ 11 | --weight_decay 1e-6 \ 12 | --rho 1e-2 \ 13 | --eps 1e-6 \ 14 | --lazy_hessian 1 \ 15 | --seed 0, 1, 2 16 | -------------------------------------------------------------------------------- /language_tasks/config/sassha/mrpc.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name mrpc \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer sassha \ 8 | --lr_scheduler_type polynomial \ 9 | --hessian_power_scheduler constant \ 10 | --lr 1e-2 \ 11 | --weight_decay 1e-8 \ 12 | --rho 1e-4 \ 13 | --eps 1e-4 \ 14 | --lazy_hessian 1 \ 15 | --seed 0, 1, 2 16 | -------------------------------------------------------------------------------- /language_tasks/config/sassha/qnli.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name qnli \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer sassha \ 8 | --lr_scheduler_type polynomial \ 9 | --hessian_power_scheduler constant \ 10 | --lr 1e-2 \ 11 | --weight_decay 1e-7 \ 12 | --rho 1e-2 \ 13 | --eps 1e-6 \ 14 | --lazy_hessian 1 \ 15 | --seed 0, 1, 2 16 | -------------------------------------------------------------------------------- /language_tasks/config/sassha/qqp.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name qqp \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer sassha \ 8 | --lr_scheduler_type polynomial \ 9 | --hessian_power_scheduler constant \ 10 | --lr 1e-2 \ 11 | --weight_decay 1e-6 \ 12 | --rho 1e-2 \ 13 | --eps 1e-4 \ 14 | --lazy_hessian 1 \ 15 | --seed 0, 1, 2 16 | -------------------------------------------------------------------------------- /language_tasks/config/sassha/rte.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name rte \ 4 | --max_length 512 \ 5 | --num_train_epochs 10 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer sassha \ 8 | --lr_scheduler_type polynomial \ 9 | --hessian_power_scheduler constant \ 10 | --lr 1e-2 \ 11 | --weight_decay 1e-4 \ 12 | --rho 1e-2 \ 13 | --eps 1e-4 \ 14 | --lazy_hessian 1 \ 15 | --seed 0, 1, 2 16 | -------------------------------------------------------------------------------- /language_tasks/config/sassha/sst2.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name sst2 \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer sassha \ 8 | --lr_scheduler_type polynomial \ 9 | --hessian_power_scheduler constant \ 10 | --lr 1e-2 \ 11 | --weight_decay 1e-8 \ 12 | --rho 1e-4 \ 13 | --eps 1e-6 \ 14 | --lazy_hessian 1 \ 15 | --seed 0, 1, 2 16 | -------------------------------------------------------------------------------- /language_tasks/config/sassha/stsb.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name stsb \ 4 | --max_length 512 \ 5 | --num_train_epochs 10 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer sassha \ 8 | --lr_scheduler_type polynomial \ 9 | --hessian_power_scheduler constant \ 10 | --lr 1e-2 \ 11 | --weight_decay 1e-5 \ 12 | --rho 1e-2 \ 13 | --eps 1e-8 \ 14 | --lazy_hessian 1 \ 15 | --seed 0, 1, 2 16 | -------------------------------------------------------------------------------- /language_tasks/config/sophiah/mnli.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name mnli \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer sophiah \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-5 \ 10 | --weight_decay 0 \ 11 | --clip_threshold 1e-4 \ 12 | --eps 1e-6 \ 13 | --lazy_hessian 1 \ 14 | --seed 0, 1, 2 15 | -------------------------------------------------------------------------------- /language_tasks/config/sophiah/mrpc.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name mrpc \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer sophiah \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-5 \ 10 | --weight_decay 1e-6 \ 11 | --clip_threshold 1e-2 \ 12 | --eps 1e-6 \ 13 | --lazy_hessian 1 \ 14 | --seed 0, 1, 2 15 | -------------------------------------------------------------------------------- /language_tasks/config/sophiah/qnli.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name qnli \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer sophiah \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-5 \ 10 | --weight_decay 1e-7 \ 11 | --clip_threshold 1e-2 \ 12 | --eps 1e-4 \ 13 | --lazy_hessian 1 \ 14 | --seed 0, 1, 2 15 | -------------------------------------------------------------------------------- /language_tasks/config/sophiah/qqp.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name qqp \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer sophiah \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-5 \ 10 | --weight_decay 1e-6 \ 11 | --clip_threshold 0.1 \ 12 | --eps 1e-4 \ 13 | --lazy_hessian 1 \ 14 | --seed 0, 1, 2 15 | -------------------------------------------------------------------------------- /language_tasks/config/sophiah/rte.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name rte \ 4 | --max_length 512 \ 5 | --num_train_epochs 10 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer sophiah \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-5 \ 10 | --weight_decay 1e-4 \ 11 | --clip_threshold 1e-5 \ 12 | --eps 1e-4 \ 13 | --lazy_hessian 1 \ 14 | --seed 0, 1, 2 15 | -------------------------------------------------------------------------------- /language_tasks/config/sophiah/sst2.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name sst2 \ 4 | --max_length 512 \ 5 | --num_train_epochs 5 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer sophiah \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-5 \ 10 | --weight_decay 0 \ 11 | --clip_threshold 1e-4 \ 12 | --eps 1e-6 \ 13 | --lazy_hessian 1 \ 14 | --seed 0, 1, 2 15 | -------------------------------------------------------------------------------- /language_tasks/config/sophiah/stsb.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model_name_or_path squeezebert/squeezebert-uncased \ 3 | --task_name stsb \ 4 | --max_length 512 \ 5 | --num_train_epochs 10 \ 6 | --per_device_train_batch_size 16 \ 7 | --optimizer sophiah \ 8 | --lr_scheduler_type polynomial \ 9 | --lr 1e-5 \ 10 | --weight_decay 1e-5 \ 11 | --clip_threshold 1e-4 \ 12 | --eps 1e-6 \ 13 | --lazy_hessian 1 \ 14 | --seed 0, 1, 2 15 | -------------------------------------------------------------------------------- /language_tasks/config/wandb/rte_sophia.sh: -------------------------------------------------------------------------------- 1 | project: squeezebert-rte # here 2 | program: finetune.py # here 3 | method: grid 4 | metric: 5 | name: accuracy.accuracy 6 | goal: maximize 7 | parameters: 8 | optimizer: 9 | values: ['sophiah'] # here 10 | task_name: 11 | values: ['rte'] # here 12 | learning_rate: 13 | values: [1e-1, 1e-2, 1e-3, 1e-4, 1e-5] 14 | weight_decay: 15 | values: [0, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8] 16 | update_each: 17 | values: [1, 2, 3, 4, 5, 10] 18 | clip_threshold: 19 | values: [1e-1, 1e-2, 1e-3, 1e-4, 1e-5] 20 | eps: 21 | values: [1e-4, 1e-6, 1e-8] 22 | model_name_or_path: 23 | values: ['squeezebert/squeezebert-uncased'] 24 | max_length: 25 | values: [512] 26 | num_train_epochs: 27 | values: [10] 28 | lr_scheduler_type: 29 | values: ['polynomial'] 30 | per_device_train_batch_size: 31 | values: [16] 32 | -------------------------------------------------------------------------------- /language_tasks/finetune.py: -------------------------------------------------------------------------------- 1 | # Acknowledgement: this code is based on : https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue_no_trainer.py 2 | 3 | import sys 4 | import argparse 5 | import json 6 | import logging 7 | import math 8 | import os 9 | import random 10 | from pathlib import Path 11 | 12 | import datasets 13 | import evaluate 14 | import torch 15 | from transformers import set_seed 16 | 17 | from datasets import load_dataset 18 | from huggingface_hub import HfApi 19 | from torch.utils.data import DataLoader 20 | import torch.nn as nn 21 | 22 | from tqdm.auto import tqdm 23 | 24 | import transformers 25 | from transformers import ( 26 | AutoConfig, 27 | AutoModelForSequenceClassification, 28 | AutoTokenizer, 29 | DataCollatorWithPadding, 30 | PretrainedConfig, 31 | SchedulerType, 32 | default_data_collator, 33 | get_scheduler, 34 | ) 35 | from transformers.utils import check_min_version, send_example_telemetry 36 | from transformers.utils.versions import require_version 37 | 38 | import wandb 39 | os.environ["WANDB__SERVICE_WAIT"] = "300" 40 | 41 | # Load optimizers 42 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname('optimizers')))) 43 | from optimizers import get_optimizer 44 | 45 | # load hessain power scheduler 46 | from optimizers.hessian_scheduler import ConstantScheduler, ProportionScheduler, LinearScheduler, CosineScheduler 47 | 48 | 49 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 50 | check_min_version("4.41.0.dev0") 51 | 52 | #logger = get_logger(__name__) 53 | 54 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") 55 | 56 | task_to_keys = { 57 | "cola": ("sentence", None), 58 | "mnli": ("premise", "hypothesis"), 59 | "mrpc": ("sentence1", "sentence2"), 60 | "qnli": ("question", "sentence"), 61 | "qqp": ("question1", "question2"), 62 | "rte": ("sentence1", "sentence2"), 63 | "sst2": ("sentence", None), 64 | "stsb": ("sentence1", "sentence2"), 65 | "wnli": ("sentence1", "sentence2"), 66 | } 67 | 68 | 69 | def parse_args(): 70 | parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task") 71 | parser.add_argument( 72 | "--task_name", 73 | type=str, 74 | default=None, 75 | help="The name of the glue task to train on.", 76 | choices=list(task_to_keys.keys()), 77 | ) 78 | parser.add_argument( 79 | "--train_file", type=str, default=None, help="A csv or a json file containing the training data." 80 | ) 81 | parser.add_argument( 82 | "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." 83 | ) 84 | parser.add_argument( 85 | "--max_length", 86 | type=int, 87 | default=128, 88 | help=( 89 | "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," 90 | " sequences shorter will be padded if `--pad_to_max_length` is passed." 91 | ), 92 | ) 93 | parser.add_argument( 94 | "--pad_to_max_length", 95 | action="store_true", 96 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", 97 | ) 98 | parser.add_argument( 99 | "--model_name_or_path", 100 | type=str, 101 | help="Path to pretrained model or model identifier from huggingface.co/models.", 102 | required=True, 103 | ) 104 | parser.add_argument( 105 | "--use_slow_tokenizer", 106 | action="store_true", 107 | help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", 108 | ) 109 | parser.add_argument( 110 | "--per_device_train_batch_size", 111 | type=int, 112 | default=8, 113 | help="Batch size (per device) for the training dataloader.", 114 | ) 115 | parser.add_argument( 116 | "--per_device_eval_batch_size", 117 | type=int, 118 | default=8, 119 | help="Batch size (per device) for the evaluation dataloader.", 120 | ) 121 | parser.add_argument( 122 | "--lr", 123 | type=float, 124 | default=5e-5, 125 | help="Initial learning rate (after the potential warmup period) to use.", 126 | ) 127 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") 128 | parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") 129 | parser.add_argument( 130 | "--max_train_steps", 131 | type=int, 132 | default=None, 133 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 134 | ) 135 | parser.add_argument( 136 | "--gradient_accumulation_steps", 137 | type=int, 138 | default=1, 139 | help="Number of updates steps to accumulate before performing a backward/update pass.", 140 | ) 141 | parser.add_argument( 142 | "--lr_scheduler_type", 143 | type=SchedulerType, 144 | default="linear", 145 | help="The scheduler type to use.", 146 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], 147 | ) 148 | parser.add_argument( 149 | "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." 150 | ) 151 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") 152 | parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.") 153 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 154 | parser.add_argument( 155 | "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." 156 | ) 157 | parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") 158 | parser.add_argument( 159 | "--trust_remote_code", 160 | type=bool, 161 | default=False, 162 | help=( 163 | "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option " 164 | "should only be set to `True` for repositories you trust and in which you have read the code, as it will " 165 | "execute code present on the Hub on your local machine." 166 | ), 167 | ) 168 | parser.add_argument( 169 | "--checkpointing_steps", 170 | type=str, 171 | default=None, 172 | help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", 173 | ) 174 | parser.add_argument( 175 | "--resume_from_checkpoint", 176 | type=str, 177 | default=None, 178 | help="If the training should continue from a checkpoint folder.", 179 | ) 180 | parser.add_argument( 181 | "--with_tracking", 182 | action="store_true", 183 | help="Whether to enable experiment trackers for logging.", 184 | ) 185 | parser.add_argument( 186 | "--report_to", 187 | type=str, 188 | default="all", 189 | help=( 190 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' 191 | ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations. ' 192 | "Only applicable when `--with_tracking` is passed." 193 | ), 194 | ) 195 | 196 | parser.add_argument( 197 | "--clip_norm", 198 | type=int, 199 | default=10, 200 | help="Clip norm for SGD optimizer (default: %(default)s).", 201 | ) 202 | 203 | parser.add_argument( 204 | "--ignore_mismatched_sizes", 205 | action="store_true", 206 | help="Whether or not to enable to load a pretrained model whose head dimensions are different.", 207 | ) 208 | 209 | parser.add_argument('--optimizer', type=str, default='sassha', help='choose optim') 210 | 211 | # Second-order optimization settings 212 | parser.add_argument("--n_samples", default=1, type=int, help="the number of sampling") 213 | parser.add_argument('--betas', type=float, nargs='*', default=[0.9, 0.999], help='betas') 214 | parser.add_argument("--eps", default=1e-4, type=float, help="add a small number for stability") 215 | parser.add_argument("--lazy_hessian", default=10, type=int, help="Delayed hessian update") 216 | parser.add_argument("--clip_threshold", default=0.01, type=float, help="sophia clipping") 217 | 218 | # Hessian power scheduler 219 | parser.add_argument('--hessian_power_scheduler', type=str, default='constant', help="choose Hessian power 1. 'constant', 2. 'proportion', 3. 'linear', 4. 'cosine'") 220 | parser.add_argument('--max_hessian_power', type=float, default=1) 221 | parser.add_argument('--min_hessian_power', type=float, default=0.5) 222 | parser.add_argument('--min_lr', type=float, default=0.0, help="the minimum value of learning rate") 223 | 224 | # Sharpness minimization settings 225 | parser.add_argument("--rho", default=0.05, type=float, help="Rho parameter for sharpness minimization") 226 | parser.add_argument("--adaptive", default=False, type=bool, help="True if you want to use the Adaptive sharpness") 227 | parser.add_argument('--project_name', type=str, default='project_name', help="project_name") 228 | 229 | args = parser.parse_args() 230 | 231 | # Sanity checks 232 | if args.task_name is None and args.train_file is None and args.validation_file is None: 233 | raise ValueError("Need either a task name or a training/validation file.") 234 | else: 235 | if args.train_file is not None: 236 | extension = args.train_file.split(".")[-1] 237 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 238 | if args.validation_file is not None: 239 | extension = args.validation_file.split(".")[-1] 240 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 241 | 242 | if args.push_to_hub: 243 | assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." 244 | 245 | return args 246 | 247 | 248 | def main(): 249 | args = parse_args() 250 | 251 | # Control randomness 252 | random.seed(args.seed) 253 | torch.manual_seed(args.seed) 254 | torch.cuda.manual_seed_all(args.seed) 255 | set_seed(args.seed) 256 | torch.backends.cudnn.benchmark = False 257 | torch.backends.cudnn.deterministic = True 258 | 259 | # Load Dataset 260 | if args.task_name is not None: 261 | # Downloading and loading a dataset from the hub. 262 | raw_datasets = load_dataset("nyu-mll/glue", args.task_name) 263 | else: 264 | # Loading the dataset from local csv or json file. 265 | data_files = {} 266 | if args.train_file is not None: 267 | data_files["train"] = args.train_file 268 | if args.validation_file is not None: 269 | data_files["validation"] = args.validation_file 270 | extension = (args.train_file if args.train_file is not None else args.validation_file).split(".")[-1] 271 | raw_datasets = load_dataset(extension, data_files=data_files) 272 | # See more about loading any type of standard or custom dataset at 273 | # https://huggingface.co/docs/datasets/loading_datasets. 274 | 275 | # Labels 276 | if args.task_name is not None: 277 | is_regression = args.task_name == "stsb" 278 | if not is_regression: 279 | label_list = raw_datasets["train"].features["label"].names 280 | num_labels = len(label_list) 281 | else: 282 | num_labels = 1 283 | else: 284 | # Trying to have good defaults here, don't hesitate to tweak to your needs. 285 | is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] 286 | if is_regression: 287 | num_labels = 1 288 | else: 289 | # A useful fast method: 290 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique 291 | label_list = raw_datasets["train"].unique("label") 292 | label_list.sort() # Let's sort it for determinism 293 | num_labels = len(label_list) 294 | 295 | # Load pretrained model and tokenizer 296 | 297 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 298 | # download model & vocab. 299 | 300 | # Get config 301 | config = AutoConfig.from_pretrained( 302 | args.model_name_or_path, 303 | num_labels=num_labels, 304 | finetuning_task=args.task_name, 305 | trust_remote_code=args.trust_remote_code, 306 | ) 307 | 308 | # Get tokenizer 309 | tokenizer = AutoTokenizer.from_pretrained( 310 | args.model_name_or_path, use_fast=not args.use_slow_tokenizer, trust_remote_code=args.trust_remote_code 311 | ) 312 | if tokenizer.pad_token is None: 313 | tokenizer.pad_token = tokenizer.eos_token 314 | config.pad_token_id = tokenizer.pad_token_id 315 | 316 | # Get model 317 | model = AutoModelForSequenceClassification.from_pretrained( 318 | args.model_name_or_path, 319 | from_tf=bool(".ckpt" in args.model_name_or_path), 320 | config=config, 321 | ignore_mismatched_sizes=args.ignore_mismatched_sizes, 322 | trust_remote_code=args.trust_remote_code, 323 | ) 324 | 325 | # 326 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 327 | model.to(device) 328 | 329 | # Preprocessing the datasets 330 | if args.task_name is not None: 331 | sentence1_key, sentence2_key = task_to_keys[args.task_name] 332 | else: 333 | # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. 334 | non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] 335 | if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: 336 | sentence1_key, sentence2_key = "sentence1", "sentence2" 337 | else: 338 | if len(non_label_column_names) >= 2: 339 | sentence1_key, sentence2_key = non_label_column_names[:2] 340 | else: 341 | sentence1_key, sentence2_key = non_label_column_names[0], None 342 | 343 | # Some models have set the order of the labels to use, so let's make sure we do use it. 344 | label_to_id = None 345 | if ( 346 | model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id 347 | and args.task_name is not None 348 | and not is_regression 349 | ): 350 | # Some have all caps in their config, some don't. 351 | label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} 352 | if sorted(label_name_to_id.keys()) == sorted(label_list): 353 | label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)} 354 | 355 | else: 356 | pass 357 | 358 | elif args.task_name is None and not is_regression: 359 | label_to_id = {v: i for i, v in enumerate(label_list)} 360 | 361 | if label_to_id is not None: 362 | model.config.label2id = label_to_id 363 | model.config.id2label = {id: label for label, id in config.label2id.items()} 364 | elif args.task_name is not None and not is_regression: 365 | model.config.label2id = {l: i for i, l in enumerate(label_list)} 366 | model.config.id2label = {id: label for label, id in config.label2id.items()} 367 | 368 | # Set target padding 369 | padding = "max_length" if args.pad_to_max_length else False 370 | 371 | def preprocess_function(examples): 372 | # Tokenize the texts 373 | texts = ( 374 | (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) 375 | ) 376 | result = tokenizer(*texts, padding=padding, max_length=args.max_length, truncation=True) 377 | 378 | if "label" in examples: 379 | if label_to_id is not None: 380 | # Map labels to IDs (not necessary for GLUE tasks) 381 | result["labels"] = [label_to_id[l] for l in examples["label"]] 382 | else: 383 | # In all cases, rename the column to labels because the model will expect that. 384 | result["labels"] = examples["label"] 385 | return result 386 | 387 | processed_datasets = raw_datasets.map( 388 | preprocess_function, 389 | batched=True, 390 | remove_columns=raw_datasets["train"].column_names, 391 | desc="Running tokenizer on dataset", 392 | ) 393 | 394 | train_dataset = processed_datasets["train"] 395 | eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"] 396 | 397 | # DataLoaders creation: 398 | if args.pad_to_max_length: 399 | # If padding was already done ot max length, we use the default data collator that will just convert everything 400 | # to tensors. 401 | data_collator = default_data_collator 402 | else: 403 | # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of 404 | # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple 405 | # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). 406 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=None) 407 | 408 | train_dataloader = DataLoader( 409 | train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size 410 | ) 411 | eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size) 412 | 413 | wandb_name = f"{args.task_name}-{args.optimizer}-{args.lr}-{args.weight_decay}-{args.eps}" 414 | 415 | # get an optimizer 416 | optimizer, create_graph, two_steps = get_optimizer(model, args) 417 | 418 | # Scheduler and math around the number of training steps. 419 | overrode_max_train_steps = False 420 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 421 | if args.max_train_steps is None: 422 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 423 | overrode_max_train_steps = True 424 | 425 | lr_scheduler = get_scheduler( 426 | name=args.lr_scheduler_type, 427 | optimizer=optimizer, 428 | num_warmup_steps=args.num_warmup_steps, 429 | num_training_steps=args.max_train_steps, 430 | ) 431 | 432 | # select a hessian power scheduler 433 | if args.optimizer == 'sassha': 434 | if args.hessian_power_scheduler == 'constant': 435 | hessian_power_scheduler = ConstantScheduler( 436 | T_max=args.num_train_epochs*len(train_dataloader), 437 | max_value=0.5, 438 | min_value=0.5) 439 | 440 | elif args.hessian_power_scheduler == 'proportion': 441 | hessian_power_scheduler = ProportionScheduler( 442 | pytorch_lr_scheduler=scheduler, 443 | max_lr=args.lr, 444 | min_lr=args.min_lr, 445 | max_value=args.max_hessian_power, 446 | min_value=args.min_hessian_power) 447 | 448 | elif args.hessian_power_scheduler == 'linear': 449 | hessian_power_scheduler = LinearScheduler( 450 | T_max=args.num_train_epochs*len(train_dataloader), 451 | max_value=args.max_hessian_power, 452 | min_value=args.min_hessian_power) 453 | 454 | elif args.hessian_power_scheduler == 'cosine': 455 | hessian_power_scheduler = CosineScheduler( 456 | T_max=args.num_train_epochs*len(train_dataloader), 457 | max_value=args.max_hessian_power, 458 | min_value=args.min_hessian_power) 459 | 460 | optimizer.hessian_power_scheduler = hessian_power_scheduler 461 | 462 | # We need to recalculate our total training steps as the size of the training dataloader may have changed 463 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 464 | if overrode_max_train_steps: 465 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 466 | # Afterwards we recalculate our number of training epochs 467 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 468 | 469 | # Get the metric function 470 | if args.task_name is not None: 471 | metric = evaluate.load("glue", args.task_name, experiment_id=wandb_name) 472 | else: 473 | metric = evaluate.load("accuracy", experiment_id=wandb_name) 474 | 475 | def compute_metric(eval_pred): 476 | predictions, labels = eval_pred 477 | if args.task_name != "stsb": 478 | predictions = torch.argmax(predictions, axis=1) 479 | else: 480 | predictions = predictions[:, 0] 481 | return metric.compute(predictions=predictions, references=labels) 482 | 483 | # Train! 484 | total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps 485 | 486 | # Only show the progress bar once on each machine. 487 | progress_bar = tqdm(range(args.max_train_steps)) 488 | completed_steps = 0 489 | starting_epoch = 0 490 | 491 | # Potentially load in the weights and states from a previous save 492 | if args.resume_from_checkpoint: 493 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": 494 | checkpoint_path = args.resume_from_checkpoint 495 | path = os.path.basename(args.resume_from_checkpoint) 496 | else: 497 | # Get the most recent checkpoint 498 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] 499 | dirs.sort(key=os.path.getctime) 500 | path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last 501 | checkpoint_path = path 502 | path = os.path.basename(checkpoint_path) 503 | 504 | # Extract `epoch_{i}` or `step_{i}` 505 | training_difference = os.path.splitext(path)[0] 506 | 507 | if "epoch" in training_difference: 508 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1 509 | resume_step = None 510 | completed_steps = starting_epoch * num_update_steps_per_epoch 511 | else: 512 | # need to multiply `gradient_accumulation_steps` to reflect real steps 513 | resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps 514 | starting_epoch = resume_step // len(train_dataloader) 515 | completed_steps = resume_step // args.gradient_accumulation_steps 516 | resume_step -= starting_epoch * len(train_dataloader) 517 | 518 | # wandb 519 | wandb_project = args.project_name 520 | wandb.init(project=wandb_project, name=wandb_name) 521 | wandb.config.update(args) 522 | 523 | # update the progress_bar if load from checkpoint 524 | progress_bar.update(completed_steps) 525 | 526 | for epoch in range(starting_epoch, args.num_train_epochs): 527 | model.train() 528 | total_loss = 0 529 | 530 | if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: 531 | # We skip the first `n` batches in the dataloader when resuming from a checkpoint 532 | active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) 533 | else: 534 | active_dataloader = train_dataloader 535 | 536 | for step, batch in enumerate(active_dataloader): 537 | batch = {k: v.to(device) for k, v in batch.items()} 538 | 539 | if two_steps: 540 | outputs = model(**batch) 541 | loss = outputs.loss 542 | # We keep track of the loss at each epoch 543 | total_loss += loss.item() 544 | loss = loss / args.gradient_accumulation_steps 545 | loss.backward() 546 | 547 | if args.optimizer == 'sassha': 548 | optimizer.perturb_weights(zero_grad=True) 549 | 550 | elif args.optimizer in ['samsgd', 'samadamw']: 551 | optimizer.first_step(zero_grad=True) 552 | 553 | model(**batch).loss.backward(create_graph=create_graph) 554 | 555 | if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: 556 | if args.optimizer == 'sassha': 557 | optimizer.unperturb() 558 | optimizer.step() 559 | optimizer.zero_grad() 560 | 561 | elif args.optimizer in ['samsgd', 'samadamw']: 562 | if args.grad_clip_norm != 0: 563 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm) 564 | optimizer.second_step(zero_grad=True) 565 | 566 | lr_scheduler.step() 567 | progress_bar.update(1) 568 | completed_steps += 1 569 | 570 | else: 571 | outputs = model(**batch) 572 | loss = outputs.loss 573 | # We keep track of the loss at each epoch 574 | total_loss += loss.item() 575 | loss = loss / args.gradient_accumulation_steps 576 | loss.backward(create_graph=create_graph) 577 | 578 | if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: 579 | 580 | if args.grad_clip_norm != 0: 581 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm) 582 | 583 | optimizer.step() 584 | optimizer.zero_grad() 585 | lr_scheduler.step() 586 | 587 | progress_bar.update(1) 588 | completed_steps += 1 589 | 590 | if completed_steps >= args.max_train_steps: 591 | break 592 | 593 | model.eval() 594 | all_labels, all_logits = [], [] 595 | val_loss = 0 596 | for step, batch in enumerate(eval_dataloader): 597 | batch = {k: v.to(device) for k, v in batch.items()} 598 | with torch.no_grad(): 599 | outputs = model(**batch) 600 | val_loss += outputs.loss.item() 601 | all_labels.append(batch['labels']) 602 | all_logits.append(outputs.logits) 603 | 604 | all_labels = torch.cat(all_labels, dim=0) 605 | all_logits = torch.cat(all_logits, dim=0) 606 | eval_metric = compute_metric((all_logits, all_labels)) 607 | 608 | wandb.log({ 609 | "accuracy" if args.task_name is not None else "glue": eval_metric, 610 | "train_loss": total_loss / len(train_dataloader), 611 | "val_loss": val_loss / len(eval_dataloader), 612 | "epoch": epoch, 613 | "step": completed_steps, 614 | "lr": optimizer.param_groups[0]['lr'], 615 | "hessian_power": optimizer.hessian_power_t if args.optimizer == 'sassha' else 0, 616 | }) 617 | 618 | if args.task_name == "mnli": 619 | # Final evaluation on mismatched validation set 620 | eval_dataset = processed_datasets["validation_mismatched"] 621 | eval_dataloader = DataLoader( 622 | eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size 623 | ) 624 | 625 | model.eval() 626 | all_labels, all_logits = [], [] 627 | for step, batch in enumerate(eval_dataloader): 628 | batch = {k: v.to(device) for k, v in batch.items()} 629 | with torch.no_grad(): 630 | outputs = model(**batch) 631 | all_labels.append(batch['labels']) 632 | all_logits.append(outputs.logits) 633 | 634 | all_labels = torch.cat(all_labels, dim=0) 635 | all_logits = torch.cat(all_logits, dim=0) 636 | eval_metric = compute_metric((all_logits, all_labels)) 637 | 638 | wandb.log({ 639 | "accuracy-mm" if args.task_name is not None else "glue": eval_metric, 640 | }) 641 | 642 | if args.output_dir is not None: 643 | all_results = {f"eval_{k}": v for k, v in eval_metric.items()} 644 | with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: 645 | json.dump(all_results, f) 646 | 647 | 648 | if __name__ == "__main__": 649 | main() 650 | -------------------------------------------------------------------------------- /optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from optimizers.sam import SAM 2 | from optimizers.msassha import MSASSHA 3 | from optimizers.adahessian import Adahessian 4 | from optimizers.sophiaH import SophiaH 5 | from optimizers.shampoo import Shampoo 6 | from optimizers.sassha import SASSHA 7 | import torch.optim as optim 8 | 9 | def configure_sassha(model, args): 10 | create_graph = True 11 | two_steps = True 12 | 13 | optimizer = SASSHA( 14 | model.parameters(), 15 | betas=tuple(args.betas), 16 | lr=args.lr, 17 | weight_decay=args.weight_decay/args.lr, 18 | rho=args.rho, 19 | lazy_hessian=args.lazy_hessian, 20 | eps=args.eps, 21 | seed=args.seed) 22 | 23 | return optimizer, create_graph, two_steps 24 | 25 | 26 | def configure_samsgd(model, args): 27 | create_graph = False 28 | two_steps = True 29 | 30 | optimizer = SAM( 31 | model.parameters(), optim.SGD, rho=args.rho, adaptive=args.adaptive, 32 | momentum=0.9, 33 | lr=args.lr, 34 | weight_decay=args.weight_decay) 35 | 36 | return optimizer, create_graph, two_steps 37 | 38 | 39 | def configure_samadamw(model, args): 40 | create_graph = False 41 | two_steps = True 42 | 43 | optimizer = SAM( 44 | model.parameters(), optim.AdamW, rho=args.rho, adaptive=args.adaptive, 45 | betas=tuple(args.betas), 46 | lr=args.lr, 47 | weight_decay=args.weight_decay/args.lr) 48 | 49 | return optimizer, create_graph, two_steps 50 | 51 | 52 | def configure_sgd(model, args): 53 | create_graph = False 54 | two_steps = False 55 | 56 | optimizer = optim.SGD( 57 | model.parameters(), 58 | lr=args.lr, 59 | momentum=0.9, 60 | weight_decay=args.weight_decay) 61 | 62 | return optimizer, create_graph, two_steps 63 | 64 | 65 | def configure_adamw(model, args): 66 | create_graph = False 67 | two_steps = False 68 | 69 | optimizer = optim.AdamW( 70 | model.parameters(), 71 | betas=tuple(args.betas), 72 | lr=args.lr, 73 | weight_decay=args.weight_decay/args.lr) 74 | 75 | return optimizer, create_graph, two_steps 76 | 77 | 78 | def configure_adahessian(model, args): 79 | create_graph = True 80 | two_steps = False 81 | 82 | optimizer = Adahessian( 83 | model.parameters(), 84 | betas=tuple(args.betas), 85 | lr=args.lr, 86 | weight_decay=args.weight_decay/args.lr, 87 | lazy_hessian=args.lazy_hessian, 88 | eps=args.eps, 89 | seed=args.seed) 90 | 91 | return optimizer, create_graph, two_steps 92 | 93 | 94 | def configure_sophiah(model, args): 95 | create_graph = True 96 | two_steps = False 97 | 98 | optimizer = SophiaH( 99 | model.parameters(), 100 | betas=tuple(args.betas), 101 | lr=args.lr, 102 | weight_decay=args.weight_decay / args.lr, 103 | clip_threshold=args.clip_threshold, 104 | eps=args.eps, 105 | lazy_hessian=args.lazy_hessian, 106 | seed=args.seed) 107 | 108 | return optimizer, create_graph, two_steps 109 | 110 | 111 | def configure_msassha(model, args): 112 | create_graph = True 113 | two_steps = False 114 | 115 | optimizer = MSASSHA( 116 | model.parameters(), 117 | lr=args.lr, 118 | rho=args.rho, 119 | weight_decay=args.weight_decay / args.lr, 120 | lazy_hessian=args.lazy_hessian, 121 | eps=args.eps, 122 | seed=args.seed) 123 | 124 | return optimizer, create_graph, two_steps 125 | 126 | 127 | def configure_shampoo(model, args): 128 | create_graph = False 129 | two_steps = False 130 | 131 | optimizer = Shampoo( 132 | params=model.parameters(), 133 | lr=args.lr, 134 | momentum=0.9, 135 | weight_decay=args.weight_decay, 136 | epsilon=args.eps, 137 | update_freq=1) 138 | 139 | return optimizer, create_graph, two_steps 140 | 141 | 142 | def get_optimizer(model, args): 143 | optimizer_map = { 144 | 'sassha': configure_sassha, 145 | 'samsgd': configure_samsgd, 146 | 'samadamw': configure_samadamw, 147 | 'adahessian': configure_adahessian, 148 | 'adamw': configure_adamw, 149 | 'sgd': configure_sgd, 150 | 'sophiah': configure_sophiah, 151 | 'msassha': configure_msassha, 152 | 'shampoo': configure_shampoo, 153 | } 154 | 155 | if args.optimizer not in optimizer_map: 156 | raise ValueError(f"Unsupported optimizer: {args.optimizer}") 157 | 158 | return optimizer_map[args.optimizer](model, args) 159 | 160 | -------------------------------------------------------------------------------- /optimizers/adahessian.py: -------------------------------------------------------------------------------- 1 | # Acknowledgement: This code is based on https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py 2 | 3 | import torch 4 | import math 5 | 6 | class Adahessian(torch.optim.Optimizer): 7 | """ 8 | Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning" 9 | At the previous experiment, I found there's a problem on 'denom' that the experiment result varies depending on 10 | how to define 'denom'. So I decided to follow the official pytorch adam code. 11 | 12 | Arguments: 13 | params (iterable) -- iterable of parameters to optimize or dicts defining parameter groups 14 | lr (float, optional) -- learning rate (default: 0.15) 15 | betas ((float, float), optional) -- coefficients used for computing running averages of gradient and the squared hessian trace (default: (0.9, 0.999)) 16 | eps (float, optional) -- term added to the denominator to improve numerical stability (default: 1e-4) 17 | weight_decay (float, optional) -- weight decay (L2 penalty) (default: 0.0) 18 | hessian_power (float, optional) -- exponent of the hessian trace (default: 1.0) 19 | lazy_hessian (int, optional) -- compute the hessian trace approximation only after *this* number of steps (to save time) (default: 1) 20 | n_samples (int, optional) -- how many times to sample `z` for the approximation of the hessian trace (default: 1) 21 | """ 22 | 23 | def __init__(self, 24 | params, 25 | lr=0.15, 26 | betas=(0.9, 0.999), 27 | eps=1e-4, 28 | weight_decay=0.0, 29 | hessian_power=1, 30 | lazy_hessian=1, 31 | n_samples=1, 32 | seed=0): 33 | 34 | if not 0.0 <= lr: 35 | raise ValueError(f"Invalid learning rate: {lr}") 36 | if not 0.0 <= eps: 37 | raise ValueError(f"Invalid epsilon value: {eps}") 38 | if not 0.0 <= betas[0] < 1.0: 39 | raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") 40 | if not 0.0 <= betas[1] < 1.0: 41 | raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") 42 | if not 0.0 <= hessian_power <= 1.0: 43 | raise ValueError(f"Invalid Hessian power value: {hessian_power}") 44 | 45 | self.n_samples = n_samples 46 | self.lazy_hessian = lazy_hessian 47 | self.seed = seed 48 | 49 | # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training 50 | self.generator = torch.Generator().manual_seed(self.seed) 51 | 52 | defaults = dict(lr=lr, 53 | betas=betas, 54 | eps=eps, 55 | weight_decay=weight_decay, 56 | hessian_power=hessian_power) 57 | 58 | super(Adahessian, self).__init__(params, defaults) 59 | 60 | for p in self.get_params(): 61 | p.hess = 0.0 62 | self.state[p]["hessian step"] = 0 63 | 64 | def get_params(self): 65 | """ 66 | Gets all parameters in all param_groups with gradients 67 | """ 68 | 69 | return (p for group in self.param_groups for p in group['params'] if p.requires_grad) 70 | 71 | def zero_hessian(self): 72 | """ 73 | Zeros out the accumalated hessian traces. 74 | """ 75 | 76 | for p in self.get_params(): 77 | if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.lazy_hessian == 0: 78 | p.hess.zero_() 79 | 80 | @torch.no_grad() 81 | def set_hessian(self): 82 | """ 83 | Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter. 84 | """ 85 | params = [] 86 | for p in filter(lambda p: p.grad is not None, self.get_params()): 87 | if self.state[p]["hessian step"] % self.lazy_hessian == 0: # compute a new Hessian per `lazy_hessian` step 88 | params.append(p) 89 | self.state[p]["hessian step"] += 1 90 | 91 | if len(params) == 0: 92 | return 93 | 94 | if self.generator.device != params[0].device and self.seed is not None: # hackish way of casting the generator to the right device 95 | self.generator = torch.Generator(params[0].device).manual_seed(self.seed) 96 | 97 | grads = [p.grad for p in params] 98 | 99 | last_sample = self.n_samples - 1 100 | for i in range(self.n_samples): 101 | # Rademacher distribution {-1.0, 1.0} 102 | zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params] 103 | h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < last_sample) 104 | for h_z, z, p in zip(h_zs, zs, params): 105 | p.hess += h_z * z / self.n_samples 106 | 107 | @torch.no_grad() 108 | def step(self, closure=None): 109 | """ 110 | Performs a single optimization step. 111 | Arguments: 112 | closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None) 113 | """ 114 | 115 | loss = None 116 | if closure is not None: 117 | loss = closure() 118 | 119 | self.zero_hessian() 120 | self.set_hessian() 121 | 122 | for group in self.param_groups: 123 | for p in group['params']: 124 | if p.grad is None or p.hess is None: 125 | continue 126 | 127 | if p.dim() <= 2: 128 | p.hess = p.hess.abs().clone() 129 | 130 | if p.dim() == 4: 131 | p.hess = torch.mean(p.hess.abs(), dim=[2, 3], keepdim=True).expand_as(p.hess).clone() 132 | 133 | # Perform correct stepweight decay as in AdamW 134 | p.mul_(1 - group['lr'] * group['weight_decay']) 135 | 136 | state = self.state[p] 137 | 138 | # State initialization 139 | if len(state) == 1: 140 | state['step'] = 0 141 | state['exp_avg'] = torch.zeros_like(p.data) # Exponential moving average of gradient values 142 | state['exp_hessian_diag_sq'] = torch.zeros_like(p.data) # Exponential moving average of Hessian diagonal square values 143 | 144 | exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq'] 145 | beta1, beta2 = group['betas'] 146 | state['step'] += 1 147 | 148 | # Decay the first and second moment running average coefficient 149 | exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1) 150 | exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2) 151 | 152 | bias_correction1 = 1 - beta1 ** state['step'] 153 | bias_correction2 = 1 - beta2 ** state['step'] 154 | 155 | step_size = group['lr'] / bias_correction1 156 | step_size_neg = -step_size 157 | 158 | k = group['hessian_power'] 159 | denom = (exp_hessian_diag_sq / bias_correction2).pow_(k/2).add_(group['eps']) 160 | 161 | # make update 162 | p.addcdiv_(exp_avg, denom, value=step_size_neg) 163 | 164 | return loss 165 | -------------------------------------------------------------------------------- /optimizers/hessian_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | class ProportionScheduler: 5 | def __init__(self, pytorch_lr_scheduler, max_lr, min_lr, max_value, min_value): 6 | """ 7 | This scheduler outputs a value that evolves proportional to pytorch_lr_scheduler, e.g. 8 | (value - min_value) / (max_value - min_value) = (lr - min_lr) / (max_lr - min_lr) 9 | """ 10 | self.t = 0 11 | self.pytorch_lr_scheduler = pytorch_lr_scheduler 12 | self.max_lr = max_lr 13 | self.min_lr = min_lr 14 | self.max_value = max_value 15 | self.min_value = min_value 16 | 17 | assert (max_lr > min_lr) or ((max_lr==min_lr) and (max_value==min_value)), "Current scheduler for `value` is scheduled to evolve proportionally to `lr`," \ 18 | "e.g. `(lr - min_lr) / (max_lr - min_lr) = (value - min_value) / (max_value - min_value)`. Please check `max_lr >= min_lr` and `max_value >= min_value`;" \ 19 | "if `max_lr==min_lr` hence `lr` is constant with step, please set 'max_value == min_value' so 'value' is constant with step." 20 | 21 | assert max_value >= min_value 22 | 23 | self.step() # take 1 step during initialization to get self._last_lr 24 | 25 | def lr(self): 26 | return self._last_lr[0] 27 | 28 | def step(self): 29 | self.t += 1 30 | if hasattr(self.pytorch_lr_scheduler, "_last_lr"): 31 | lr = self.pytorch_lr_scheduler._last_lr[0] 32 | else: 33 | lr = self.pytorch_lr_scheduler.optimizer.param_groups[0]['lr'] 34 | 35 | if self.max_lr > self.min_lr: 36 | value = self.max_value - (self.max_value - self.min_value) * (lr - self.min_lr) / (self.max_lr - self.min_lr) 37 | else: 38 | value = self.max_value 39 | 40 | self._last_lr = [value] 41 | return value 42 | 43 | class SchedulerBase: 44 | def __init__(self, T_max, max_value, min_value=0.0, init_value=0.0, warmup_steps=0, optimizer=None): 45 | super(SchedulerBase, self).__init__() 46 | self.t = 0 47 | self.min_value = min_value 48 | self.max_value = max_value 49 | self.init_value = init_value 50 | self.warmup_steps = warmup_steps 51 | self.total_steps = T_max 52 | 53 | # record current value in self._last_lr to match API from torch.optim.lr_scheduler 54 | self._last_lr = [init_value] 55 | 56 | # If optimizer is not None, will set learning rate to all trainable parameters in optimizer. 57 | # If optimizer is None, only output the value of lr. 58 | self.optimizer = optimizer 59 | 60 | def step(self): 61 | if self.t < self.warmup_steps: 62 | value = self.init_value + (self.max_value - self.init_value) * self.t / self.warmup_steps 63 | elif self.t == self.warmup_steps: 64 | value = self.min_value 65 | else: 66 | value = self.step_func() 67 | self.t += 1 68 | 69 | # apply the lr to optimizer if it's provided 70 | if self.optimizer is not None: 71 | for param_group in self.optimizer.param_groups: 72 | param_group['lr'] = value 73 | 74 | self._last_lr = [value] 75 | return value 76 | 77 | def step_func(self): 78 | pass 79 | 80 | def lr(self): 81 | return self._last_lr[0] 82 | 83 | class LinearScheduler(SchedulerBase): 84 | def step_func(self): 85 | value = self.min_value + (self.max_value - self.min_value) * (self.t - self.warmup_steps) / ( 86 | self.total_steps - self.warmup_steps) 87 | return value 88 | 89 | class CosineScheduler(SchedulerBase): 90 | def step_func(self): 91 | phase = (self.t-self.warmup_steps) / (self.total_steps-self.warmup_steps) * math.pi 92 | value = self.max_value - (self.max_value-self.min_value) * (np.cos(phase) + 1.) / 2.0 93 | return value 94 | 95 | class ConstantScheduler(SchedulerBase): 96 | def step_func(self): 97 | value = self.min_value 98 | return value 99 | 100 | -------------------------------------------------------------------------------- /optimizers/msassha.py: -------------------------------------------------------------------------------- 1 | # Acknowledgement: This code is based on https://github.com/MarlonBecker/MSAM/blob/main/optimizer/adamW_msam.py 2 | 3 | from typing import List 4 | from torch import Tensor 5 | import torch 6 | import math 7 | 8 | # cf. https://github.com/pytorch/pytorch/blob/main/torch/optim/adamw.py 9 | class MSASSHA(torch.optim.Optimizer): 10 | def __init__( 11 | self, 12 | params, 13 | lr: float = 0.15, 14 | betas: float = (0.9, 0.999), 15 | weight_decay: float = 1e-2, 16 | lazy_hessian: int = 10, 17 | rho: float = 0.3, 18 | n_samples: int = 1, 19 | eps: float = 1e-4, 20 | hessian_power: int = 1, 21 | seed: int = 0, 22 | maximize: bool = False 23 | ): 24 | 25 | defaults = dict( 26 | lr=lr, 27 | betas=betas, 28 | weight_decay=weight_decay, 29 | rho=rho, 30 | eps=eps, 31 | maximize=maximize 32 | ) 33 | 34 | self.lazy_hessian = lazy_hessian 35 | self.n_samples = n_samples 36 | self.seed = seed 37 | self.hessian_power = hessian_power 38 | 39 | # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training 40 | self.generator = torch.Generator().manual_seed(self.seed) 41 | 42 | super(MSASSHA, self).__init__(params, defaults) 43 | 44 | # init momentum buffer to zeros 45 | # needed to make implementation of first ascent step cleaner (before SGD.step() was ever called) 46 | 47 | for p in self.get_params(): 48 | p.hess = 0.0 49 | if self.track_hessian: 50 | p.real_hess = 0.0 51 | 52 | state = self.state[p] 53 | state["hessian_step"] = 0 54 | 55 | for group in self.param_groups: 56 | group["norm_factor"] = [0,] 57 | 58 | def get_params(self): 59 | """ 60 | Gets all parameters in all param_groups with gradients 61 | """ 62 | 63 | return (p for group in self.param_groups for p in group['params'] if p.requires_grad) 64 | 65 | 66 | def zero_hessian(self): 67 | """ 68 | Zeros out the accumulated hessian traces. 69 | """ 70 | 71 | for p in self.get_params(): 72 | if not isinstance(p.hess, float) and self.state[p]["hessian_step"] % self.lazy_hessian == 0: 73 | p.hess.zero_() 74 | 75 | 76 | @torch.no_grad() 77 | def set_hessian(self): 78 | """ 79 | Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter. 80 | """ 81 | 82 | params = [] 83 | for p in filter(lambda p: p.grad is not None, self.get_params()): 84 | if self.state[p]["hessian_step"] % self.lazy_hessian == 0: # compute a new Hessian per `lazy_hessian` step 85 | params.append(p) 86 | self.state[p]["hessian_step"] += 1 87 | 88 | if len(params) == 0: 89 | return 90 | 91 | if self.generator.device != params[0].device: # hackish way of casting the generator to the right device 92 | self.generator = torch.Generator(params[0].device).manual_seed(self.seed) 93 | 94 | grads = [p.grad for p in params] 95 | 96 | last_sample = self.n_samples - 1 97 | for i in range(self.n_samples): 98 | # Rademacher distribution {-1.0, 1.0} 99 | zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params] 100 | h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < last_sample) 101 | for h_z, z, p in zip(h_zs, zs, params): 102 | p.hess += h_z * z / self.n_samples 103 | 104 | @torch.no_grad() 105 | def step(self, closure=None): 106 | """ Performs a single optimization step. 107 | 108 | Args: 109 | closure (callable, optional): A closure that reevaluates the model 110 | and returns the loss. 111 | """ 112 | loss = None 113 | if closure is not None: 114 | with torch.enable_grad(): 115 | loss = closure() 116 | if len(self.param_groups) > 1: 117 | raise RuntimeError("only one parameter group supported atm for SAMANDA_MSAM") 118 | 119 | group = self.param_groups[0] 120 | params_with_grad = [] 121 | grads = [] 122 | hesses = [] 123 | grad_momentums = [] 124 | hess_momentums = [] 125 | state_steps = [] 126 | 127 | beta1, beta2 = group['betas'] 128 | 129 | self.zero_hessian() 130 | self.set_hessian() 131 | 132 | for p in group['params']: 133 | if p.grad is None: 134 | continue 135 | 136 | params_with_grad.append(p) 137 | if p.grad.is_sparse: 138 | raise RuntimeError('Msassha does not support sparse gradients') 139 | 140 | grads.append(p.grad) 141 | hesses.append(p.hess) 142 | 143 | state = self.state[p] 144 | # State initialization 145 | if len(state) == 1: 146 | state['step'] = 0 147 | state['grad_momentum'] = torch.zeros_like(p, memory_format=torch.preserve_format) 148 | state['hess_momentum'] = torch.zeros_like(p, memory_format=torch.preserve_format) 149 | 150 | grad_momentums.append(state['grad_momentum']) 151 | hess_momentums.append(state['hess_momentum']) 152 | 153 | # update the steps for each param group update 154 | state['step'] += 1 155 | # record the step after step update 156 | state_steps.append(state['step']) 157 | 158 | samanda_msam(params_with_grad, 159 | grads, 160 | hesses, 161 | grad_momentums, 162 | hess_momentums, 163 | state_steps, 164 | beta1=beta1, 165 | beta2=beta2, 166 | lr=group['lr'], 167 | weight_decay=group['weight_decay'], 168 | lazy_hessian=self.lazy_hessian, 169 | rho=group['rho'], 170 | norm_factor=group['norm_factor'], 171 | eps=group['eps'] 172 | ) 173 | 174 | return loss 175 | 176 | 177 | @torch.no_grad() 178 | def move_up_to_momentumAscent(self): 179 | for group in self.param_groups: 180 | for p in group['params']: 181 | if "grad_momentum" in self.state[p]: 182 | p.sub_(self.state[p]["grad_momentum"], alpha=group["norm_factor"][0]) 183 | 184 | 185 | @torch.no_grad() 186 | def move_back_from_momentumAscent(self): 187 | for group in self.param_groups: 188 | for p in group['params']: 189 | if "grad_momentum" in self.state[p]: 190 | p.add_(self.state[p]["grad_momentum"], alpha=group["norm_factor"][0]) 191 | 192 | bias_correction2 = 0 193 | def samanda_msam(params: List[Tensor], 194 | grads: List[Tensor], 195 | hesses: List[Tensor], 196 | grad_momentums: List[Tensor], 197 | hess_momentums: List[Tensor], 198 | state_steps: List[int], 199 | *, 200 | beta1: float, 201 | beta2: float, 202 | lr: float, 203 | weight_decay: float, 204 | lazy_hessian: int, 205 | rho:float, 206 | norm_factor: list, 207 | eps: float 208 | ): 209 | r"""Functional API that performs AdamW algorithm computation. 210 | 211 | See :class:`~torch.optim.AdamW` for details. 212 | """ 213 | 214 | for i, param in enumerate(params): 215 | grad = grads[i] 216 | hess = hesses[i] 217 | grad_momentum = grad_momentums[i] 218 | hess_momentum = hess_momentums[i] 219 | step = state_steps[i] 220 | 221 | # remove last perturbation (descent) w_t <- \tilde{w_t} + rho*m_t/||m_t|| 222 | param.add_(grad_momentum, alpha=norm_factor[0]) 223 | 224 | # weight decay 225 | param.mul_(1 - lr * weight_decay) 226 | 227 | # Decay the first and second moment running average coefficient 228 | grad_momentum.mul_(beta1).add_(grad, alpha=1-beta1) 229 | bias_correction1 = 1 - beta1 ** step 230 | 231 | if (step-1) % lazy_hessian == 0: 232 | hess_momentum.mul_(beta2).add_(hess.abs(), alpha=1-beta2) 233 | global bias_correction2 234 | bias_correction2 = 1 - beta2 ** step 235 | 236 | denom = ((hess_momentum ** self.hessian_power) / (bias_correction2 ** self.hessian_power)).add_(eps) 237 | 238 | step_size = lr / bias_correction1 239 | 240 | # make update 241 | param.addcdiv_(grad_momentum, denom, value=-step_size) 242 | 243 | #calculate ascent step norm 244 | ascent_norm = torch.norm( 245 | torch.stack([ 246 | grad_momentum.norm(p=2) 247 | for grad_momentum in grad_momentums 248 | ]), 249 | p=2 250 | ) 251 | norm_factor[0] = 1/(ascent_norm+1e-12) * rho 252 | 253 | # perturb for next iteration (ascent) 254 | for i, param in enumerate(params): 255 | param.sub_(grad_momentums[i], alpha = norm_factor[0]) 256 | 257 | -------------------------------------------------------------------------------- /optimizers/sam.py: -------------------------------------------------------------------------------- 1 | # Acknowledgement: This code is based on https://github.com/davda54/sam/blob/main/sam.py 2 | import torch 3 | 4 | 5 | class SAM(torch.optim.Optimizer): 6 | def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs): 7 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" 8 | 9 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs) 10 | super(SAM, self).__init__(params, defaults) 11 | 12 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs) 13 | self.param_groups = self.base_optimizer.param_groups 14 | self.defaults.update(self.base_optimizer.defaults) 15 | 16 | @torch.no_grad() 17 | def first_step(self, zero_grad=False): 18 | grad_norm = self._grad_norm() 19 | for group in self.param_groups: 20 | scale = group["rho"] / (grad_norm + 1e-12) 21 | 22 | for p in group["params"]: 23 | if p.grad is None: continue 24 | self.state[p]["old_p"] = p.data.clone() 25 | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p) 26 | p.add_(e_w) # climb to the local maximum "w + e(w)" 27 | 28 | if zero_grad: self.zero_grad() 29 | 30 | @torch.no_grad() 31 | def second_step(self, zero_grad=False): 32 | for group in self.param_groups: 33 | for p in group["params"]: 34 | if p.grad is None: continue 35 | p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)" 36 | 37 | self.base_optimizer.step() # do the actual "sharpness-aware" update 38 | 39 | if zero_grad: self.zero_grad() 40 | 41 | @torch.no_grad() 42 | def step(self, closure=None): 43 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided" 44 | closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass 45 | 46 | self.first_step(zero_grad=True) 47 | closure() 48 | self.second_step() 49 | 50 | def _grad_norm(self): 51 | shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism 52 | norm = torch.norm( 53 | torch.stack([ 54 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device) 55 | for group in self.param_groups for p in group["params"] 56 | if p.grad is not None 57 | ]), 58 | p=2 59 | ) 60 | return norm 61 | 62 | def load_state_dict(self, state_dict): 63 | super().load_state_dict(state_dict) 64 | self.base_optimizer.param_groups = self.param_groups -------------------------------------------------------------------------------- /optimizers/sassha.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import contextlib 4 | from torch.distributed import ReduceOp 5 | 6 | 7 | class SASSHA(torch.optim.Optimizer): 8 | """Implements the Sharpness-Aware Second-Order optimization with Stable Hessian Approximation (SASSHA) algorithm. 9 | 10 | Args: 11 | params (iterable): Iterable of parameters to optimize or dicts defining parameter groups. 12 | hessian_power_scheduler (None): Update the Hessian power at every training step. Initially, set it to None, and later you can replace it. 13 | lr (float, optional): Learning rate. 14 | betas (Tuple[float, float], optional): Coefficients for computing moving averages of gradient and Hessian. 15 | weight_decay (float, optional): Weight decay (L2 penalty). 16 | rho (float, optional): Size of the neighborhood for computing the max loss 17 | hessian_power (float, optional): Exponent of the Hessian in the update rule. 18 | lazy_hessian (int, optional): Number of optimization steps to perform before updating the Hessian. 19 | n_samples (int, optional): Number of samples to draw for the Hutchinson approximation. 20 | perturb_eps (float, optional): Small value for perturbations in Hessian trace computation. 21 | eps (float, optional): Term added to the denominator to improve numerical stability. 22 | adaptive (bool, optional): set this argument to True if you want to use an experimental implementation of element-wise Adaptive SAM. Default is False. 23 | grad_reduce (str, optional): Reduction method for gradients ('mean' or 'sum'). Default is 'mean'. 24 | seed (int, optional): Random seed for reproducibility. Default is 0. 25 | **kwargs: Additional keyword arguments for compatibility with other optimizers. 26 | """ 27 | 28 | def __init__(self, params, 29 | hessian_power_scheduler=None, 30 | lr=0.15, 31 | betas=(0.9, 0.999), 32 | weight_decay=0.0, 33 | rho=0.0, 34 | lazy_hessian=10, 35 | n_samples=1, 36 | perturb_eps=1e-12, 37 | eps=1e-4, 38 | adaptive=False, 39 | grad_reduce='mean', 40 | seed=0, 41 | **kwargs): 42 | 43 | if not 0.0 <= lr: 44 | raise ValueError(f"Invalid learning rate: {lr}") 45 | if not 0.0 <= eps: 46 | raise ValueError(f"Invalid epsilon value: {eps}") 47 | if not 0.0 <= betas[0] < 1.0: 48 | raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") 49 | if not 0.0 <= betas[1] < 1.0: 50 | raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") 51 | 52 | self.hessian_power_scheduler = hessian_power_scheduler 53 | self.lazy_hessian = lazy_hessian 54 | self.n_samples = n_samples 55 | self.adaptive = adaptive 56 | self.seed = seed 57 | 58 | defaults = dict(lr=lr, 59 | betas=betas, 60 | weight_decay=weight_decay, 61 | rho=rho, 62 | perturb_eps=perturb_eps, 63 | eps=eps) 64 | 65 | super(SASSHA, self).__init__(params, defaults) 66 | 67 | for p in self.get_params(): 68 | p.hess = 0.0 69 | self.state[p]["hessian step"] = 0 70 | 71 | # set up reduction for gradient across workers 72 | if grad_reduce.lower() == 'mean': 73 | if hasattr(ReduceOp, 'AVG'): 74 | self.grad_reduce = ReduceOp.AVG 75 | self.manual_average = False 76 | else: # PyTorch <= 1.11.0 does not have AVG, need to manually average across processes 77 | self.grad_reduce = ReduceOp.SUM 78 | self.manual_average = True 79 | elif grad_reduce.lower() == 'sum': 80 | self.grad_reduce = ReduceOp.SUM 81 | self.manual_average = False 82 | else: 83 | raise ValueError('"grad_reduce" should be one of ["mean", "sum"].') 84 | 85 | # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training 86 | self.generator = torch.Generator().manual_seed(self.seed) 87 | 88 | 89 | def get_params(self): 90 | """ 91 | Gets all parameters in all param_groups with gradients 92 | """ 93 | 94 | return (p for group in self.param_groups for p in group['params'] if p.requires_grad) 95 | 96 | def zero_hessian(self): 97 | """ 98 | Zeros out the accumalated hessian traces. 99 | """ 100 | 101 | for p in self.get_params(): 102 | if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.lazy_hessian == 0: 103 | p.hess.zero_() 104 | 105 | 106 | @torch.no_grad() 107 | def update_hessian_power(self): 108 | """ 109 | Update the Hessian power at every training step. 110 | """ 111 | if self.hessian_power_scheduler is not None: 112 | self.hessian_power_t = self.hessian_power_scheduler.step() 113 | else: 114 | self.hessian_power_t = None 115 | return self.hessian_power_t 116 | 117 | 118 | @torch.no_grad() 119 | def set_hessian(self): 120 | """ 121 | Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter. 122 | """ 123 | params = [] 124 | for p in filter(lambda p: p.grad is not None, self.get_params()): 125 | if self.state[p]["hessian step"] % self.lazy_hessian == 0: # compute a new Hessian once per 'lazy hessian' steps 126 | params.append(p) 127 | self.state[p]["hessian step"] += 1 128 | 129 | if len(params) == 0: 130 | return 131 | 132 | if self.generator.device != params[0].device: # hackish way of casting the generator to the right device 133 | self.generator = torch.Generator(params[0].device).manual_seed(self.seed) 134 | 135 | grads = [p.grad for p in params] 136 | 137 | last_sample = self.n_samples - 1 138 | for i in range(self.n_samples): 139 | # Rademacher distribution {-1.0, 1.0} 140 | zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params] 141 | h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < last_sample) 142 | for h_z, z, p in zip(h_zs, zs, params): 143 | p.hess += h_z * z / self.n_samples 144 | 145 | @torch.no_grad() 146 | def perturb_weights(self, zero_grad=True): 147 | grad_norm = self._grad_norm(weight_adaptive=self.adaptive) 148 | for group in self.param_groups: 149 | scale = group["rho"] / (grad_norm + group["perturb_eps"]) 150 | 151 | for p in group["params"]: 152 | if p.grad is None: continue 153 | e_w = p.grad * scale.to(p) 154 | if self.adaptive: 155 | e_w *= torch.pow(p, 2) 156 | p.add_(e_w) # climb to the local maximum "w + e(w)" 157 | self.state[p]['e_w'] = e_w 158 | 159 | if zero_grad: self.zero_grad() 160 | 161 | @torch.no_grad() 162 | def unperturb(self): 163 | for group in self.param_groups: 164 | for p in group['params']: 165 | if 'e_w' in self.state[p].keys(): 166 | p.data.sub_(self.state[p]['e_w']) 167 | 168 | @torch.no_grad() 169 | def _grad_norm(self, by=None, weight_adaptive=False): 170 | if not by: 171 | norm = torch.norm( 172 | torch.stack([ 173 | ( (torch.abs(p.data) if weight_adaptive else 1.0) * p.grad).norm(p=2) 174 | for group in self.param_groups for p in group["params"] 175 | if p.grad is not None 176 | ]), 177 | p=2 178 | ) 179 | else: 180 | norm = torch.norm( 181 | torch.stack([ 182 | ( (torch.abs(p.data) if weight_adaptive else 1.0) * self.state[p][by]).norm(p=2) 183 | for group in self.param_groups for p in group["params"] 184 | if p.grad is not None 185 | ]), 186 | p=2 187 | ) 188 | return norm 189 | 190 | @torch.no_grad() 191 | def _sync_gradients(self): 192 | for group in self.param_groups: 193 | for p in group['params']: 194 | if p.grad is None: continue 195 | if torch.distributed.is_initialized(): # synchronize final gardients 196 | if self.manual_average: 197 | torch.distributed.all_reduce(p.grad, op=self.grad_reduce) 198 | world_size = torch.distributed.get_world_size() 199 | p.grad.div_(float(world_size)) 200 | else: 201 | torch.distributed.all_reduce(p.grad, op=self.grad_reduce) 202 | return 203 | 204 | @torch.no_grad() 205 | def _sync_hessians(self): 206 | for group in self.param_groups: 207 | for p in group['params']: 208 | if p.hess is None: continue 209 | if torch.distributed.is_initialized(): # synchronize final hessian 210 | if not p.hess.is_contiguous(): 211 | p.hess = p.hess.contiguous() 212 | 213 | if self.manual_average: 214 | torch.distributed.all_reduce(p.hess, op=self.grad_reduce) 215 | world_size = torch.distributed.get_world_size() 216 | p.hess.div_(float(world_size)) 217 | else: 218 | torch.distributed.all_reduce(p.hess, op=self.grad_reduce) 219 | return 220 | 221 | @torch.no_grad() 222 | def step(self, closure=None): 223 | """ 224 | Performs a single optimization step. 225 | Arguments: 226 | closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None) 227 | """ 228 | 229 | self.update_hessian_power() 230 | 231 | loss = None 232 | if closure is not None: 233 | loss = closure() 234 | 235 | self.zero_hessian() 236 | self.set_hessian() 237 | self._sync_gradients() 238 | self._sync_hessians() 239 | 240 | # prepare to update parameters 241 | for group in self.param_groups: 242 | for p in group['params']: 243 | if p.grad is None or p.hess is None: 244 | continue 245 | 246 | p.hess = p.hess.abs().clone() 247 | 248 | # Perform correct stepweight decay as in AdamW 249 | p.mul_(1 - group['lr'] * group['weight_decay']) 250 | 251 | state = self.state[p] 252 | # State initialization 253 | if len(state) == 2: 254 | state['step'] = 0 255 | state['exp_avg'] = torch.zeros_like(p.data) 256 | state['exp_hessian_diag'] = torch.zeros_like(p.data) 257 | state['bias_correction2'] = 0 258 | 259 | exp_avg, exp_hessian_diag = state['exp_avg'], state['exp_hessian_diag'] 260 | beta1, beta2 = group['betas'] 261 | state['step'] += 1 262 | 263 | # Decay the first and second moment running average coefficient 264 | exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1) 265 | bias_correction1 = 1 - beta1 ** state['step'] 266 | 267 | if (state['hessian step']-1) % self.lazy_hessian == 0: 268 | exp_hessian_diag.mul_(beta2).add_(p.hess, alpha=1 - beta2) 269 | bias_correction2 = 1 - beta2 ** state['step'] 270 | state['bias_correction2'] = bias_correction2 ** self.hessian_power_t 271 | 272 | step_size = group['lr'] / bias_correction1 273 | step_size_neg = -step_size 274 | 275 | denom = ((exp_hessian_diag**self.hessian_power_t) / state['bias_correction2']).add_(group['eps']) 276 | 277 | # make update 278 | p.addcdiv_(exp_avg, denom, value=step_size_neg) 279 | 280 | return loss 281 | -------------------------------------------------------------------------------- /optimizers/shampoo.py: -------------------------------------------------------------------------------- 1 | # Acknowledgement: This code is based on https://github.com/jettify/pytorch-optimizer/blob/master/torch_optimizer/shampoo.py 2 | 3 | import torch 4 | from torch.optim.optimizer import Optimizer 5 | 6 | def _matrix_power(matrix: torch.Tensor, power: float) -> torch.Tensor: 7 | # use CPU for svd for speed up 8 | device = matrix.device 9 | matrix = matrix.cpu() 10 | u, s, v = torch.svd(matrix) 11 | return (u @ s.pow_(power).diag() @ v.t()).to(device) 12 | 13 | 14 | class Shampoo(Optimizer): 15 | r"""Implements Shampoo Optimizer Algorithm. 16 | 17 | It has been proposed in `Shampoo: Preconditioned Stochastic Tensor 18 | Optimization`__. 19 | 20 | Arguments: 21 | params: iterable of parameters to optimize or dicts defining 22 | parameter groups 23 | lr: learning rate (default: 1e-3) 24 | momentum: momentum factor (default: 0) 25 | weight_decay: weight decay (L2 penalty) (default: 0) 26 | epsilon: epsilon added to each mat_gbar_j for numerical stability 27 | (default: 1e-4) 28 | update_freq: update frequency to compute inverse (default: 1) 29 | 30 | Example: 31 | >>> import torch_optimizer as optim 32 | >>> optimizer = optim.Shampoo(model.parameters(), lr=0.01) 33 | >>> optimizer.zero_grad() 34 | >>> loss_fn(model(input), target).backward() 35 | >>> optimizer.step() 36 | 37 | __ https://arxiv.org/abs/1802.09568 38 | 39 | Note: 40 | Reference code: https://github.com/moskomule/shampoo.pytorch 41 | """ 42 | 43 | def __init__( 44 | self, 45 | params, 46 | lr: float = 1e-1, 47 | momentum: float = 0.0, 48 | weight_decay: float = 0.0, 49 | epsilon: float = 1e-4, 50 | update_freq: int = 1, 51 | ): 52 | if lr <= 0.0: 53 | raise ValueError("Invalid learning rate: {}".format(lr)) 54 | if momentum < 0.0: 55 | raise ValueError("Invalid momentum value: {}".format(momentum)) 56 | if weight_decay < 0.0: 57 | raise ValueError( 58 | "Invalid weight_decay value: {}".format(weight_decay) 59 | ) 60 | if epsilon < 0.0: 61 | raise ValueError("Invalid momentum value: {}".format(momentum)) 62 | if update_freq < 1: 63 | raise ValueError("Invalid momentum value: {}".format(momentum)) 64 | 65 | defaults = dict( 66 | lr=lr, 67 | momentum=momentum, 68 | weight_decay=weight_decay, 69 | epsilon=epsilon, 70 | update_freq=update_freq, 71 | ) 72 | super(Shampoo, self).__init__(params, defaults) 73 | 74 | def step(self, closure=None): 75 | """Performs a single optimization step. 76 | 77 | Arguments: 78 | closure: A closure that reevaluates the model and returns the loss. 79 | """ 80 | loss = None 81 | if closure is not None: 82 | loss = closure() 83 | 84 | for group in self.param_groups: 85 | for p in group["params"]: 86 | if p.grad is None: 87 | continue 88 | grad = p.grad.data 89 | order = grad.ndimension() 90 | original_size = grad.size() 91 | state = self.state[p] 92 | momentum = group["momentum"] 93 | weight_decay = group["weight_decay"] 94 | if len(state) == 0: 95 | state["step"] = 0 96 | if momentum > 0: 97 | state["momentum_buffer"] = grad.clone() 98 | for dim_id, dim in enumerate(grad.size()): 99 | # precondition matrices 100 | state["precond_{}".format(dim_id)] = group[ 101 | "epsilon" 102 | ] * torch.eye(dim, out=grad.new(dim, dim)) 103 | state[ 104 | "inv_precond_{dim_id}".format(dim_id=dim_id) 105 | ] = grad.new(dim, dim).zero_() 106 | 107 | if momentum > 0: 108 | grad.mul_(1 - momentum).add_( 109 | state["momentum_buffer"], alpha=momentum 110 | ) 111 | 112 | if weight_decay > 0: 113 | grad.add_(p.data, alpha=group["weight_decay"]) 114 | 115 | # See Algorithm 2 for detail 116 | for dim_id, dim in enumerate(grad.size()): 117 | precond = state["precond_{}".format(dim_id)] 118 | inv_precond = state["inv_precond_{}".format(dim_id)] 119 | 120 | # mat_{dim_id}(grad) 121 | grad = grad.transpose_(0, dim_id).contiguous() 122 | transposed_size = grad.size() 123 | grad = grad.view(dim, -1) 124 | 125 | grad_t = grad.t() 126 | precond.add_(grad @ grad_t) 127 | if state["step"] % group["update_freq"] == 0: 128 | inv_precond.copy_(_matrix_power(precond, -1 / order)) 129 | 130 | if dim_id == order - 1: 131 | # finally 132 | grad = grad_t @ inv_precond 133 | # grad: (-1, last_dim) 134 | grad = grad.view(original_size) 135 | else: 136 | # if not final 137 | grad = inv_precond @ grad 138 | # grad (dim, -1) 139 | grad = grad.view(transposed_size) 140 | 141 | state["step"] += 1 142 | state["momentum_buffer"] = grad 143 | p.data.add_(grad, alpha=-group["lr"]) 144 | 145 | return loss 146 | -------------------------------------------------------------------------------- /optimizers/sophiaH.py: -------------------------------------------------------------------------------- 1 | # Acknowledgement: This code is based on https://github.com/Liuhong99/Sophia/blob/main/sophia.py 2 | 3 | import torch 4 | import math 5 | 6 | class SophiaH(torch.optim.Optimizer): 7 | def __init__(self, 8 | params, 9 | lr=0.15, 10 | betas=(0.965, 0.99), 11 | eps=1e-15, 12 | weight_decay=1e-1, 13 | lazy_hessian=10, 14 | n_samples=1, 15 | clip_threshold=0.04, 16 | seed=0): 17 | 18 | if not 0.0 <= lr: 19 | raise ValueError(f"Invalid learning rate: {lr}") 20 | if not 0.0 <= eps: 21 | raise ValueError(f"Invalid epsilon value: {eps}") 22 | if not 0.0 <= betas[0] < 1.0: 23 | raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") 24 | if not 0.0 <= betas[1] < 1.0: 25 | raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") 26 | if not 0.0 <= clip_threshold: 27 | raise ValueError(f"Invalid threshold parameter: {clip_threshold}") 28 | 29 | self.n_samples = n_samples 30 | self.lazy_hessian = lazy_hessian 31 | self.seed = seed 32 | 33 | # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training 34 | self.generator = torch.Generator().manual_seed(self.seed) 35 | 36 | defaults = dict(lr=lr, 37 | betas=betas, 38 | eps=eps, 39 | weight_decay=weight_decay, 40 | clip_threshold=clip_threshold) 41 | 42 | super(SophiaH, self).__init__(params, defaults) 43 | 44 | for p in self.get_params(): 45 | p.hess = 0.0 46 | self.state[p]["hessian step"] = 0 47 | 48 | def get_params(self): 49 | """ 50 | Gets all parameters in all param_groups with gradients 51 | """ 52 | 53 | return (p for group in self.param_groups for p in group['params'] if p.requires_grad) 54 | 55 | def zero_hessian(self): 56 | """ 57 | Zeros out the accumalated hessian traces. 58 | """ 59 | 60 | for p in self.get_params(): 61 | if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.lazy_hessian == 0: 62 | p.hess.zero_() 63 | 64 | @torch.no_grad() 65 | def set_hessian(self): 66 | """ 67 | Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter. 68 | """ 69 | 70 | params = [] 71 | for p in filter(lambda p: p.grad is not None, self.get_params()): 72 | if self.state[p]["hessian step"] % self.lazy_hessian == 0: # compute a new Hessian once per 'lazy hessian' steps 73 | params.append(p) 74 | self.state[p]["hessian step"] += 1 75 | 76 | if len(params) == 0: 77 | return 78 | 79 | if self.generator.device != params[0].device: # hackish way of casting the generator to the right device 80 | self.generator = torch.Generator(params[0].device).manual_seed(self.seed) 81 | 82 | grads = [p.grad for p in params] 83 | 84 | last_sample = self.n_samples - 1 85 | for i in range(self.n_samples): 86 | # Rademacher distribution {-1.0, 1.0} 87 | zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params] 88 | h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < last_sample) 89 | for h_z, z, p in zip(h_zs, zs, params): 90 | p.hess += h_z * z / self.n_samples 91 | 92 | @torch.no_grad() 93 | def step(self, closure=None): 94 | """ 95 | Performs a single optimization step. 96 | Arguments: 97 | closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None) 98 | """ 99 | 100 | loss = None 101 | if closure is not None: 102 | loss = closure() 103 | 104 | self.zero_hessian() 105 | self.set_hessian() 106 | 107 | for group in self.param_groups: 108 | for p in group['params']: 109 | if p.grad is None or p.hess is None: 110 | continue 111 | 112 | # Perform correct stepweight decay as in AdamW 113 | p.mul_(1 - group['lr'] * group['weight_decay']) 114 | 115 | state = self.state[p] 116 | 117 | # State initialization 118 | if len(state) == 1: 119 | state['step'] = 0 120 | state['exp_avg'] = torch.zeros_like(p.data) 121 | state['exp_hessian_diag'] = torch.zeros_like(p.data) 122 | 123 | exp_avg, exp_hessian_diag = state['exp_avg'], state['exp_hessian_diag'] 124 | beta1, beta2 = group['betas'] 125 | state['step'] += 1 126 | 127 | # Decay the first and second moment running average coefficient 128 | exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1) 129 | 130 | if (state['hessian step']-1) % self.lazy_hessian == 0: 131 | exp_hessian_diag.mul_(beta2).add_(p.hess, alpha=1 - beta2) 132 | 133 | step_size = group['lr'] 134 | step_size_neg = -step_size 135 | 136 | denom = group['clip_threshold'] * exp_hessian_diag.clamp(0, None) + group['eps'] 137 | ratio = (exp_avg.abs() / denom).clamp(None, 1) 138 | 139 | # make update 140 | p.addcmul_(exp_avg.sign(), ratio, value=step_size_neg) 141 | 142 | return loss 143 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==2.19.0 2 | einops==0.7.0 3 | evaluate==0.4.2 4 | huggingface_hub 5 | numpy==1.26.2 6 | torch==2.1.1 7 | torchvision==0.16.1 8 | tqdm 9 | transformers==4.48.0 10 | wandb --------------------------------------------------------------------------------