├── README.md
├── image_classification
├── LRScheduler.py
├── bypass_bn.py
├── config
│ ├── cifar10
│ │ ├── resnet20
│ │ │ ├── adahessian.sh
│ │ │ ├── adamw.sh
│ │ │ ├── msassha.sh
│ │ │ ├── sam_sgd.sh
│ │ │ ├── sassha.sh
│ │ │ ├── sgd.sh
│ │ │ ├── shampoo.sh
│ │ │ └── sophiah.sh
│ │ └── resnet32
│ │ │ ├── adahessian.sh
│ │ │ ├── adamw.sh
│ │ │ ├── msassha.sh
│ │ │ ├── sam_sgd.sh
│ │ │ ├── sassha.sh
│ │ │ ├── sgd.sh
│ │ │ ├── shampoo.sh
│ │ │ └── sophiah.sh
│ ├── cifar100
│ │ ├── resnet32
│ │ │ ├── adahessian.sh
│ │ │ ├── adamw.sh
│ │ │ ├── msassha.sh
│ │ │ ├── sam_sgd.sh
│ │ │ ├── sassha.sh
│ │ │ ├── sgd.sh
│ │ │ ├── shampoo.sh
│ │ │ └── sophiah.sh
│ │ └── wideresnet
│ │ │ ├── adahessian.sh
│ │ │ ├── adamw.sh
│ │ │ ├── msassha.sh
│ │ │ ├── sam_sgd.sh
│ │ │ ├── sassha.sh
│ │ │ ├── sgd.sh
│ │ │ ├── shampoo.sh
│ │ │ └── sophiah.sh
│ └── imagenet
│ │ ├── resnet50
│ │ ├── adahessian.sh
│ │ ├── adamw.sh
│ │ ├── msassha.sh
│ │ ├── sam_sgd.sh
│ │ ├── sassha.sh
│ │ ├── sgd.sh
│ │ └── sophiah.sh
│ │ └── vit_small
│ │ ├── adahessian.sh
│ │ ├── adamw.sh
│ │ ├── msassha.sh
│ │ ├── sam_adamw.sh
│ │ ├── sassha.sh
│ │ ├── sgd.sh
│ │ └── sophiah.sh
├── models
│ ├── __init__.py
│ ├── resnet.py
│ ├── simple_vit.py
│ └── wide_resnet.py
├── train.py
└── utils.py
├── label_noise
├── bypass_bn.py
├── config
│ ├── cifar10
│ │ ├── adahessian
│ │ │ ├── noise20.sh
│ │ │ ├── noise40.sh
│ │ │ └── noise60.sh
│ │ ├── msassha
│ │ │ ├── noise20.sh
│ │ │ ├── noise40.sh
│ │ │ └── noise60.sh
│ │ ├── sam_sgd
│ │ │ ├── noise20.sh
│ │ │ ├── noise40.sh
│ │ │ └── noise60.sh
│ │ ├── sassha
│ │ │ ├── noise20.sh
│ │ │ ├── noise40.sh
│ │ │ └── noise60.sh
│ │ ├── sgd
│ │ │ ├── noise20.sh
│ │ │ ├── noise40.sh
│ │ │ └── noise60.sh
│ │ ├── shampoo
│ │ │ ├── noise20.sh
│ │ │ ├── noise40.sh
│ │ │ └── noise60.sh
│ │ └── sophiah
│ │ │ ├── noise20.sh
│ │ │ ├── noise40.sh
│ │ │ └── noise60.sh
│ └── cifar100
│ │ ├── adahessian
│ │ ├── noise20.sh
│ │ ├── noise40.sh
│ │ └── noise60.sh
│ │ ├── msassha
│ │ ├── noise20.sh
│ │ ├── noise40.sh
│ │ └── noise60.sh
│ │ ├── sam_sgd
│ │ ├── noise20.sh
│ │ ├── noise40.sh
│ │ └── noise60.sh
│ │ ├── sassha
│ │ ├── noise20.sh
│ │ ├── noise40.sh
│ │ └── noise60.sh
│ │ ├── sgd
│ │ ├── noise20.sh
│ │ ├── noise40.sh
│ │ └── noise60.sh
│ │ ├── shampoo
│ │ ├── noise20.sh
│ │ ├── noise40.sh
│ │ └── noise60.sh
│ │ └── sophiah
│ │ ├── noise20.sh
│ │ ├── noise40.sh
│ │ └── noise60.sh
├── models
│ └── resnet.py
├── train.py
└── utils.py
├── language_tasks
├── README.md
├── config
│ ├── adahessian
│ │ ├── mnli.sh
│ │ ├── mrpc.sh
│ │ ├── qnli.sh
│ │ ├── qqp.sh
│ │ ├── rte.sh
│ │ ├── sst2.sh
│ │ └── stsb.sh
│ ├── adamw
│ │ ├── mnli.sh
│ │ ├── mrpc.sh
│ │ ├── qnli.sh
│ │ ├── qqp.sh
│ │ ├── rte.sh
│ │ ├── sst2.sh
│ │ └── stsb.sh
│ ├── msassha
│ │ ├── mnli.sh
│ │ ├── mrpc.sh
│ │ ├── qnli.sh
│ │ ├── qqp.sh
│ │ ├── rte.sh
│ │ ├── sst2.sh
│ │ └── stsb.sh
│ ├── sam_adamw
│ │ ├── mnli.sh
│ │ ├── mrpc.sh
│ │ ├── qnli.sh
│ │ ├── qqp.sh
│ │ ├── rte.sh
│ │ ├── sst2.sh
│ │ └── stsb.sh
│ ├── sassha
│ │ ├── mnli.sh
│ │ ├── mrpc.sh
│ │ ├── qnli.sh
│ │ ├── qqp.sh
│ │ ├── rte.sh
│ │ ├── sst2.sh
│ │ └── stsb.sh
│ ├── sophiah
│ │ ├── mnli.sh
│ │ ├── mrpc.sh
│ │ ├── qnli.sh
│ │ ├── qqp.sh
│ │ ├── rte.sh
│ │ ├── sst2.sh
│ │ └── stsb.sh
│ └── wandb
│ │ └── rte_sophia.sh
└── finetune.py
├── optimizers
├── __init__.py
├── adahessian.py
├── hessian_scheduler.py
├── msassha.py
├── sam.py
├── sassha.py
├── shampoo.py
└── sophiaH.py
└── requirements.txt
/README.md:
--------------------------------------------------------------------------------
1 | # SASSHA: Sharpness-aware Adaptive Second-order Optimization with Stable Hessian Approximation
2 |
3 | This repository contains Pytorch source code for arXiv paper [SASSHA: Sharpness-aware Adaptive Second-order Optimization With Stable Hessian Approximation](https://arxiv.org/abs/2502.18153) by Dahun Shin*, [Dongyeop Lee](https://edong6768.github.io/)*, Jinseok Chung, and [Namhoon Lee](https://namhoonlee.github.io/).
4 |
5 | ## Introduction
6 |
7 | SASSHA is a novel second-order method designed to enhance generalization by explicitly reducing sharpness of the solution, while stabilizing the computation of approximate Hessians along the optimization trajectory.
8 |
9 | This Pytorch implementation supports various tasks, including image classification, finetuning, and label noise experiments.
10 |
11 | For a detailed explanation of the SASSHA algorithm, please refer to [our paper](https://arxiv.org/pdf/2502.18153).
12 |
13 |
14 | ## Getting Started
15 |
16 | First, clone our repository to your local system:
17 | ```bash
18 | git clone https://github.com/LOG-postech/Sassha.git
19 | cd Sassha
20 | ```
21 |
22 | We recommend using Anaconda to set up the environment and install all necessary dependencies:
23 |
24 | ```bash
25 | conda create -n "sassha" python=3.9
26 | conda activate sassha
27 | pip install -r requirements.txt
28 | ```
29 |
30 | Ensure you are using Python 3.9 or later.
31 |
32 | Navigate to the example folder of your choice. For instance, to run an image classification experiment:
33 |
34 | ```bash
35 | cd image_classification
36 | ```
37 |
38 | Now, train the model with the following command:
39 | ```bash
40 | python train.py --workers 4 --dataset imagenet -a resnet50 --epochs 90 -b 256 \
41 | --LRScheduler multi_step --lr-decay-epoch 30 60 --lr-decay 0.1 \
42 | --optimizer sassha \
43 | --lr 0.3 --wd 1e-4 --rho 0.2 --lazy_hessian 10 --seed 0 \
44 | --project_name sassha \
45 | {enter/your/imagenet-folder/with/train_and_val_data}
46 | ```
47 |
48 | Here, enter the path to imagenet datasets in `{enter/your/imagenet-folder/with/train_and_val_data}`.
49 |
50 | ### Distributed Training (Single node, multiple GPUs)
51 | SASSHA is fully compatible with multi-GPU environments for distributed training. Use the following command to train a model across multiple GPUs on a single node:
52 | ```bash
53 | python train.py --dist-url 'tcp://127.0.0.1:23456' --dist-backend 'nccl' --multiprocessing-distributed --world-size 1 --rank 0 \
54 | --workers 4 --dataset imagenet -a vit_b_32 --epochs 90 -b 1024 \
55 | --LRScheduler cosine --warmup_epochs 8 \
56 | --optimizer sassha \
57 | --lr 0.6 --wd 2e-4 --rho 0.25 --lazy_hessian 10 --eps 1e-6 --seed 0 \
58 | --project_name sassha \
59 | {enter/your/imagenet-folder/with/train_and_val_data}
60 | ```
61 | Ensure that NCCL is properly configured on your system and that your GPUs are available before running the script.
62 |
63 | ### Reproducing Paper Results
64 | Configurations used in [our paper](https://arxiv.org/pdf/2502.18153) are provided as shell scrips in each example folder.
65 |
66 | ### Environments
67 | - cuda 11.6.2
68 | - python 3.9
69 |
70 | ## General Usage
71 |
72 | SASSHA can be imported and used as follows:
73 |
74 | ```python
75 | from optimizers import SASSHA
76 |
77 | ...
78 |
79 | # Initialize your model and optimizer
80 | model = YourModel()
81 | optimizer = SASSHA(model.parameters(), ...)
82 |
83 | ...
84 |
85 | # training loop
86 | for input, output in data:
87 |
88 | # first forward-backward pass
89 | loss = loss_function(output, model(input))
90 | loss.backward()
91 | optimizer.perturb_weights(zero_grad=True)
92 |
93 | # second forward-backward pass
94 | loss_function(output, model(input)).backward(create_graph=True)
95 | optimizer.unperturb()
96 | optimizer.step()
97 | optimizer.zero_grad()
98 |
99 | ...
100 | ```
101 |
102 | ## Citation
103 | ```bibtex
104 | @article{shin2025sassha,
105 | title={SASSHA: Sharpness-aware Adaptive Second-order Optimization With Stable Hessian Approximation},
106 | author={Shin, Dahun and Lee, Dongyeop and Chung, Jinseok and Lee, Namhoon},
107 | journal={arXiv preprint arXiv:2502.18153},
108 | year={2025}
109 | }
110 | ```
--------------------------------------------------------------------------------
/image_classification/LRScheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.lr_scheduler import _LRScheduler
4 |
5 | class CosineAnnealingWarmupRestarts(_LRScheduler):
6 | """
7 | optimizer (Optimizer): Wrapped optimizer.
8 | first_cycle_steps (int): First cycle step size.
9 | cycle_mult(float): Cycle steps magnification. Default: -1.
10 | max_lr(float): First cycle's max learning rate. Default: 0.1.
11 | min_lr(float): Min learning rate. Default: 0.001.
12 | warmup_steps(int): Linear warmup step size. Default: 0.
13 | gamma(float): Decrease rate of max learning rate by cycle. Default: 1.
14 | last_epoch (int): The index of last epoch. Default: -1.
15 | """
16 |
17 | def __init__(self,
18 | optimizer : torch.optim.Optimizer,
19 | first_cycle_steps : int,
20 | cycle_mult : float = 1.,
21 | max_lr : float = 0.1,
22 | min_lr : float = 0.001,
23 | warmup_steps : int = 0,
24 | gamma : float = 1.,
25 | last_epoch : int = -1
26 | ):
27 | assert warmup_steps < first_cycle_steps
28 |
29 | self.first_cycle_steps = first_cycle_steps # first cycle step size
30 | self.cycle_mult = cycle_mult # cycle steps magnification
31 | self.base_max_lr = max_lr # first max learning rate
32 | self.max_lr = max_lr # max learning rate in the current cycle
33 | self.min_lr = min_lr # min learning rate
34 | self.warmup_steps = warmup_steps # warmup step size
35 | self.gamma = gamma # decrease rate of max learning rate by cycle
36 |
37 | self.cur_cycle_steps = first_cycle_steps # first cycle step size
38 | self.cycle = 0 # cycle count
39 | self.step_in_cycle = last_epoch # step size of the current cycle
40 |
41 | super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch)
42 |
43 | # set learning rate min_lr
44 | self.init_lr()
45 |
46 | def init_lr(self):
47 | self.base_lrs = []
48 | for param_group in self.optimizer.param_groups:
49 | param_group['lr'] = self.min_lr
50 | self.base_lrs.append(self.min_lr)
51 |
52 | def get_lr(self):
53 | if self.step_in_cycle == -1:
54 | return self.base_lrs
55 | elif self.step_in_cycle < self.warmup_steps:
56 | return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs]
57 | else:
58 | return [base_lr + (self.max_lr - base_lr) \
59 | * (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \
60 | / (self.cur_cycle_steps - self.warmup_steps))) / 2
61 | for base_lr in self.base_lrs]
62 |
63 | def step(self, epoch=None):
64 | if epoch is None:
65 | epoch = self.last_epoch + 1
66 | self.step_in_cycle = self.step_in_cycle + 1
67 | if self.step_in_cycle >= self.cur_cycle_steps:
68 | self.cycle += 1
69 | self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps
70 | self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps
71 | else:
72 | if epoch >= self.first_cycle_steps:
73 | if self.cycle_mult == 1.:
74 | self.step_in_cycle = epoch % self.first_cycle_steps
75 | self.cycle = epoch // self.first_cycle_steps
76 | else:
77 | n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult))
78 | self.cycle = n
79 | self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1))
80 | self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n)
81 | else:
82 | self.cur_cycle_steps = self.first_cycle_steps
83 | self.step_in_cycle = epoch
84 |
85 | self.max_lr = self.base_max_lr * (self.gamma**self.cycle)
86 | self.last_epoch = math.floor(epoch)
87 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
88 | param_group['lr'] = lr
--------------------------------------------------------------------------------
/image_classification/bypass_bn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.modules.batchnorm import _BatchNorm
4 |
5 |
6 | def disable_running_stats(model):
7 | def _disable(module):
8 | if isinstance(module, _BatchNorm):
9 | module.backup_momentum = module.momentum
10 | module.momentum = 0
11 |
12 | model.apply(_disable)
13 |
14 | def enable_running_stats(model):
15 | def _enable(module):
16 | if isinstance(module, _BatchNorm) and hasattr(module, "backup_momentum"):
17 | module.momentum = module.backup_momentum
18 |
19 | model.apply(_enable)
20 |
--------------------------------------------------------------------------------
/image_classification/config/cifar10/resnet20/adahessian.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar10 -a resnet20 --epochs 160 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \
3 | --optimizer adahessian \
4 | --lr 0.15 --wd 5e-4 --lazy_hessian 1 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/cifar10/resnet20/adamw.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar10 -a resnet20 --epochs 160 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \
3 | --optimizer adamw \
4 | --lr 0.01 --wd 5e-4 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/cifar10/resnet20/msassha.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar10 -a resnet20 --epochs 160 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \
3 | --optimizer msassha \
4 | --lr 0.3 --wd 1e-3 --rho 0.8 --lazy_hessian 10 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/cifar10/resnet20/sam_sgd.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar10 -a resnet20 --epochs 160 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \
3 | --optimizer samsgd \
4 | --lr 0.1 --wd 5e-4 --rho 0.15 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/cifar10/resnet20/sassha.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar10 -a resnet20 --epochs 160 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \
3 | --optimizer sassha \
4 | --hessian_power_scheduler constant \
5 | --lr 0.15 --min_lr 0.0015 --wd 5e-4 --rho 0.2 --lazy_hessian 10 --eps 1e-6 --seed 0, 1, 2 \
6 |
--------------------------------------------------------------------------------
/image_classification/config/cifar10/resnet20/sgd.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar10 -a resnet20 --epochs 160 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \
3 | --optimizer sgd \
4 | --lr 0.1 --wd 5e-4 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/cifar10/resnet20/shampoo.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar10 -a resnet20 --epochs 160 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \
3 | --optimizer shampoo \
4 | --lr 0.8 --wd 5e-3 --eps 1e-4 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/cifar10/resnet20/sophiah.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar10 -a resnet20 --epochs 160 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \
3 | --optimizer sophiah \
4 | --lr 1e-3 --wd 1e-3 --lazy_hessian 1 --clip_threshold 0.01 --eps 1e-4 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/cifar10/resnet32/adahessian.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar10 -a resnet32 --epochs 160 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \
3 | --optimizer adahessian \
4 | --lr 0.15 --wd 5e-4 --lazy_hessian 1 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/cifar10/resnet32/adamw.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar10 -a resnet32 --epochs 160 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \
3 | --optimizer adamw \
4 | --lr 0.01 --wd 5e-4 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/cifar10/resnet32/msassha.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar10 -a resnet32 --epochs 160 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \
3 | --optimizer msassha \
4 | --lr 0.15 --wd 1e-3 --rho 0.6 --lazy_hessian 10 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/cifar10/resnet32/sam_sgd.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar10 -a resnet32 --epochs 160 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \
3 | --optimizer samsgd \
4 | --lr 0.1 --wd 5e-4 --rho 0.15 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/cifar10/resnet32/sassha.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar10 -a resnet32 --epochs 160 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \
3 | --optimizer sassha \
4 | --hessian_power_scheduler constant \
5 | --lr 0.15 --min_lr 0.0015 --wd 5e-4 --rho 0.2 --lazy_hessian 10 --seed 0, 1, 2
6 |
--------------------------------------------------------------------------------
/image_classification/config/cifar10/resnet32/sgd.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar10 -a resnet32 --epochs 160 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \
3 | --optimizer sgd \
4 | --lr 0.1 --wd 5e-4 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/cifar10/resnet32/shampoo.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar10 -a resnet32 --epochs 160 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \
3 | --optimizer shampoo \
4 | --lr 0.4 --wd 1e-2 --eps 1e-4 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/cifar10/resnet32/sophiah.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar10 -a resnet32 --epochs 160 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \
3 | --optimizer sophiah \
4 | --lr 1e-3 --wd 2e-4 --lazy_hessian 1 --clip_threshold 0.1 --eps 1e-4 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/cifar100/resnet32/adahessian.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar100 -a resnet32 --epochs 160 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \
3 | --optimizer adahessian \
4 | --lr 0.15 --wd 5e-4 --lazy_hessian 1 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/cifar100/resnet32/adamw.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar100 -a resnet32 --epochs 160 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \
3 | --optimizer adamw \
4 | --lr 0.01 --wd 5e-4 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/cifar100/resnet32/msassha.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar100 -a resnet32 --epochs 160 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \
3 | --optimizer msassha \
4 | --lr 0.3 --wd 1e-3 --rho 0.6 --lazy_hessian 10 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/cifar100/resnet32/sam_sgd.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar100 -a resnet32 --epochs 160 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \
3 | --optimizer samsgd \
4 | --lr 0.1 --wd 5e-4 --rho 0.2 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/cifar100/resnet32/sassha.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar100 -a resnet32 --epochs 160 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \
3 | --optimizer sassha \
4 | --hessian_power_scheduler constant \
5 | --lr 0.3 --min_lr 0.003 --wd 1e-3 --rho 0.25 --lazy_hessian 10 --eps 1e-6 --seed 0, 1, 2 \
6 |
--------------------------------------------------------------------------------
/image_classification/config/cifar100/resnet32/sgd.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar100 -a resnet32 --epochs 160 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \
3 | --optimizer sgd \
4 | --lr 0.1 --wd 5e-4 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/cifar100/resnet32/shampoo.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar100 -a resnet32 --epochs 160 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \
3 | --optimizer shampoo \
4 | --lr 1 --wd 5e-3 --eps 1e-4 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/cifar100/resnet32/sophiah.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar100 -a resnet32 --epochs 160 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 80 120 --lr-decay 0.1 \
3 | --optimizer sophiah \
4 | --lr 1e-3 --wd 2e-4 --lazy_hessian 1 --clip_threshold 0.05 --eps 1e-4 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/cifar100/wideresnet/adahessian.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar100 -a wideresnet_28_10 --epochs 200 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 60 120 160 --lr-decay 0.2 \
3 | --optimizer adahessian \
4 | --lr 0.3 --wd 5e-4 --lazy_hessian 1 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/cifar100/wideresnet/adamw.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar100 -a wideresnet_28_10 --epochs 200 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 60 120 160 --lr-decay 0.2 \
3 | --optimizer adamw \
4 | --lr 1e-3 --wd 5e-4 --seed 0, 1, 2
--------------------------------------------------------------------------------
/image_classification/config/cifar100/wideresnet/msassha.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar100 -a wideresnet_28_10 --epochs 200 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 60 120 160 --lr-decay 0.2 \
3 | --optimizer msassha \
4 | --lr 0.15 --wd 1e-3 --rho 0.25 --lazy_hessian 10 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/cifar100/wideresnet/sam_sgd.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar100 -a wideresnet_28_10 --epochs 200 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 60 120 160 --lr-decay 0.2 \
3 | --optimizer samsgd \
4 | --lr 0.1 --wd 1e-3 --rho 0.15 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/cifar100/wideresnet/sassha.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar100 -a wideresnet_28_10 --epochs 200 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 60 120 160 --lr-decay 0.2 \
3 | --optimizer sassha \
4 | --hessian_power_scheduler constant \
5 | --lr 0.15 --min_lr 0.0015 --wd 0.0015 --rho 0.2 --lazy_hessian 10 --seed 0, 1, 2
--------------------------------------------------------------------------------
/image_classification/config/cifar100/wideresnet/sgd.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar100 -a wideresnet_28_10 --epochs 200 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 60 120 160 --lr-decay 0.2 \
3 | --optimizer sgd \
4 | --lr 0.1 --wd 5e-4 --seed 0, 1, 2
--------------------------------------------------------------------------------
/image_classification/config/cifar100/wideresnet/shampoo.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar100 -a wideresnet_28_10 --epochs 200 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 60 120 160 --lr-decay 0.2 \
3 | --optimizer shampoo \
4 | --lr 1 --wd 5e-3 --eps 1e-4 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/cifar100/wideresnet/sophiah.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 2 --dataset cifar100 -a wideresnet_28_10 --epochs 200 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 60 120 160 --lr-decay 0.2 \
3 | --optimizer sophiah \
4 | --lr 1e-3 --wd 1e-3 --lazy_hessian 1 --clip_threshold 0.01 --eps 1e-4 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/image_classification/config/imagenet/resnet50/adahessian.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 4 --dataset imagenet -a resnet50 --epochs 90 -b 256 \
2 | --LRScheduler plateau \
3 | --optimizer adahessian \
4 | --lr 0.15 --wd 1e-4 --lazy_hessian 1 --seed 0, 1, 2 \
5 | /home/shared/dataset/imagenet
--------------------------------------------------------------------------------
/image_classification/config/imagenet/resnet50/adamw.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 4 --dataset imagenet -a resnet50 --epochs 90 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 30 60 --lr-decay 0.1 \
3 | --optimizer adamw \
4 | --lr 1e-3 --wd 1e-4 --seed 0, 1, 2 \
5 | /home/shared/dataset/imagenet
--------------------------------------------------------------------------------
/image_classification/config/imagenet/resnet50/msassha.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 4 --dataset imagenet -a resnet50 --epochs 90 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 30 60 --lr-decay 0.1 \
3 | --optimizer msassha \
4 | --lr 0.15 --wd 1e-4 --rho 0.1 --lazy_hessian 10 --seed 0, 1, 2 \
5 | /home/shared/dataset/imagenet
--------------------------------------------------------------------------------
/image_classification/config/imagenet/resnet50/sam_sgd.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 4 --dataset imagenet -a resnet50 --epochs 90 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 30 60 --lr-decay 0.1 \
3 | --optimizer samsgd \
4 | --lr 0.1 --wd 1e-4 --rho 0.1 --seed 0, 1, 2 \
5 | /home/shared/dataset/imagenet
--------------------------------------------------------------------------------
/image_classification/config/imagenet/resnet50/sassha.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 4 --dataset imagenet -a resnet50 --epochs 90 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 30 60 --lr-decay 0.1 \
3 | --optimizer sassha \
4 | --hessian_power_scheduler constant \
5 | --lr 0.3 --min_lr 0.003 --wd 1e-4 --rho 0.2 --lazy_hessian 10 --seed 0, 1, 2 \
6 | /home/shared/dataset/imagenet
--------------------------------------------------------------------------------
/image_classification/config/imagenet/resnet50/sgd.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 4 --dataset imagenet -a resnet50 --epochs 90 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 30 60 --lr-decay 0.1 \
3 | --optimizer sgd \
4 | --lr 0.1 --wd 1e-4 --seed 0, 1, 2 \
5 | /home/shared/dataset/imagenet
--------------------------------------------------------------------------------
/image_classification/config/imagenet/resnet50/sophiah.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 4 --dataset imagenet -a resnet50 --epochs 90 -b 256 \
2 | --LRScheduler multi_step --lr-decay-epoch 30 60 --lr-decay 0.1 \
3 | --optimizer sophiah \
4 | --lr 1e-2 --wd 1e-4 --lazy_hessian 1 --clip_threshold 0.1 --eps 1e-4 --seed 0, 1, 2 \
5 | /home/shared/dataset/imagenet
--------------------------------------------------------------------------------
/image_classification/config/imagenet/vit_small/adahessian.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 4 --dataset imagenet -a vit_s_32 --epochs 90 -b 1024 \
2 | --LRScheduler cosine --warmup_epochs 8 \
3 | --optimizer adahessian \
4 | --lr 0.15 --wd 1e-4 --lazy_hessian 1 --seed 0, 1, 2 \
5 | /home/shared/dataset/imagenet
--------------------------------------------------------------------------------
/image_classification/config/imagenet/vit_small/adamw.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 4 --dataset imagenet -a vit_s_32 --epochs 90 -b 1024 \
2 | --LRScheduler cosine --warmup_epochs 8 \
3 | --optimizer adamw \
4 | --lr 1e-3 --wd 1e-4 --grad_clip 1 --seed 0, 1, 2 \
5 | /home/shared/dataset/imagenet
--------------------------------------------------------------------------------
/image_classification/config/imagenet/vit_small/msassha.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 4 --dataset imagenet -a vit_s_32 --epochs 90 -b 1024 \
2 | --LRScheduler cosine --warmup_epochs 8 \
3 | --optimizer msassha \
4 | --lr 0.6 --wd 4e-4 --rho 0.4 --lazy_hessian 10 --eps 1e-6 --seed 0, 1, 2 \
5 | /home/shared/dataset/imagenet
--------------------------------------------------------------------------------
/image_classification/config/imagenet/vit_small/sam_adamw.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 4 --dataset imagenet -a vit_s_32 --epochs 90 -b 1024 \
2 | --LRScheduler cosine --warmup_epochs 8 \
3 | --optimizer samadamw \
4 | --lr 1e-3 --wd 1e-4 --rho 0.2 --grad_clip 1 --seed 0, 1, 2 \
5 | /home/shared/dataset/imagenet
--------------------------------------------------------------------------------
/image_classification/config/imagenet/vit_small/sassha.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 4 --dataset imagenet -a vit_s_32 --epochs 90 -b 1024 \
2 | --LRScheduler cosine --warmup_epochs 8 \
3 | --optimizer sassha \
4 | --hessian_power_scheduler constant \
5 | --lr 0.6 --min_lr 0.0011981 --wd 2e-4 --rho 0.25 --lazy_hessian 10 --eps 1e-6 --seed 0, 1, 2 \
6 | /home/shared/dataset/imagenet
--------------------------------------------------------------------------------
/image_classification/config/imagenet/vit_small/sgd.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 4 --dataset imagenet -a vit_s_32 --epochs 90 -b 1024 \
2 | --LRScheduler cosine --warmup_epochs 8 \
3 | --optimizer sgd \
4 | --lr 0.1 --wd 1e-4 --seed 0, 1, 2 \
5 | /home/shared/dataset/imagenet
--------------------------------------------------------------------------------
/image_classification/config/imagenet/vit_small/sophiah.sh:
--------------------------------------------------------------------------------
1 | python train.py --workers 4 --dataset imagenet -a vit_s_32 --epochs 90 -b 1024 \
2 | --LRScheduler cosine --warmup_epochs 8 \
3 | --optimizer sophiah \
4 | --lr 1e-3 --wd 1e-4 --lazy_hessian 1 --clip_threshold 0.01 --eps 1e-4 --seed 0, 1, 2 \
5 | /home/shared/dataset/imagenet
--------------------------------------------------------------------------------
/image_classification/models/__init__.py:
--------------------------------------------------------------------------------
1 | from models.simple_vit import SimpleViT
2 | from models.resnet import *
3 | from models.wide_resnet import *
4 |
5 | import re
6 | from torchvision import models
7 |
8 | def create_vit_model(arch):
9 | """Create a Vision Transformer (ViT) model."""
10 | patch_size = int(arch.split('_')[2])
11 | if arch.startswith('vit_s_'):
12 | return SimpleViT(
13 | image_size=224,
14 | patch_size=patch_size,
15 | num_classes=1000,
16 | dim=384,
17 | depth=12,
18 | heads=6,
19 | mlp_dim=1536
20 | )
21 | elif arch.startswith('vit_b_'):
22 | return SimpleViT(
23 | image_size=224,
24 | patch_size=patch_size,
25 | num_classes=1000,
26 | dim=768,
27 | depth=12,
28 | heads=12,
29 | mlp_dim=3072
30 | )
31 | elif arch.startswith('vit_l_'):
32 | return SimpleViT(
33 | image_size=224,
34 | patch_size=patch_size,
35 | num_classes=1000,
36 | dim=1024,
37 | depth=24,
38 | heads=16,
39 | mlp_dim=4096
40 | )
41 | else:
42 | raise ValueError(f"Unknown ViT architecture: {arch}")
43 |
44 | def create_resnet_model(arch, dataset):
45 | """Create a ResNet model."""
46 | num_classes = 1000 if dataset == 'imagenet' else int(re.findall(r'\d+', dataset)[-1])
47 | depth = int(re.findall(r'\d+', arch)[-1])
48 | return resnet(num_classes=num_classes, depth=depth)
49 |
50 | def create_wideresnet_model(arch, dataset):
51 | """Create a Wide ResNet model."""
52 | num_classes = 1000 if dataset == 'imagenet' else int(re.findall(r'\d+', dataset)[-1])
53 | parts = arch.split('_')
54 | depth = int(parts[-2])
55 | widen_factor = int(parts[-1])
56 | return Wide_ResNet(depth=depth, widen_factor=widen_factor, dropout_rate=0.0, num_classes=num_classes)
57 |
58 | def get_model(args):
59 | """Main function to create a model based on the architecture."""
60 |
61 | if args.arch.startswith('vit_'):
62 | return create_vit_model(args.arch)
63 | elif args.arch in ['resnet20', 'resnet32']:
64 | return create_resnet_model(args.arch, args.dataset)
65 | elif args.arch.startswith('wideresnet_'):
66 | return create_wideresnet_model(args.arch, args.dataset)
67 | else:
68 | # Default case: PyTorch model loaded directly
69 | return models.__dict__[args.arch]()
70 |
--------------------------------------------------------------------------------
/image_classification/models/resnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | '''Resnet for cifar dataset.
4 | Ported form
5 | https://github.com/facebook/fb.resnet.torch
6 | and
7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
8 | (c) YANG, Wei
9 | '''
10 | import torch.nn as nn
11 | import math
12 | from copy import deepcopy
13 |
14 | __all__ = ['resnet']
15 |
16 |
17 | def conv3x3(in_planes, out_planes, stride=1):
18 | "3x3 convolution with padding"
19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
20 | padding=1, bias=False)
21 |
22 |
23 | class BasicBlock(nn.Module):
24 | expansion = 1
25 |
26 | def __init__(
27 | self,
28 | inplanes,
29 | planes,
30 | residual_not,
31 | batch_norm_not,
32 | stride=1,
33 | downsample=None):
34 | super(BasicBlock, self).__init__()
35 | self.residual_not = residual_not
36 | self.batch_norm_not = batch_norm_not
37 | self.conv1 = conv3x3(inplanes, planes, stride)
38 | if self.batch_norm_not:
39 | self.bn1 = nn.BatchNorm2d(planes)
40 | self.relu = nn.ReLU(inplace=True)
41 | self.conv2 = conv3x3(planes, planes)
42 | if self.batch_norm_not:
43 | self.bn2 = nn.BatchNorm2d(planes)
44 | self.downsample = downsample
45 | self.stride = stride
46 |
47 | def forward(self, x):
48 | residual = x
49 |
50 | out = self.conv1(x)
51 | if self.batch_norm_not:
52 | out = self.bn1(out)
53 | out = self.relu(out)
54 |
55 | out = self.conv2(out)
56 | if self.batch_norm_not:
57 | out = self.bn2(out)
58 |
59 | if self.downsample is not None:
60 | residual = self.downsample(x)
61 | if self.residual_not:
62 | out += residual
63 | out = self.relu(out)
64 |
65 | return out
66 |
67 |
68 | class Bottleneck(nn.Module):
69 | expansion = 4
70 |
71 | def __init__(
72 | self,
73 | inplanes,
74 | planes,
75 | residual_not,
76 | batch_norm_not,
77 | stride=1,
78 | downsample=None):
79 | super(Bottleneck, self).__init__()
80 | self.residual_not = residual_not
81 | self.batch_norm_not = batch_norm_not
82 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
83 | if self.batch_norm_not:
84 | self.bn1 = nn.BatchNorm2d(planes)
85 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
86 | padding=1, bias=False)
87 | if self.batch_norm_not:
88 | self.bn2 = nn.BatchNorm2d(planes)
89 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
90 | if self.batch_norm_not:
91 | self.bn3 = nn.BatchNorm2d(planes * 4)
92 | self.relu = nn.ReLU(inplace=True)
93 | self.downsample = downsample
94 | self.stride = stride
95 |
96 | def forward(self, x):
97 | residual = x
98 |
99 | out = self.conv1(x)
100 | if self.batch_norm_not:
101 | out = self.bn1(out)
102 | out = self.relu(out)
103 |
104 | out = self.conv2(out)
105 | if self.batch_norm_not:
106 | out = self.bn2(out)
107 | out = self.relu(out)
108 |
109 | out = self.conv3(out)
110 | if self.batch_norm_not:
111 | out = self.bn3(out)
112 |
113 | if self.downsample is not None:
114 | residual = self.downsample(x)
115 | if self.residual_not:
116 | out += residual
117 |
118 | out = self.relu(out)
119 |
120 | return out
121 |
122 |
123 | ALPHA_ = 1
124 |
125 |
126 | class ResNet(nn.Module):
127 |
128 | def __init__(
129 | self,
130 | depth,
131 | residual_not=True,
132 | batch_norm_not=True,
133 | base_channel=16,
134 | num_classes=10):
135 | super(ResNet, self).__init__()
136 | # Model type specifies number of layers for CIFAR-10 model
137 | assert (depth - 2) % 6 == 0, 'depth should be 6n+2'
138 | n = (depth - 2) // 6
139 |
140 | # block = Bottleneck if depth >=44 else BasicBlock
141 | block = BasicBlock
142 |
143 | self.base_channel = int(base_channel)
144 | self.residual_not = residual_not
145 | self.batch_norm_not = batch_norm_not
146 | self.inplanes = self.base_channel * ALPHA_
147 | self.conv1 = nn.Conv2d(
148 | 3,
149 | self.base_channel *
150 | ALPHA_,
151 | kernel_size=3,
152 | padding=1,
153 | bias=False)
154 | if self.batch_norm_not:
155 | self.bn1 = nn.BatchNorm2d(self.base_channel * ALPHA_)
156 | self.relu = nn.ReLU(inplace=True)
157 | self.layer1 = self._make_layer(
158 | block,
159 | self.base_channel *
160 | ALPHA_,
161 | n,
162 | self.residual_not,
163 | self.batch_norm_not)
164 | self.layer2 = self._make_layer(
165 | block,
166 | self.base_channel *
167 | 2 *
168 | ALPHA_,
169 | n,
170 | self.residual_not,
171 | self.batch_norm_not,
172 | stride=2)
173 | self.layer3 = self._make_layer(
174 | block,
175 | self.base_channel *
176 | 4 *
177 | ALPHA_,
178 | n,
179 | self.residual_not,
180 | self.batch_norm_not,
181 | stride=2)
182 | self.avgpool = nn.AvgPool2d(8)
183 | self.fc = nn.Linear(
184 | self.base_channel *
185 | 4 *
186 | ALPHA_ *
187 | block.expansion,
188 | num_classes)
189 |
190 | for m in self.modules():
191 | if isinstance(m, nn.Conv2d):
192 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
193 | m.weight.data.normal_(0, math.sqrt(2. / n))
194 | elif isinstance(m, nn.BatchNorm2d):
195 | m.weight.data.fill_(1)
196 | m.bias.data.zero_()
197 |
198 | def _make_layer(
199 | self,
200 | block,
201 | planes,
202 | blocks,
203 | residual_not,
204 | batch_norm_not,
205 | stride=1):
206 | downsample = None
207 | if (stride != 1 or self.inplanes != planes *
208 | block.expansion) and (residual_not):
209 | if batch_norm_not:
210 | downsample = nn.Sequential(
211 | nn.Conv2d(self.inplanes, planes * block.expansion,
212 | kernel_size=1, stride=stride, bias=False),
213 | nn.BatchNorm2d(planes * block.expansion),
214 | )
215 | else:
216 | downsample = nn.Sequential(
217 | nn.Conv2d(self.inplanes, planes * block.expansion,
218 | kernel_size=1, stride=stride, bias=False),
219 | )
220 |
221 | layers = nn.ModuleList()
222 | layers.append(
223 | block(
224 | self.inplanes,
225 | planes,
226 | residual_not,
227 | batch_norm_not,
228 | stride,
229 | downsample))
230 | self.inplanes = planes * block.expansion
231 | for i in range(1, blocks):
232 | layers.append(
233 | block(
234 | self.inplanes,
235 | planes,
236 | residual_not,
237 | batch_norm_not))
238 |
239 | # return nn.Sequential(*layers)
240 | return layers
241 |
242 | def forward(self, x):
243 | output_list = []
244 | x = self.conv1(x)
245 | if self.batch_norm_not:
246 | x = self.bn1(x)
247 | x = self.relu(x) # 32x32
248 | output_list.append(x.view(x.size(0), -1))
249 |
250 | for layer in self.layer1:
251 | x = layer(x) # 32x32
252 | output_list.append(x.view(x.size(0), -1))
253 | for layer in self.layer2:
254 | x = layer(x) # 16x16
255 | output_list.append(x.view(x.size(0), -1))
256 | for layer in self.layer3:
257 | x = layer(x) # 8x8
258 | output_list.append(x.view(x.size(0), -1))
259 |
260 | x = self.avgpool(x)
261 | x = x.view(x.size(0), -1)
262 | x = self.fc(x)
263 | output_list.append(x.view(x.size(0), -1))
264 |
265 | # return output_list, x
266 | return x
267 |
268 |
269 | def resnet(**kwargs):
270 | """
271 | Constructs a ResNet model.
272 | """
273 | return ResNet(**kwargs)
274 |
--------------------------------------------------------------------------------
/image_classification/models/simple_vit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from einops import rearrange
5 | from einops.layers.torch import Rearrange
6 |
7 | # helpers
8 |
9 | def pair(t):
10 | return t if isinstance(t, tuple) else (t, t)
11 |
12 | def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32): # 7, 7, 384
13 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
14 | assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
15 | omega = torch.arange(dim // 4) / (dim // 4 - 1)
16 | omega = 1.0 / (temperature ** omega)
17 |
18 | y = y.flatten()[:, None] * omega[None, :]
19 | x = x.flatten()[:, None] * omega[None, :]
20 | pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
21 | return pe.type(dtype)
22 |
23 | # classes
24 |
25 | class FeedForward(nn.Module):
26 | def __init__(self, dim, hidden_dim):
27 | super().__init__()
28 | self.net = nn.Sequential(
29 | nn.LayerNorm(dim),
30 | nn.Linear(dim, hidden_dim),
31 | nn.GELU(),
32 | nn.Linear(hidden_dim, dim),
33 | )
34 | def forward(self, x):
35 | return self.net(x)
36 |
37 | class Attention(nn.Module):
38 | def __init__(self, dim, heads = 8, dim_head = 64):
39 | super().__init__()
40 | inner_dim = dim_head * heads
41 | self.heads = heads
42 | self.scale = dim_head ** -0.5
43 | self.norm = nn.LayerNorm(dim)
44 |
45 | self.attend = nn.Softmax(dim = -1)
46 |
47 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
48 | self.to_out = nn.Linear(inner_dim, dim, bias = False)
49 |
50 | def forward(self, x):
51 | x = self.norm(x)
52 |
53 | qkv = self.to_qkv(x).chunk(3, dim = -1)
54 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
55 |
56 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
57 |
58 | attn = self.attend(dots)
59 |
60 | out = torch.matmul(attn, v)
61 | out = rearrange(out, 'b h n d -> b n (h d)')
62 | return self.to_out(out)
63 |
64 | class Transformer(nn.Module):
65 | def __init__(self, dim, depth, heads, dim_head, mlp_dim):
66 | super().__init__()
67 | self.norm = nn.LayerNorm(dim) # MSAM (mean -> norm -> linear) This (norm > mean > Linear)
68 | self.layers = nn.ModuleList([])
69 | for _ in range(depth):
70 | self.layers.append(nn.ModuleList([
71 | Attention(dim, heads = heads, dim_head = dim_head),
72 | FeedForward(dim, mlp_dim)
73 | ]))
74 | def forward(self, x):
75 | for attn, ff in self.layers:
76 | x = attn(x) + x
77 | x = ff(x) + x
78 | return self.norm(x)
79 |
80 | class SimpleViT(nn.Module):
81 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
82 | super().__init__()
83 | image_height, image_width = pair(image_size)
84 | patch_height, patch_width = pair(patch_size)
85 |
86 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
87 |
88 | patch_dim = channels * patch_height * patch_width # patch_dim = 3 * 32 * 32 = 3072
89 |
90 | self.to_patch_embedding = nn.Sequential(
91 | Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width), # b 3 (7 32) (7 32) -> b 49 3072
92 | nn.LayerNorm(patch_dim),
93 | nn.Linear(patch_dim, dim), # dim = 384 (b 49 384)
94 | nn.LayerNorm(dim),
95 | )
96 |
97 | self.pos_embedding = posemb_sincos_2d(
98 | h = image_height // patch_height, # 224 // 32 = 7
99 | w = image_width // patch_width, # 224 // 32 = 7
100 | dim = dim, # 384
101 | )
102 |
103 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
104 |
105 | self.pool = "mean"
106 | self.to_latent = nn.Identity()
107 |
108 | self.linear_head = nn.Linear(dim, num_classes)
109 |
110 | def forward(self, img):
111 | device = img.device
112 |
113 | x = self.to_patch_embedding(img) # (b 49 384)
114 | x += self.pos_embedding.to(device, dtype=x.dtype)
115 |
116 | x = self.transformer(x)
117 | x = x.mean(dim = 1)
118 |
119 | x = self.to_latent(x)
120 | return self.linear_head(x)
121 |
--------------------------------------------------------------------------------
/image_classification/models/wide_resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 |
7 | import sys
8 | import numpy as np
9 |
10 | def conv3x3(in_planes, out_planes, stride=1):
11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
12 |
13 | def conv_init(m):
14 | classname = m.__class__.__name__
15 | if classname.find('Conv') != -1:
16 | init.xavier_uniform_(m.weight, gain=np.sqrt(2))
17 | init.constant_(m.bias, 0)
18 | elif classname.find('BatchNorm') != -1:
19 | init.constant_(m.weight, 1)
20 | init.constant_(m.bias, 0)
21 |
22 | class wide_basic(nn.Module):
23 | def __init__(self, in_planes, planes, dropout_rate, stride=1):
24 | super(wide_basic, self).__init__()
25 | self.bn1 = nn.BatchNorm2d(in_planes)
26 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
27 | self.dropout = nn.Dropout(p=dropout_rate)
28 | self.bn2 = nn.BatchNorm2d(planes)
29 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
30 |
31 | self.shortcut = nn.Sequential()
32 | if stride != 1 or in_planes != planes:
33 | self.shortcut = nn.Sequential(
34 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
35 | )
36 |
37 | def forward(self, x):
38 | out = self.dropout(self.conv1(F.relu(self.bn1(x))))
39 | out = self.conv2(F.relu(self.bn2(out)))
40 | out += self.shortcut(x)
41 |
42 | return out
43 |
44 | class Wide_ResNet(nn.Module):
45 | def __init__(self, depth, widen_factor, dropout_rate, num_classes):
46 | super(Wide_ResNet, self).__init__()
47 | self.in_planes = 16
48 |
49 | assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4'
50 | n = (depth-4)/6
51 | k = widen_factor
52 |
53 | print('| Wide-Resnet %dx%d' %(depth, k))
54 | nStages = [16, 16*k, 32*k, 64*k]
55 |
56 | self.conv1 = conv3x3(3,nStages[0])
57 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1)
58 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2)
59 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2)
60 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9)
61 | self.linear = nn.Linear(nStages[3], num_classes)
62 |
63 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
64 | strides = [stride] + [1]*(int(num_blocks)-1)
65 | layers = []
66 |
67 | for stride in strides:
68 | layers.append(block(self.in_planes, planes, dropout_rate, stride))
69 | self.in_planes = planes
70 |
71 | return nn.Sequential(*layers)
72 |
73 | def forward(self, x):
74 | out = self.conv1(x)
75 | out = self.layer1(out)
76 | out = self.layer2(out)
77 | out = self.layer3(out)
78 | out = F.relu(self.bn1(out))
79 | out = F.avg_pool2d(out, 8)
80 | out = out.view(out.size(0), -1)
81 | out = self.linear(out)
82 |
83 | return out
84 |
85 | if __name__ == '__main__':
86 | net=Wide_ResNet(28, 10, 0.3, 10)
87 | y = net(Variable(torch.randn(1,3,32,32)))
88 |
89 | print(y.size())
90 |
--------------------------------------------------------------------------------
/image_classification/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.optim as optim
8 | from torchvision import datasets, transforms
9 | from torch.utils.data import Dataset, DataLoader
10 | from torch.autograd import Variable
11 |
12 |
13 | def getData(
14 | name='cifar10',
15 | path='/home/dahunshin/imagenet',
16 | train_bs=256,
17 | test_bs=1000,
18 | num_workers=1,
19 | distributed=False):
20 |
21 | if name == 'mnist':
22 |
23 | train_loader = torch.utils.data.DataLoader(
24 | datasets.MNIST('../data', train=True, download=True,
25 | transform=transforms.Compose([
26 | transforms.ToTensor(),
27 | transforms.Normalize((0.1307,), (0.3081,))
28 | ])),
29 | batch_size=train_bs, shuffle=True)
30 | test_loader = torch.utils.data.DataLoader(
31 | datasets.MNIST(
32 | '../data',
33 | train=False,
34 | transform=transforms.Compose(
35 | [
36 | transforms.ToTensor(),
37 | transforms.Normalize(
38 | (0.1307,
39 | ),
40 | (0.3081,
41 | ))])),
42 | batch_size=test_bs,
43 | shuffle=False)
44 |
45 |
46 | if name == 'cifar10':
47 | transform_train = transforms.Compose([
48 | transforms.RandomCrop(32, padding=4),
49 | transforms.RandomHorizontalFlip(),
50 | transforms.ToTensor(),
51 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
52 | ])
53 |
54 | transform_test = transforms.Compose([
55 | transforms.ToTensor(),
56 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
57 | ])
58 |
59 | trainset = datasets.CIFAR10(
60 | root='../data',
61 | train=True,
62 | download=True,
63 | transform=transform_train)
64 |
65 | testset = datasets.CIFAR10(
66 | root='../data',
67 | train=False,
68 | download=False,
69 | transform=transform_test)
70 |
71 | if distributed:
72 | train_sampler = torch.utils.data.distributed.DistributedSampler(trainset, shuffle=True, drop_last=False)
73 | test_sampler = torch.utils.data.distributed.DistributedSampler(testset, shuffle=False, drop_last=True)
74 | else:
75 | train_sampler = None
76 | test_sampler = None
77 |
78 | train_loader = torch.utils.data.DataLoader(
79 | trainset, batch_size=train_bs, shuffle=(train_sampler is None),
80 | num_workers=num_workers, pin_memory=True, sampler=train_sampler)
81 |
82 | test_loader = torch.utils.data.DataLoader(
83 | testset, batch_size=test_bs, shuffle=False,
84 | num_workers=num_workers, pin_memory=True, sampler=test_sampler)
85 |
86 |
87 | if name == 'cifar100':
88 |
89 | transform_train = transforms.Compose([
90 | transforms.RandomCrop(32, padding=4),
91 | transforms.RandomHorizontalFlip(),
92 | transforms.ToTensor(),
93 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
94 | ])
95 |
96 | transform_test = transforms.Compose([
97 | transforms.ToTensor(),
98 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
99 | ])
100 |
101 | trainset = datasets.CIFAR100(
102 | root='../data',
103 | train=True,
104 | download=True,
105 | transform=transform_train)
106 |
107 | testset = datasets.CIFAR100(
108 | root='../data',
109 | train=False,
110 | download=False,
111 | transform=transform_test)
112 |
113 | if distributed:
114 | train_sampler = torch.utils.data.distributed.DistributedSampler(trainset, shuffle=True, drop_last=False)
115 | test_sampler = torch.utils.data.distributed.DistributedSampler(testset, shuffle=False, drop_last=True)
116 | else:
117 | train_sampler = None
118 | test_sampler = None
119 |
120 | train_loader = torch.utils.data.DataLoader(
121 | trainset, batch_size=train_bs, shuffle=(train_sampler is None),
122 | num_workers=num_workers, pin_memory=True, sampler=train_sampler)
123 |
124 | test_loader = torch.utils.data.DataLoader(
125 | testset, batch_size=test_bs, shuffle=False,
126 | num_workers=num_workers, pin_memory=True, sampler=test_sampler)
127 |
128 |
129 | if name == 'imagenet':
130 | traindir = os.path.join(path, 'train')
131 | valdir = os.path.join(path, 'val')
132 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
133 | std=[0.229, 0.224, 0.225])
134 |
135 | train_dataset = datasets.ImageFolder(
136 | traindir,
137 | transforms.Compose([
138 | transforms.RandomResizedCrop(224),
139 | transforms.RandomHorizontalFlip(),
140 | transforms.ToTensor(),
141 | normalize,
142 | ]))
143 |
144 | val_dataset = datasets.ImageFolder(
145 | valdir,
146 | transforms.Compose([
147 | transforms.Resize(256),
148 | transforms.CenterCrop(224),
149 | transforms.ToTensor(),
150 | normalize,
151 | ]))
152 |
153 | if distributed:
154 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=False)
155 | test_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
156 | else:
157 | train_sampler = None
158 | test_sampler = None
159 |
160 | train_loader = torch.utils.data.DataLoader(
161 | train_dataset, batch_size=train_bs, shuffle=(train_sampler is None),
162 | num_workers=num_workers, pin_memory=True, sampler=train_sampler)
163 |
164 | test_loader = torch.utils.data.DataLoader(
165 | val_dataset, batch_size=test_bs, shuffle=False,
166 | num_workers=num_workers, pin_memory=True, sampler=test_sampler)
167 |
168 |
169 | return train_loader, test_loader, train_sampler, test_sampler
170 |
--------------------------------------------------------------------------------
/label_noise/bypass_bn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.modules.batchnorm import _BatchNorm
4 |
5 |
6 | def disable_running_stats(model):
7 | def _disable(module):
8 | if isinstance(module, _BatchNorm):
9 | module.backup_momentum = module.momentum
10 | module.momentum = 0
11 |
12 | model.apply(_disable)
13 |
14 | def enable_running_stats(model):
15 | def _enable(module):
16 | if isinstance(module, _BatchNorm) and hasattr(module, "backup_momentum"):
17 | module.momentum = module.backup_momentum
18 |
19 | model.apply(_enable)
20 |
--------------------------------------------------------------------------------
/label_noise/config/cifar10/adahessian/noise20.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer adahessian \
3 | --noise_level 0.2 \
4 | --lr 0.1 --wd 5e-4 --lazy_hessian 1 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar10/adahessian/noise40.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer adahessian \
3 | --noise_level 0.4 \
4 | --lr 0.1 --wd 5e-4 --lazy_hessian 1 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar10/adahessian/noise60.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer adahessian \
3 | --noise_level 0.6 \
4 | --lr 0.05 --wd 5e-4 --lazy_hessian 1 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar10/msassha/noise20.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer msassha \
3 | --noise_level 0.2 \
4 | --lr 0.15 --wd 1e-3 --rho 0.8 --lazy_hessian 10 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar10/msassha/noise40.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer msassha \
3 | --noise_level 0.4 \
4 | --lr 0.05 --wd 1e-3 --rho 0.25 --lazy_hessian 10 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar10/msassha/noise60.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer msassha \
3 | --noise_level 0.6 \
4 | --lr 0.05 --wd 1e-3 --rho 0.25 --lazy_hessian 10 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar10/sam_sgd/noise20.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer samsgd \
3 | --noise_level 0.2 \
4 | --lr 0.1 --wd 1e-3 --rho 0.2 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar10/sam_sgd/noise40.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer samsgd \
3 | --noise_level 0.4 \
4 | --lr 0.1 --wd 5e-4 --rho 0.2 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar10/sam_sgd/noise60.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer samsgd \
3 | --noise_level 0.6 \
4 | --lr 0.03 --wd 1e-3 --rho 0.1 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar10/sassha/noise20.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer sassha \
3 | --noise_level 0.2 \
4 | --hessian_power_scheduler constant \
5 | --lr 0.1 --min_lr 0.001 --wd 1e-3 --rho 0.1 --lazy_hessian 10 --seed 0, 1, 2 \
6 |
--------------------------------------------------------------------------------
/label_noise/config/cifar10/sassha/noise40.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer sassha \
3 | --noise_level 0.4 \
4 | --hessian_power_scheduler constant \
5 | --lr 0.1 --min_lr 0.001 --wd 5e-4 --rho 0.2 --lazy_hessian 10 --seed 0, 1, 2 \
6 |
--------------------------------------------------------------------------------
/label_noise/config/cifar10/sassha/noise60.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer sassha \
3 | --noise_level 0.6 \
4 | --hessian_power_scheduler constant \
5 | --lr 0.1 --min_lr 0.001 --wd 5e-4 --rho 0.2 --lazy_hessian 10 --seed 0, 1, 2 \
6 |
--------------------------------------------------------------------------------
/label_noise/config/cifar10/sgd/noise20.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer sgd \
3 | --noise_level 0.2 \
4 | --lr 0.015 --wd 2e-3 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar10/sgd/noise40.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer sgd \
3 | --noise_level 0.4 \
4 | --lr 0.015 --wd 2e-3 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar10/sgd/noise60.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer sgd \
3 | --noise_level 0.6 \
4 | --lr 0.05 --wd 1e-3 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar10/shampoo/noise20.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer shampoo \
3 | --noise_level 0.2 \
4 | --lr 0.6 --wd 5e-3 --eps 1e-2 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar10/shampoo/noise40.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer shampoo \
3 | --noise_level 0.4 \
4 | --lr 0.15 --wd 5e-2 --eps 1e-2 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar10/shampoo/noise60.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer shampoo \
3 | --noise_level 0.6 \
4 | --lr 0.15 --wd 5e-2 --eps 1e-2 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar10/sophiah/noise20.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer sophiah \
3 | --noise_level 0.2 \
4 | --lr 1e-3 --wd 5e-4 --clip_threshold 0.1 --eps 1e-4 --lazy_hessian 1 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar10/sophiah/noise40.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer sophiah \
3 | --noise_level 0.4 \
4 | --lr 1e-3 --wd 5e-4 --clip_threshold 0.1 --eps 1e-4 --lazy_hessian 1 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar10/sophiah/noise60.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar10 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer sophiah \
3 | --noise_level 0.6 \
4 | --lr 1e-3 --wd 5e-4 --clip_threshold 0.1 --eps 1e-4 --lazy_hessian 1 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar100/adahessian/noise20.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer adahessian \
3 | --noise_level 0.2 \
4 | --lr 0.0015 --wd 5e-4 --lazy_hessian 1 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar100/adahessian/noise40.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer adahessian \
3 | --noise_level 0.4 \
4 | --lr 1e-3 --wd 5e-4 --lazy_hessian 1 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar100/adahessian/noise60.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer adahessian \
3 | --noise_level 0.6 \
4 | --lr 1e-2 --wd 5e-4 --lazy_hessian 1 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar100/msassha/noise20.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer msassha \
3 | --noise_level 0.2 \
4 | --lr 0.15 --wd 1e-3 --rho 0.8 --lazy_hessian 10 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar100/msassha/noise40.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer msassha \
3 | --noise_level 0.4 \
4 | --lr 0.15 --wd 1e-3 --rho 0.8 --lazy_hessian 10 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar100/msassha/noise60.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer msassha \
3 | --noise_level 0.6 \
4 | --lr 0.15 --wd 1e-3 --rho 0.8 --lazy_hessian 10 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar100/sam_sgd/noise20.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer samsgd \
3 | --noise_level 0.2 \
4 | --lr 0.1 --wd 1e-3 --rho 0.2 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar100/sam_sgd/noise40.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer samsgd \
3 | --noise_level 0.4 \
4 | --lr 0.01 --wd 1e-3 --rho 0.1 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar100/sam_sgd/noise60.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer samsgd \
3 | --noise_level 0.6 \
4 | --lr 0.03 --wd 1e-3 --rho 0.1 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar100/sassha/noise20.sh:
--------------------------------------------------------------------------------
1 | # we use three seeds 0, 1, 2
2 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \
3 | --optimizer sassha \
4 | --noise_level 0.2 \
5 | --hessian_power_scheduler constant \
6 | --lr 0.1 --min_lr 0.001 --wd 5e-4 --rho 0.1 --lazy_hessian 10 --seed 0, 1, 2
7 |
8 |
9 |
--------------------------------------------------------------------------------
/label_noise/config/cifar100/sassha/noise40.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer sassha \
3 | --noise_level 0.4 \
4 | --hessian_power_scheduler constant \
5 | --lr 0.1 --min_lr 0.001 --wd 5e-4 --rho 0.2 --lazy_hessian 10 --seed 0, 1, 2 \
6 |
--------------------------------------------------------------------------------
/label_noise/config/cifar100/sassha/noise60.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer sassha \
3 | --noise_level 0.6 \
4 | --hessian_power_scheduler constant \
5 | --lr 0.1 --min_lr 0.001 --wd 5e-4 --rho 0.2 --lazy_hessian 10 --seed 0, 1, 2 \
6 |
--------------------------------------------------------------------------------
/label_noise/config/cifar100/sgd/noise20.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer sgd \
3 | --noise_level 0.2 \
4 | --lr 5e-4 --wd 2e-3 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar100/sgd/noise40.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer sgd \
3 | --noise_level 0.4 \
4 | --lr 5e-4 --wd 2e-3 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar100/sgd/noise60.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer sgd \
3 | --noise_level 0.6 \
4 | --lr 5e-4 --wd 2e-3 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar100/shampoo/noise20.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer shampoo \
3 | --noise_level 0.2 \
4 | --lr 0.6 --wd 1e-2 --eps 1e-4 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar100/shampoo/noise40.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer shampoo \
3 | --noise_level 0.4 \
4 | --lr 0.6 --wd 1e-2 --eps 1e-4 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar100/shampoo/noise60.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer shampoo \
3 | --noise_level 0.6 \
4 | --lr 0.3 --wd 1e-2 --eps 1e-4 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar100/sophiah/noise20.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer sophiah \
3 | --noise_level 0.2 \
4 | --lr 1e-5 --wd 5e-4 --clip_threshold 0.05 --eps 1e-4 --lazy_hessian 1 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar100/sophiah/noise40.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer sophiah \
3 | --noise_level 0.4 \
4 | --lr 1e-5 --wd 5e-4 --clip_threshold 0.1 --eps 1e-4 --lazy_hessian 1 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/config/cifar100/sophiah/noise60.sh:
--------------------------------------------------------------------------------
1 | python train.py --data cifar100 -depth 32 --epochs 160 --batch-size 256 \
2 | --optimizer sophiah \
3 | --noise_level 0.6 \
4 | --lr 1e-5 --wd 5e-4 --clip_threshold 0.1 --eps 1e-4 --lazy_hessian 1 --seed 0, 1, 2 \
5 |
--------------------------------------------------------------------------------
/label_noise/models/resnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | '''Resnet for cifar dataset.
4 | Ported form
5 | https://github.com/facebook/fb.resnet.torch
6 | and
7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
8 | (c) YANG, Wei
9 | '''
10 | import torch.nn as nn
11 | import math
12 | from copy import deepcopy
13 |
14 | __all__ = ['resnet']
15 |
16 |
17 | def conv3x3(in_planes, out_planes, stride=1):
18 | "3x3 convolution with padding"
19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
20 | padding=1, bias=False)
21 |
22 |
23 | class BasicBlock(nn.Module):
24 | expansion = 1
25 |
26 | def __init__(
27 | self,
28 | inplanes,
29 | planes,
30 | residual_not,
31 | batch_norm_not,
32 | stride=1,
33 | downsample=None):
34 | super(BasicBlock, self).__init__()
35 | self.residual_not = residual_not
36 | self.batch_norm_not = batch_norm_not
37 | self.conv1 = conv3x3(inplanes, planes, stride)
38 | if self.batch_norm_not:
39 | self.bn1 = nn.BatchNorm2d(planes)
40 | self.relu = nn.ReLU(inplace=True)
41 | self.conv2 = conv3x3(planes, planes)
42 | if self.batch_norm_not:
43 | self.bn2 = nn.BatchNorm2d(planes)
44 | self.downsample = downsample
45 | self.stride = stride
46 |
47 | def forward(self, x):
48 | residual = x
49 |
50 | out = self.conv1(x)
51 | if self.batch_norm_not:
52 | out = self.bn1(out)
53 | out = self.relu(out)
54 |
55 | out = self.conv2(out)
56 | if self.batch_norm_not:
57 | out = self.bn2(out)
58 |
59 | if self.downsample is not None:
60 | residual = self.downsample(x)
61 | if self.residual_not:
62 | out += residual
63 | out = self.relu(out)
64 |
65 | return out
66 |
67 |
68 | class Bottleneck(nn.Module):
69 | expansion = 4
70 |
71 | def __init__(
72 | self,
73 | inplanes,
74 | planes,
75 | residual_not,
76 | batch_norm_not,
77 | stride=1,
78 | downsample=None):
79 | super(Bottleneck, self).__init__()
80 | self.residual_not = residual_not
81 | self.batch_norm_not = batch_norm_not
82 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
83 | if self.batch_norm_not:
84 | self.bn1 = nn.BatchNorm2d(planes)
85 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
86 | padding=1, bias=False)
87 | if self.batch_norm_not:
88 | self.bn2 = nn.BatchNorm2d(planes)
89 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
90 | if self.batch_norm_not:
91 | self.bn3 = nn.BatchNorm2d(planes * 4)
92 | self.relu = nn.ReLU(inplace=True)
93 | self.downsample = downsample
94 | self.stride = stride
95 |
96 | def forward(self, x):
97 | residual = x
98 |
99 | out = self.conv1(x)
100 | if self.batch_norm_not:
101 | out = self.bn1(out)
102 | out = self.relu(out)
103 |
104 | out = self.conv2(out)
105 | if self.batch_norm_not:
106 | out = self.bn2(out)
107 | out = self.relu(out)
108 |
109 | out = self.conv3(out)
110 | if self.batch_norm_not:
111 | out = self.bn3(out)
112 |
113 | if self.downsample is not None:
114 | residual = self.downsample(x)
115 | if self.residual_not:
116 | out += residual
117 |
118 | out = self.relu(out)
119 |
120 | return out
121 |
122 |
123 | ALPHA_ = 1
124 |
125 |
126 | class ResNet(nn.Module):
127 |
128 | def __init__(
129 | self,
130 | depth,
131 | residual_not=True,
132 | batch_norm_not=True,
133 | base_channel=16,
134 | num_classes=10):
135 | super(ResNet, self).__init__()
136 | # Model type specifies number of layers for CIFAR-10 model
137 | assert (depth - 2) % 6 == 0, 'depth should be 6n+2'
138 | n = (depth - 2) // 6
139 |
140 | # block = Bottleneck if depth >=44 else BasicBlock
141 | block = BasicBlock
142 |
143 | self.base_channel = int(base_channel)
144 | self.residual_not = residual_not
145 | self.batch_norm_not = batch_norm_not
146 | self.inplanes = self.base_channel * ALPHA_
147 | self.conv1 = nn.Conv2d(
148 | 3,
149 | self.base_channel *
150 | ALPHA_,
151 | kernel_size=3,
152 | padding=1,
153 | bias=False)
154 | if self.batch_norm_not:
155 | self.bn1 = nn.BatchNorm2d(self.base_channel * ALPHA_)
156 | self.relu = nn.ReLU(inplace=True)
157 | self.layer1 = self._make_layer(
158 | block,
159 | self.base_channel *
160 | ALPHA_,
161 | n,
162 | self.residual_not,
163 | self.batch_norm_not)
164 | self.layer2 = self._make_layer(
165 | block,
166 | self.base_channel *
167 | 2 *
168 | ALPHA_,
169 | n,
170 | self.residual_not,
171 | self.batch_norm_not,
172 | stride=2)
173 | self.layer3 = self._make_layer(
174 | block,
175 | self.base_channel *
176 | 4 *
177 | ALPHA_,
178 | n,
179 | self.residual_not,
180 | self.batch_norm_not,
181 | stride=2)
182 | self.avgpool = nn.AvgPool2d(8)
183 | self.fc = nn.Linear(
184 | self.base_channel *
185 | 4 *
186 | ALPHA_ *
187 | block.expansion,
188 | num_classes)
189 |
190 | for m in self.modules():
191 | if isinstance(m, nn.Conv2d):
192 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
193 | m.weight.data.normal_(0, math.sqrt(2. / n))
194 | elif isinstance(m, nn.BatchNorm2d):
195 | m.weight.data.fill_(1)
196 | m.bias.data.zero_()
197 |
198 | def _make_layer(
199 | self,
200 | block,
201 | planes,
202 | blocks,
203 | residual_not,
204 | batch_norm_not,
205 | stride=1):
206 | downsample = None
207 | if (stride != 1 or self.inplanes != planes *
208 | block.expansion) and (residual_not):
209 | if batch_norm_not:
210 | downsample = nn.Sequential(
211 | nn.Conv2d(self.inplanes, planes * block.expansion,
212 | kernel_size=1, stride=stride, bias=False),
213 | nn.BatchNorm2d(planes * block.expansion),
214 | )
215 | else:
216 | downsample = nn.Sequential(
217 | nn.Conv2d(self.inplanes, planes * block.expansion,
218 | kernel_size=1, stride=stride, bias=False),
219 | )
220 |
221 | layers = nn.ModuleList()
222 | layers.append(
223 | block(
224 | self.inplanes,
225 | planes,
226 | residual_not,
227 | batch_norm_not,
228 | stride,
229 | downsample))
230 | self.inplanes = planes * block.expansion
231 | for i in range(1, blocks):
232 | layers.append(
233 | block(
234 | self.inplanes,
235 | planes,
236 | residual_not,
237 | batch_norm_not))
238 |
239 | # return nn.Sequential(*layers)
240 | return layers
241 |
242 | def forward(self, x):
243 | output_list = []
244 | x = self.conv1(x)
245 | if self.batch_norm_not:
246 | x = self.bn1(x)
247 | x = self.relu(x) # 32x32
248 | output_list.append(x.view(x.size(0), -1))
249 |
250 | for layer in self.layer1:
251 | x = layer(x) # 32x32
252 | output_list.append(x.view(x.size(0), -1))
253 | for layer in self.layer2:
254 | x = layer(x) # 16x16
255 | output_list.append(x.view(x.size(0), -1))
256 | for layer in self.layer3:
257 | x = layer(x) # 8x8
258 | output_list.append(x.view(x.size(0), -1))
259 |
260 | x = self.avgpool(x)
261 | x = x.view(x.size(0), -1)
262 | x = self.fc(x)
263 | output_list.append(x.view(x.size(0), -1))
264 |
265 | # return output_list, x
266 | return x
267 |
268 |
269 | def resnet(**kwargs):
270 | """
271 | Constructs a ResNet model.
272 | """
273 | return ResNet(**kwargs)
274 |
--------------------------------------------------------------------------------
/label_noise/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import logging
3 | import os
4 | import sys
5 | import random
6 |
7 | import numpy as np
8 | import argparse
9 | from tqdm import tqdm, trange
10 |
11 | import torch
12 | import torch.backends.cudnn as cudnn
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 | import torch.optim as optim
16 | import torch.optim.lr_scheduler as lr_scheduler
17 | from torchvision import datasets, transforms
18 | from torch.autograd import Variable
19 |
20 | from utils import *
21 | from models.resnet import *
22 | from bypass_bn import enable_running_stats, disable_running_stats
23 |
24 | # load optimizers
25 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname('optimizers'))))
26 | from optimizers import get_optimizer
27 |
28 | # load hessain power scheduler
29 | from optimizers.hessian_scheduler import ConstantScheduler, ProportionScheduler, LinearScheduler, CosineScheduler
30 |
31 | # Training settings
32 | parser = argparse.ArgumentParser(description='PyTorch Example')
33 | parser.add_argument('--batch-size', type=int, default=256, metavar='B',
34 | help='input batch size for training (default: 256)')
35 | parser.add_argument('--test-batch-size', type=int, default=256, metavar='TB',
36 | help='input batch size for testing (default: 256)')
37 | parser.add_argument('--epochs', type=int, default=160, metavar='E',
38 | help='number of epochs to train (default: 10)')
39 | parser.add_argument('--lr', type=float, default=0.15, metavar='LR',
40 | help='learning rate (default: 0.15)')
41 | parser.add_argument('--lr-decay', type=float, default=0.1,
42 | help='learning rate ratio')
43 | parser.add_argument('--lr-decay-epoch', type=int, nargs='+', default=[80, 120],
44 | help='decrease learning rate at these epochs.')
45 | parser.add_argument('--seed', type=int, default=0, metavar='S',
46 | help='random seed (default: 0)')
47 | parser.add_argument('--weight_decay', '--wd', default=5e-4, type=float,
48 | metavar='W', help='weight decay (default: 5e-4)')
49 | parser.add_argument('--depth', type=int, default=20,
50 | help='choose the depth of resnet')
51 | parser.add_argument('--optimizer', type=str, default='sassha',
52 | help='choose optim')
53 | parser.add_argument('--data', type=str, default='cifar100',
54 | help='choose dataset cifar10/100')
55 | parser.add_argument('--noise_level', type=float, default=0.2,
56 | help='noise_level')
57 | parser.add_argument('--min_lr', type=float, default=0.0, help="the minimum value of learning rate")
58 |
59 | # Second-order optimization settings
60 | parser.add_argument("--n_samples", default=1, type=int, help="the number of sampling")
61 | parser.add_argument('--betas', type=float, nargs='*', default=[0.9, 0.999], help='betas')
62 | parser.add_argument("--eps", default=1e-4, type=float, help="add a small number for stability")
63 | parser.add_argument("--lazy_hessian", default=10, type=int, help="Delayed hessian update.")
64 | parser.add_argument("--clip_threshold", default=0.05, type=float, help="Clipping threshold.")
65 |
66 | # Hessian power scheduler
67 | parser.add_argument('--hessian_power_scheduler', type=str, default='constant', help="choose LRScheduler 1. 'constant', 2. 'proportion', 3. 'linear', 4. 'cosine'")
68 | parser.add_argument('--max_hessian_power', type=float, default=1)
69 | parser.add_argument('--min_hessian_power', type=float, default=0.5)
70 |
71 | # Sharpness minimization settings
72 | parser.add_argument("--rho", default=0.05, type=float, help="Rho parameter for SAM.")
73 | parser.add_argument("--adaptive", default=False, type=bool, help="True if you want to use the Adaptive SAM.")
74 | parser.add_argument('--project_name', type=str, default='project_name', help="project_name")
75 |
76 | args = parser.parse_args()
77 |
78 | # wandb logging
79 | wandb_log = True
80 | wandb_project = args.project_name
81 | wandb_run_name = f'{args.optimizer}_lr_{args.lr}_wd_{args.weight_decay}_rho_{args.rho}'
82 |
83 | num_classes = ''.join([c for c in args.data if c.isdigit()])
84 | num_classes = int(num_classes)
85 |
86 | # set random seed to reproduce the work
87 | random.seed(args.seed)
88 | np.random.seed(args.seed)
89 | torch.manual_seed(args.seed)
90 | torch.cuda.manual_seed(args.seed)
91 | cudnn.deterministic = True
92 | cudnn.benchmark = False
93 |
94 | for arg in vars(args):
95 | print(arg, getattr(args, arg))
96 | if not os.path.isdir('checkpoint/'):
97 | os.makedirs('checkpoint/')
98 |
99 | # get a dataset (e.g,. cifar10, cifar100)
100 | train_loader, test_loader = getNoisyData(name=args.data,
101 | train_bs=args.batch_size,
102 | test_bs=args.test_batch_size,
103 | noise_level=args.noise_level)
104 |
105 | # get a model
106 | model = resnet(num_classes=num_classes, depth=args.depth).cuda()
107 | print(model)
108 |
109 | model = torch.nn.DataParallel(model)
110 | print(' Total params: %.2fM' % (sum(p.numel()
111 | for p in model.parameters()) / 1000000.0))
112 | # define a loss
113 | criterion = nn.CrossEntropyLoss()
114 |
115 | # get an optimizer
116 | optimizer, create_graph, two_steps = get_optimizer(model, args)
117 |
118 | # learning rate schedule
119 | scheduler = lr_scheduler.MultiStepLR(
120 | optimizer,
121 | args.lr_decay_epoch,
122 | gamma=args.lr_decay,
123 | last_epoch=-1)
124 |
125 | # select a hessian power scheduler
126 | if args.hessian_power_scheduler == 'constant':
127 | hessian_power_scheduler = ConstantScheduler(
128 | T_max=args.epochs*len(train_loader),
129 | max_value=0.5,
130 | min_value=0.5)
131 |
132 | elif args.hessian_power_scheduler == 'proportion':
133 | hessian_power_scheduler = ProportionScheduler(
134 | pytorch_lr_scheduler=scheduler,
135 | max_lr=args.lr,
136 | min_lr=args.min_lr,
137 | max_value=args.max_hessian_power,
138 | min_value=args.min_hessian_power)
139 |
140 | elif args.hessian_power_scheduler == 'linear':
141 | hessian_power_scheduler = LinearScheduler(
142 | T_max=args.epochs*len(train_loader),
143 | max_value=args.max_hessian_power,
144 | min_value=args.min_hessian_power)
145 |
146 | elif args.hessian_power_scheduler == 'cosine':
147 | hessian_power_scheduler = CosineScheduler(
148 | T_max=args.epochs*len(train_loader),
149 | max_value=args.max_hessian_power,
150 | min_value=args.min_hessian_power)
151 |
152 | optimizer.hessian_power_scheduler = hessian_power_scheduler
153 |
154 | # import and init wandb
155 | if wandb_log:
156 | import wandb
157 | os.environ["WANDB__SERVICE_WAIT"] = "300"
158 | wandb.init(project=wandb_project, name=wandb_run_name)
159 | wandb.config.update(args)
160 |
161 | best_acc = 0.0
162 | iter_num = 0
163 | # training loop
164 | for epoch in range(1, args.epochs + 1):
165 | print('Current Epoch: ', epoch)
166 | train_loss = 0.
167 | total_num = 0
168 | correct = 0
169 |
170 | scheduler.step()
171 | model.train()
172 |
173 | if args.optimizer == 'msassha':
174 | optimizer.move_up_to_momentumAscent()
175 |
176 | with tqdm(total=len(train_loader.dataset)) as progressbar:
177 | for batch_idx, (data, target) in enumerate(train_loader):
178 | data, target = data.cuda(), target.cuda()
179 |
180 | if two_steps:
181 | enable_running_stats(model)
182 | output = model(data)
183 | loss = criterion(output, target)
184 | loss.backward()
185 |
186 | if args.optimizer == 'sassha':
187 | optimizer.perturb_weights(zero_grad=True)
188 |
189 | elif args.optimizer in ['samsgd', 'samadamw']:
190 | optimizer.first_step(zero_grad=True)
191 |
192 | disable_running_stats(model)
193 | criterion(model(data), target).backward(create_graph=create_graph)
194 |
195 | if args.optimizer == 'sassha':
196 | optimizer.unperturb()
197 | optimizer.step()
198 | optimizer.zero_grad()
199 |
200 | elif args.optimizer in ['samsgd', 'samadamw']:
201 | optimizer.second_step(zero_grad=True)
202 |
203 | else:
204 | output = model(data)
205 | loss = criterion(output, target)
206 | loss.backward(create_graph=create_graph)
207 | optimizer.step()
208 | optimizer.zero_grad()
209 |
210 | # for records
211 | train_loss += loss.item() * target.size()[0]
212 | total_num += target.size()[0]
213 | _, predicted = output.max(1)
214 | correct += predicted.eq(target).sum().item()
215 | progressbar.update(target.size(0))
216 | iter_num += 1
217 |
218 | if args.optimizer == 'msassha':
219 | optimizer.move_back_from_momentumAscent()
220 |
221 | acc, val_loss = test(model, test_loader, criterion)
222 |
223 | train_loss /= total_num
224 | train_acc = correct / total_num * 100
225 |
226 | if acc > best_acc:
227 | best_acc = acc
228 |
229 | if wandb_log:
230 | wandb.log({
231 | "iter": iter_num,
232 | "train/loss": train_loss,
233 | "train/acc": train_acc,
234 | "val/acc": acc*100,
235 | "val/loss": val_loss,
236 | "lr": scheduler.get_last_lr()[-1],
237 | 'best_accuracy': best_acc,
238 | "hessian_power": optimizer.hessian_power_t if args.optimizer == 'sassha' else 0},
239 | step=epoch)
240 |
241 | print(f"Training Loss of Epoch {epoch}: {np.around(train_loss, 2)}")
242 | print(f"Testing of Epoch {epoch}: {np.around(acc * 100, 2)} \n")
243 |
244 | print(f'Best Acc: {np.around(best_acc * 100, 2)}')
245 |
--------------------------------------------------------------------------------
/label_noise/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.optim as optim
7 | from torchvision import datasets, transforms
8 | from torch.utils.data import Dataset, DataLoader
9 | from torch.autograd import Variable
10 |
11 | def introduce_label_noise(labels, num_classes, noise_level=0.1):
12 | n_samples = len(labels)
13 | n_noisy = int(n_samples * noise_level)
14 | noisy_indices = torch.randperm(n_samples)[:n_noisy]
15 | # Generate new random labels for the selected indices
16 | new_labels = torch.randint(0, num_classes, (n_noisy,))
17 | labels[noisy_indices] = new_labels
18 | return labels
19 |
20 | def getNoisyData(
21 | name='cifar10',
22 | train_bs=128,
23 | test_bs=1000,
24 | noise_level=0.1,
25 | ):
26 |
27 | if name == 'mnist':
28 |
29 | train_loader = torch.utils.data.DataLoader(
30 | datasets.MNIST('../data', train=True, download=True,
31 | transform=transforms.Compose([
32 | transforms.ToTensor(),
33 | transforms.Normalize((0.1307,), (0.3081,))
34 | ])),
35 | batch_size=train_bs, shuffle=True)
36 |
37 | test_loader = torch.utils.data.DataLoader(
38 | datasets.MNIST(
39 | '../data',
40 | train=False,
41 | transform=transforms.Compose(
42 | [
43 | transforms.ToTensor(),
44 | transforms.Normalize(
45 | (0.1307,
46 | ),
47 | (0.3081,
48 | ))])),
49 | batch_size=test_bs,
50 | shuffle=False)
51 |
52 | if name == 'cifar10':
53 | transform_train = transforms.Compose([
54 | transforms.RandomCrop(32, padding=4),
55 | transforms.RandomHorizontalFlip(),
56 | transforms.ToTensor(),
57 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
58 | ])
59 |
60 | trainset = datasets.CIFAR10(
61 | root='../data',
62 | train=True,
63 | download=True,
64 | transform=transform_train)
65 |
66 | trainset.targets = introduce_label_noise(torch.tensor(trainset.targets),
67 | num_classes=10,
68 | noise_level=noise_level)
69 |
70 | train_loader = torch.utils.data.DataLoader(
71 | trainset, batch_size=train_bs, shuffle=True, drop_last=False)
72 |
73 | transform_test = transforms.Compose([
74 | transforms.ToTensor(),
75 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
76 | ])
77 |
78 | testset = datasets.CIFAR10(
79 | root='../data',
80 | train=False,
81 | download=False,
82 | transform=transform_test)
83 |
84 | test_loader = torch.utils.data.DataLoader(
85 | testset, batch_size=test_bs, shuffle=False)
86 |
87 | if name == 'cifar100':
88 |
89 | transform_train = transforms.Compose([
90 | transforms.RandomCrop(32, padding=4),
91 | transforms.RandomHorizontalFlip(),
92 | transforms.ToTensor(),
93 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
94 | ])
95 |
96 | trainset = datasets.CIFAR100(
97 | root='../data',
98 | train=True,
99 | download=True,
100 | transform=transform_train)
101 |
102 | trainset.targets = introduce_label_noise(torch.tensor(trainset.targets),
103 | num_classes=100,
104 | noise_level=noise_level)
105 |
106 | train_loader = torch.utils.data.DataLoader(
107 | trainset, batch_size=train_bs, shuffle=True, drop_last=False)
108 |
109 | transform_test = transforms.Compose([
110 | transforms.ToTensor(),
111 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
112 | ])
113 |
114 | testset = datasets.CIFAR100(
115 | root='../data',
116 | train=False,
117 | download=False,
118 | transform=transform_test)
119 |
120 | test_loader = torch.utils.data.DataLoader(
121 | testset, batch_size=test_bs, shuffle=False)
122 |
123 | return train_loader, test_loader
124 |
125 |
126 | def test(model, test_loader, criterion):
127 | # print('Testing')
128 | model.eval()
129 | correct = 0
130 | total_num = 0
131 | val_loss = 0.0
132 | with torch.no_grad():
133 | for data, target in test_loader:
134 | data, target = data.cuda(), target.cuda()
135 | output = model(data)
136 | loss = criterion(output, target)
137 | # get the index of the max log-probability
138 | pred = output.data.max(1, keepdim=True)[1]
139 | correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
140 | val_loss += loss.item() * target.size()[0]
141 | total_num += len(data)
142 | # print('testing_correct: ', correct / total_num, '\n')
143 | return (correct / total_num, val_loss / total_num)
144 |
145 |
146 | def get_params_grad(model):
147 | """
148 | get model parameters and corresponding gradients
149 | """
150 | params = []
151 | grads = []
152 | for param in model.parameters():
153 | if not param.requires_grad:
154 | continue
155 | params.append(param)
156 | grads.append(0. if param.grad is None else param.grad + 0.)
157 | return params, grads
158 |
--------------------------------------------------------------------------------
/language_tasks/README.md:
--------------------------------------------------------------------------------
1 | # We plan to release code for pretraining large language models that exceed the scale of the mini-gpt1 model.
--------------------------------------------------------------------------------
/language_tasks/config/adahessian/mnli.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name mnli \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer adahessian \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-3 \
10 | --weight_decay 1e-6 \
11 | --eps 1e-6 \
12 | --lazy_hessian 1 \
13 | --seed 0, 1, 2
14 |
--------------------------------------------------------------------------------
/language_tasks/config/adahessian/mrpc.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name mrpc \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer adahessian \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-2 \
10 | --weight_decay 1e-6 \
11 | --eps 1e-4 \
12 | --lazy_hessian 1 \
13 | --seed 0, 1, 2
14 |
--------------------------------------------------------------------------------
/language_tasks/config/adahessian/qnli.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name qnli \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer adahessian \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-3 \
10 | --weight_decay 1e-7 \
11 | --eps 1e-4 \
12 | --lazy_hessian 1 \
13 | --seed 0, 1, 2
14 |
--------------------------------------------------------------------------------
/language_tasks/config/adahessian/qqp.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name qqp \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer adahessian \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-2 \
10 | --weight_decay 1e-7 \
11 | --eps 1e-4 \
12 | --lazy_hessian 1 \
13 | --seed 0, 1, 2
14 |
--------------------------------------------------------------------------------
/language_tasks/config/adahessian/rte.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name rte \
4 | --max_length 512 \
5 | --num_train_epochs 10 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer adahessian \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-3 \
10 | --weight_decay 1e-6 \
11 | --eps 1e-4 \
12 | --lazy_hessian 1 \
13 | --seed 0, 1, 2
14 |
--------------------------------------------------------------------------------
/language_tasks/config/adahessian/sst2.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name sst2 \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer adahessian \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-3 \
10 | --weight_decay 1e-6 \
11 | --eps 1e-8 \
12 | --lazy_hessian 1 \
13 | --seed 0, 1, 2
14 |
--------------------------------------------------------------------------------
/language_tasks/config/adahessian/stsb.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name stsb \
4 | --max_length 512 \
5 | --num_train_epochs 10 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer sassha \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-2 \
10 | --weight_decay 1e-6 \
11 | --eps 1e-4 \
12 | --lazy_hessian 1 \
13 | --seed 0, 1, 2
14 |
--------------------------------------------------------------------------------
/language_tasks/config/adamw/mnli.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name mnli \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer adamw \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-5 \
10 | --weight_decay 1e-6 \
11 | --eps 1e-6 \
12 | --seed 0, 1, 2
13 |
--------------------------------------------------------------------------------
/language_tasks/config/adamw/mrpc.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name mrpc \
4 | --max_length 512 \
5 | --num_train_epochs 10 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer adamw \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-4 \
10 | --weight_decay 1e-7 \
11 | --eps 1e-4 \
12 | --seed 0, 1, 2
13 |
--------------------------------------------------------------------------------
/language_tasks/config/adamw/qnli.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name qnli \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer adamw \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-5 \
10 | --weight_decay 1e-7 \
11 | --eps 1e-8 \
12 | --seed 0, 1, 2
13 |
--------------------------------------------------------------------------------
/language_tasks/config/adamw/qqp.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name qqp \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer adamw \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-5 \
10 | --weight_decay 1e-6 \
11 | --eps 1e-6 \
12 | --seed 0, 1, 2
13 |
--------------------------------------------------------------------------------
/language_tasks/config/adamw/rte.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name rte \
4 | --max_length 512 \
5 | --num_train_epochs 10 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer adamw \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-2 \
10 | --weight_decay 1e-8 \
11 | --eps 1e-8 \
12 | --lazy_hessian 1 \
13 | --seed 0, 1, 2
14 |
--------------------------------------------------------------------------------
/language_tasks/config/adamw/sst2.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name sst2 \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer adamw \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-5 \
10 | --weight_decay 1e-8 \
11 | --eps 1e-8 \
12 | --seed 0, 1, 2
13 |
--------------------------------------------------------------------------------
/language_tasks/config/adamw/stsb.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name stsb \
4 | --max_length 512 \
5 | --num_train_epochs 10 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer adamw \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-5 \
10 | --weight_decay 1e-5 \
11 | --eps 1e-8 \
12 | --seed 0, 1, 2
13 |
--------------------------------------------------------------------------------
/language_tasks/config/msassha/mnli.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name mnli \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer msassha \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-2 \
10 | --weight_decay 1e-6 \
11 | --rho 1e-2 \
12 | --eps 1e-2 \
13 | --lazy_hessian 1 \
14 | --seed 0, 1, 2
15 |
--------------------------------------------------------------------------------
/language_tasks/config/msassha/mrpc.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name mrpc \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer msassha \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-2 \
10 | --weight_decay 1e-8 \
11 | --rho 1e-5 \
12 | --eps 1e-4 \
13 | --lazy_hessian 1 \
14 | --seed 0, 1, 2
15 |
--------------------------------------------------------------------------------
/language_tasks/config/msassha/qnli.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name qnli \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer msassha \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-2 \
10 | --weight_decay 1e-6 \
11 | --rho 1e-5 \
12 | --eps 1e-2 \
13 | --lazy_hessian 1 \
14 | --seed 0, 1, 2
15 |
--------------------------------------------------------------------------------
/language_tasks/config/msassha/qqp.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name qqp \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer msassha \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-2 \
10 | --weight_decay 1e-6 \
11 | --rho 1e-5 \
12 | --eps 1e-4 \
13 | --lazy_hessian 1 \
14 | --seed 0, 1, 2
15 |
--------------------------------------------------------------------------------
/language_tasks/config/msassha/rte.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name rte \
4 | --max_length 512 \
5 | --num_train_epochs 10 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer msassha \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-2 \
10 | --weight_decay 1e-6 \
11 | --rho 1e-3 \
12 | --eps 1e-8 \
13 | --lazy_hessian 1 \
14 | --seed 0, 1, 2
15 |
--------------------------------------------------------------------------------
/language_tasks/config/msassha/sst2.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name sst2 \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer msassha \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-2 \
10 | --weight_decay 1e-6 \
11 | --rho 1e-4 \
12 | --eps 1e-8 \
13 | --lazy_hessian 1 \
14 | --seed 0, 1, 2
15 |
--------------------------------------------------------------------------------
/language_tasks/config/msassha/stsb.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name stsb \
4 | --max_length 512 \
5 | --num_train_epochs 10 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer msassha \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-3 \
10 | --weight_decay 1e-5 \
11 | --rho 1e-5 \
12 | --eps 1e-4 \
13 | --lazy_hessian 1 \
14 | --seed 0, 1, 2
15 |
--------------------------------------------------------------------------------
/language_tasks/config/sam_adamw/mnli.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name mnli \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer samadamw \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-5 \
10 | --weight_decay 1e-6 \
11 | --rho 1e-2 \
12 | --eps 1e-6 \
13 | --seed 0, 1, 2
14 |
--------------------------------------------------------------------------------
/language_tasks/config/sam_adamw/mrpc.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name mrpc \
4 | --max_length 512 \
5 | --num_train_epochs 10 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer samadamw \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-5 \
10 | --weight_decay 1e-7 \
11 | --rho 1e-4 \
12 | --eps 1e-6 \
13 | --seed 0, 1, 2
14 |
--------------------------------------------------------------------------------
/language_tasks/config/sam_adamw/qnli.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name qnli \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer samadamw \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-5 \
10 | --weight_decay 1e-7 \
11 | --rho 1e-2 \
12 | --eps 1e-8 \
13 | --seed 0, 1, 2
14 |
--------------------------------------------------------------------------------
/language_tasks/config/sam_adamw/qqp.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name qqp \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer samadamw \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-4 \
10 | --weight_decay 1e-6 \
11 | --rho 1e-2 \
12 | --eps 1e-6 \
13 | --seed 0, 1, 2
14 |
--------------------------------------------------------------------------------
/language_tasks/config/sam_adamw/rte.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name rte \
4 | --max_length 512 \
5 | --num_train_epochs 10 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer samadamw \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-4 \
10 | --weight_decay 1e-7 \
11 | --rho 1e-3 \
12 | --eps 1e-6 \
13 | --lazy_hessian 1 \
14 | --seed 0, 1, 2
15 |
--------------------------------------------------------------------------------
/language_tasks/config/sam_adamw/sst2.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name sst2 \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer samadamw \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-5 \
10 | --weight_decay 1e-5 \
11 | --rho 1e-2 \
12 | --eps 1e-8 \
13 | --seed 0, 1, 2
14 |
--------------------------------------------------------------------------------
/language_tasks/config/sam_adamw/stsb.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name stsb \
4 | --max_length 512 \
5 | --num_train_epochs 10 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer samadamw \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-4 \
10 | --weight_decay 1e-7 \
11 | --rho 1e-3 \
12 | --eps 1e-6 \
13 | --seed 0, 1, 2
14 |
--------------------------------------------------------------------------------
/language_tasks/config/sassha/mnli.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name mnli \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer sassha \
8 | --lr_scheduler_type polynomial \
9 | --hessian_power_scheduler constant \
10 | --lr 1e-2 \
11 | --weight_decay 1e-6 \
12 | --rho 1e-2 \
13 | --eps 1e-6 \
14 | --lazy_hessian 1 \
15 | --seed 0, 1, 2
16 |
--------------------------------------------------------------------------------
/language_tasks/config/sassha/mrpc.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name mrpc \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer sassha \
8 | --lr_scheduler_type polynomial \
9 | --hessian_power_scheduler constant \
10 | --lr 1e-2 \
11 | --weight_decay 1e-8 \
12 | --rho 1e-4 \
13 | --eps 1e-4 \
14 | --lazy_hessian 1 \
15 | --seed 0, 1, 2
16 |
--------------------------------------------------------------------------------
/language_tasks/config/sassha/qnli.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name qnli \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer sassha \
8 | --lr_scheduler_type polynomial \
9 | --hessian_power_scheduler constant \
10 | --lr 1e-2 \
11 | --weight_decay 1e-7 \
12 | --rho 1e-2 \
13 | --eps 1e-6 \
14 | --lazy_hessian 1 \
15 | --seed 0, 1, 2
16 |
--------------------------------------------------------------------------------
/language_tasks/config/sassha/qqp.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name qqp \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer sassha \
8 | --lr_scheduler_type polynomial \
9 | --hessian_power_scheduler constant \
10 | --lr 1e-2 \
11 | --weight_decay 1e-6 \
12 | --rho 1e-2 \
13 | --eps 1e-4 \
14 | --lazy_hessian 1 \
15 | --seed 0, 1, 2
16 |
--------------------------------------------------------------------------------
/language_tasks/config/sassha/rte.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name rte \
4 | --max_length 512 \
5 | --num_train_epochs 10 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer sassha \
8 | --lr_scheduler_type polynomial \
9 | --hessian_power_scheduler constant \
10 | --lr 1e-2 \
11 | --weight_decay 1e-4 \
12 | --rho 1e-2 \
13 | --eps 1e-4 \
14 | --lazy_hessian 1 \
15 | --seed 0, 1, 2
16 |
--------------------------------------------------------------------------------
/language_tasks/config/sassha/sst2.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name sst2 \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer sassha \
8 | --lr_scheduler_type polynomial \
9 | --hessian_power_scheduler constant \
10 | --lr 1e-2 \
11 | --weight_decay 1e-8 \
12 | --rho 1e-4 \
13 | --eps 1e-6 \
14 | --lazy_hessian 1 \
15 | --seed 0, 1, 2
16 |
--------------------------------------------------------------------------------
/language_tasks/config/sassha/stsb.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name stsb \
4 | --max_length 512 \
5 | --num_train_epochs 10 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer sassha \
8 | --lr_scheduler_type polynomial \
9 | --hessian_power_scheduler constant \
10 | --lr 1e-2 \
11 | --weight_decay 1e-5 \
12 | --rho 1e-2 \
13 | --eps 1e-8 \
14 | --lazy_hessian 1 \
15 | --seed 0, 1, 2
16 |
--------------------------------------------------------------------------------
/language_tasks/config/sophiah/mnli.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name mnli \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer sophiah \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-5 \
10 | --weight_decay 0 \
11 | --clip_threshold 1e-4 \
12 | --eps 1e-6 \
13 | --lazy_hessian 1 \
14 | --seed 0, 1, 2
15 |
--------------------------------------------------------------------------------
/language_tasks/config/sophiah/mrpc.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name mrpc \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer sophiah \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-5 \
10 | --weight_decay 1e-6 \
11 | --clip_threshold 1e-2 \
12 | --eps 1e-6 \
13 | --lazy_hessian 1 \
14 | --seed 0, 1, 2
15 |
--------------------------------------------------------------------------------
/language_tasks/config/sophiah/qnli.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name qnli \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer sophiah \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-5 \
10 | --weight_decay 1e-7 \
11 | --clip_threshold 1e-2 \
12 | --eps 1e-4 \
13 | --lazy_hessian 1 \
14 | --seed 0, 1, 2
15 |
--------------------------------------------------------------------------------
/language_tasks/config/sophiah/qqp.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name qqp \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer sophiah \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-5 \
10 | --weight_decay 1e-6 \
11 | --clip_threshold 0.1 \
12 | --eps 1e-4 \
13 | --lazy_hessian 1 \
14 | --seed 0, 1, 2
15 |
--------------------------------------------------------------------------------
/language_tasks/config/sophiah/rte.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name rte \
4 | --max_length 512 \
5 | --num_train_epochs 10 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer sophiah \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-5 \
10 | --weight_decay 1e-4 \
11 | --clip_threshold 1e-5 \
12 | --eps 1e-4 \
13 | --lazy_hessian 1 \
14 | --seed 0, 1, 2
15 |
--------------------------------------------------------------------------------
/language_tasks/config/sophiah/sst2.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name sst2 \
4 | --max_length 512 \
5 | --num_train_epochs 5 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer sophiah \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-5 \
10 | --weight_decay 0 \
11 | --clip_threshold 1e-4 \
12 | --eps 1e-6 \
13 | --lazy_hessian 1 \
14 | --seed 0, 1, 2
15 |
--------------------------------------------------------------------------------
/language_tasks/config/sophiah/stsb.sh:
--------------------------------------------------------------------------------
1 | python finetune.py \
2 | --model_name_or_path squeezebert/squeezebert-uncased \
3 | --task_name stsb \
4 | --max_length 512 \
5 | --num_train_epochs 10 \
6 | --per_device_train_batch_size 16 \
7 | --optimizer sophiah \
8 | --lr_scheduler_type polynomial \
9 | --lr 1e-5 \
10 | --weight_decay 1e-5 \
11 | --clip_threshold 1e-4 \
12 | --eps 1e-6 \
13 | --lazy_hessian 1 \
14 | --seed 0, 1, 2
15 |
--------------------------------------------------------------------------------
/language_tasks/config/wandb/rte_sophia.sh:
--------------------------------------------------------------------------------
1 | project: squeezebert-rte # here
2 | program: finetune.py # here
3 | method: grid
4 | metric:
5 | name: accuracy.accuracy
6 | goal: maximize
7 | parameters:
8 | optimizer:
9 | values: ['sophiah'] # here
10 | task_name:
11 | values: ['rte'] # here
12 | learning_rate:
13 | values: [1e-1, 1e-2, 1e-3, 1e-4, 1e-5]
14 | weight_decay:
15 | values: [0, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
16 | update_each:
17 | values: [1, 2, 3, 4, 5, 10]
18 | clip_threshold:
19 | values: [1e-1, 1e-2, 1e-3, 1e-4, 1e-5]
20 | eps:
21 | values: [1e-4, 1e-6, 1e-8]
22 | model_name_or_path:
23 | values: ['squeezebert/squeezebert-uncased']
24 | max_length:
25 | values: [512]
26 | num_train_epochs:
27 | values: [10]
28 | lr_scheduler_type:
29 | values: ['polynomial']
30 | per_device_train_batch_size:
31 | values: [16]
32 |
--------------------------------------------------------------------------------
/language_tasks/finetune.py:
--------------------------------------------------------------------------------
1 | # Acknowledgement: this code is based on : https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue_no_trainer.py
2 |
3 | import sys
4 | import argparse
5 | import json
6 | import logging
7 | import math
8 | import os
9 | import random
10 | from pathlib import Path
11 |
12 | import datasets
13 | import evaluate
14 | import torch
15 | from transformers import set_seed
16 |
17 | from datasets import load_dataset
18 | from huggingface_hub import HfApi
19 | from torch.utils.data import DataLoader
20 | import torch.nn as nn
21 |
22 | from tqdm.auto import tqdm
23 |
24 | import transformers
25 | from transformers import (
26 | AutoConfig,
27 | AutoModelForSequenceClassification,
28 | AutoTokenizer,
29 | DataCollatorWithPadding,
30 | PretrainedConfig,
31 | SchedulerType,
32 | default_data_collator,
33 | get_scheduler,
34 | )
35 | from transformers.utils import check_min_version, send_example_telemetry
36 | from transformers.utils.versions import require_version
37 |
38 | import wandb
39 | os.environ["WANDB__SERVICE_WAIT"] = "300"
40 |
41 | # Load optimizers
42 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname('optimizers'))))
43 | from optimizers import get_optimizer
44 |
45 | # load hessain power scheduler
46 | from optimizers.hessian_scheduler import ConstantScheduler, ProportionScheduler, LinearScheduler, CosineScheduler
47 |
48 |
49 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
50 | check_min_version("4.41.0.dev0")
51 |
52 | #logger = get_logger(__name__)
53 |
54 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
55 |
56 | task_to_keys = {
57 | "cola": ("sentence", None),
58 | "mnli": ("premise", "hypothesis"),
59 | "mrpc": ("sentence1", "sentence2"),
60 | "qnli": ("question", "sentence"),
61 | "qqp": ("question1", "question2"),
62 | "rte": ("sentence1", "sentence2"),
63 | "sst2": ("sentence", None),
64 | "stsb": ("sentence1", "sentence2"),
65 | "wnli": ("sentence1", "sentence2"),
66 | }
67 |
68 |
69 | def parse_args():
70 | parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
71 | parser.add_argument(
72 | "--task_name",
73 | type=str,
74 | default=None,
75 | help="The name of the glue task to train on.",
76 | choices=list(task_to_keys.keys()),
77 | )
78 | parser.add_argument(
79 | "--train_file", type=str, default=None, help="A csv or a json file containing the training data."
80 | )
81 | parser.add_argument(
82 | "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
83 | )
84 | parser.add_argument(
85 | "--max_length",
86 | type=int,
87 | default=128,
88 | help=(
89 | "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
90 | " sequences shorter will be padded if `--pad_to_max_length` is passed."
91 | ),
92 | )
93 | parser.add_argument(
94 | "--pad_to_max_length",
95 | action="store_true",
96 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.",
97 | )
98 | parser.add_argument(
99 | "--model_name_or_path",
100 | type=str,
101 | help="Path to pretrained model or model identifier from huggingface.co/models.",
102 | required=True,
103 | )
104 | parser.add_argument(
105 | "--use_slow_tokenizer",
106 | action="store_true",
107 | help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
108 | )
109 | parser.add_argument(
110 | "--per_device_train_batch_size",
111 | type=int,
112 | default=8,
113 | help="Batch size (per device) for the training dataloader.",
114 | )
115 | parser.add_argument(
116 | "--per_device_eval_batch_size",
117 | type=int,
118 | default=8,
119 | help="Batch size (per device) for the evaluation dataloader.",
120 | )
121 | parser.add_argument(
122 | "--lr",
123 | type=float,
124 | default=5e-5,
125 | help="Initial learning rate (after the potential warmup period) to use.",
126 | )
127 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
128 | parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
129 | parser.add_argument(
130 | "--max_train_steps",
131 | type=int,
132 | default=None,
133 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
134 | )
135 | parser.add_argument(
136 | "--gradient_accumulation_steps",
137 | type=int,
138 | default=1,
139 | help="Number of updates steps to accumulate before performing a backward/update pass.",
140 | )
141 | parser.add_argument(
142 | "--lr_scheduler_type",
143 | type=SchedulerType,
144 | default="linear",
145 | help="The scheduler type to use.",
146 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
147 | )
148 | parser.add_argument(
149 | "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
150 | )
151 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
152 | parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
153 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
154 | parser.add_argument(
155 | "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
156 | )
157 | parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
158 | parser.add_argument(
159 | "--trust_remote_code",
160 | type=bool,
161 | default=False,
162 | help=(
163 | "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
164 | "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
165 | "execute code present on the Hub on your local machine."
166 | ),
167 | )
168 | parser.add_argument(
169 | "--checkpointing_steps",
170 | type=str,
171 | default=None,
172 | help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
173 | )
174 | parser.add_argument(
175 | "--resume_from_checkpoint",
176 | type=str,
177 | default=None,
178 | help="If the training should continue from a checkpoint folder.",
179 | )
180 | parser.add_argument(
181 | "--with_tracking",
182 | action="store_true",
183 | help="Whether to enable experiment trackers for logging.",
184 | )
185 | parser.add_argument(
186 | "--report_to",
187 | type=str,
188 | default="all",
189 | help=(
190 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
191 | ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations. '
192 | "Only applicable when `--with_tracking` is passed."
193 | ),
194 | )
195 |
196 | parser.add_argument(
197 | "--clip_norm",
198 | type=int,
199 | default=10,
200 | help="Clip norm for SGD optimizer (default: %(default)s).",
201 | )
202 |
203 | parser.add_argument(
204 | "--ignore_mismatched_sizes",
205 | action="store_true",
206 | help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
207 | )
208 |
209 | parser.add_argument('--optimizer', type=str, default='sassha', help='choose optim')
210 |
211 | # Second-order optimization settings
212 | parser.add_argument("--n_samples", default=1, type=int, help="the number of sampling")
213 | parser.add_argument('--betas', type=float, nargs='*', default=[0.9, 0.999], help='betas')
214 | parser.add_argument("--eps", default=1e-4, type=float, help="add a small number for stability")
215 | parser.add_argument("--lazy_hessian", default=10, type=int, help="Delayed hessian update")
216 | parser.add_argument("--clip_threshold", default=0.01, type=float, help="sophia clipping")
217 |
218 | # Hessian power scheduler
219 | parser.add_argument('--hessian_power_scheduler', type=str, default='constant', help="choose Hessian power 1. 'constant', 2. 'proportion', 3. 'linear', 4. 'cosine'")
220 | parser.add_argument('--max_hessian_power', type=float, default=1)
221 | parser.add_argument('--min_hessian_power', type=float, default=0.5)
222 | parser.add_argument('--min_lr', type=float, default=0.0, help="the minimum value of learning rate")
223 |
224 | # Sharpness minimization settings
225 | parser.add_argument("--rho", default=0.05, type=float, help="Rho parameter for sharpness minimization")
226 | parser.add_argument("--adaptive", default=False, type=bool, help="True if you want to use the Adaptive sharpness")
227 | parser.add_argument('--project_name', type=str, default='project_name', help="project_name")
228 |
229 | args = parser.parse_args()
230 |
231 | # Sanity checks
232 | if args.task_name is None and args.train_file is None and args.validation_file is None:
233 | raise ValueError("Need either a task name or a training/validation file.")
234 | else:
235 | if args.train_file is not None:
236 | extension = args.train_file.split(".")[-1]
237 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
238 | if args.validation_file is not None:
239 | extension = args.validation_file.split(".")[-1]
240 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
241 |
242 | if args.push_to_hub:
243 | assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
244 |
245 | return args
246 |
247 |
248 | def main():
249 | args = parse_args()
250 |
251 | # Control randomness
252 | random.seed(args.seed)
253 | torch.manual_seed(args.seed)
254 | torch.cuda.manual_seed_all(args.seed)
255 | set_seed(args.seed)
256 | torch.backends.cudnn.benchmark = False
257 | torch.backends.cudnn.deterministic = True
258 |
259 | # Load Dataset
260 | if args.task_name is not None:
261 | # Downloading and loading a dataset from the hub.
262 | raw_datasets = load_dataset("nyu-mll/glue", args.task_name)
263 | else:
264 | # Loading the dataset from local csv or json file.
265 | data_files = {}
266 | if args.train_file is not None:
267 | data_files["train"] = args.train_file
268 | if args.validation_file is not None:
269 | data_files["validation"] = args.validation_file
270 | extension = (args.train_file if args.train_file is not None else args.validation_file).split(".")[-1]
271 | raw_datasets = load_dataset(extension, data_files=data_files)
272 | # See more about loading any type of standard or custom dataset at
273 | # https://huggingface.co/docs/datasets/loading_datasets.
274 |
275 | # Labels
276 | if args.task_name is not None:
277 | is_regression = args.task_name == "stsb"
278 | if not is_regression:
279 | label_list = raw_datasets["train"].features["label"].names
280 | num_labels = len(label_list)
281 | else:
282 | num_labels = 1
283 | else:
284 | # Trying to have good defaults here, don't hesitate to tweak to your needs.
285 | is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
286 | if is_regression:
287 | num_labels = 1
288 | else:
289 | # A useful fast method:
290 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
291 | label_list = raw_datasets["train"].unique("label")
292 | label_list.sort() # Let's sort it for determinism
293 | num_labels = len(label_list)
294 |
295 | # Load pretrained model and tokenizer
296 |
297 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
298 | # download model & vocab.
299 |
300 | # Get config
301 | config = AutoConfig.from_pretrained(
302 | args.model_name_or_path,
303 | num_labels=num_labels,
304 | finetuning_task=args.task_name,
305 | trust_remote_code=args.trust_remote_code,
306 | )
307 |
308 | # Get tokenizer
309 | tokenizer = AutoTokenizer.from_pretrained(
310 | args.model_name_or_path, use_fast=not args.use_slow_tokenizer, trust_remote_code=args.trust_remote_code
311 | )
312 | if tokenizer.pad_token is None:
313 | tokenizer.pad_token = tokenizer.eos_token
314 | config.pad_token_id = tokenizer.pad_token_id
315 |
316 | # Get model
317 | model = AutoModelForSequenceClassification.from_pretrained(
318 | args.model_name_or_path,
319 | from_tf=bool(".ckpt" in args.model_name_or_path),
320 | config=config,
321 | ignore_mismatched_sizes=args.ignore_mismatched_sizes,
322 | trust_remote_code=args.trust_remote_code,
323 | )
324 |
325 | #
326 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
327 | model.to(device)
328 |
329 | # Preprocessing the datasets
330 | if args.task_name is not None:
331 | sentence1_key, sentence2_key = task_to_keys[args.task_name]
332 | else:
333 | # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
334 | non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
335 | if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
336 | sentence1_key, sentence2_key = "sentence1", "sentence2"
337 | else:
338 | if len(non_label_column_names) >= 2:
339 | sentence1_key, sentence2_key = non_label_column_names[:2]
340 | else:
341 | sentence1_key, sentence2_key = non_label_column_names[0], None
342 |
343 | # Some models have set the order of the labels to use, so let's make sure we do use it.
344 | label_to_id = None
345 | if (
346 | model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
347 | and args.task_name is not None
348 | and not is_regression
349 | ):
350 | # Some have all caps in their config, some don't.
351 | label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
352 | if sorted(label_name_to_id.keys()) == sorted(label_list):
353 | label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)}
354 |
355 | else:
356 | pass
357 |
358 | elif args.task_name is None and not is_regression:
359 | label_to_id = {v: i for i, v in enumerate(label_list)}
360 |
361 | if label_to_id is not None:
362 | model.config.label2id = label_to_id
363 | model.config.id2label = {id: label for label, id in config.label2id.items()}
364 | elif args.task_name is not None and not is_regression:
365 | model.config.label2id = {l: i for i, l in enumerate(label_list)}
366 | model.config.id2label = {id: label for label, id in config.label2id.items()}
367 |
368 | # Set target padding
369 | padding = "max_length" if args.pad_to_max_length else False
370 |
371 | def preprocess_function(examples):
372 | # Tokenize the texts
373 | texts = (
374 | (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
375 | )
376 | result = tokenizer(*texts, padding=padding, max_length=args.max_length, truncation=True)
377 |
378 | if "label" in examples:
379 | if label_to_id is not None:
380 | # Map labels to IDs (not necessary for GLUE tasks)
381 | result["labels"] = [label_to_id[l] for l in examples["label"]]
382 | else:
383 | # In all cases, rename the column to labels because the model will expect that.
384 | result["labels"] = examples["label"]
385 | return result
386 |
387 | processed_datasets = raw_datasets.map(
388 | preprocess_function,
389 | batched=True,
390 | remove_columns=raw_datasets["train"].column_names,
391 | desc="Running tokenizer on dataset",
392 | )
393 |
394 | train_dataset = processed_datasets["train"]
395 | eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"]
396 |
397 | # DataLoaders creation:
398 | if args.pad_to_max_length:
399 | # If padding was already done ot max length, we use the default data collator that will just convert everything
400 | # to tensors.
401 | data_collator = default_data_collator
402 | else:
403 | # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of
404 | # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple
405 | # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
406 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=None)
407 |
408 | train_dataloader = DataLoader(
409 | train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
410 | )
411 | eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
412 |
413 | wandb_name = f"{args.task_name}-{args.optimizer}-{args.lr}-{args.weight_decay}-{args.eps}"
414 |
415 | # get an optimizer
416 | optimizer, create_graph, two_steps = get_optimizer(model, args)
417 |
418 | # Scheduler and math around the number of training steps.
419 | overrode_max_train_steps = False
420 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
421 | if args.max_train_steps is None:
422 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
423 | overrode_max_train_steps = True
424 |
425 | lr_scheduler = get_scheduler(
426 | name=args.lr_scheduler_type,
427 | optimizer=optimizer,
428 | num_warmup_steps=args.num_warmup_steps,
429 | num_training_steps=args.max_train_steps,
430 | )
431 |
432 | # select a hessian power scheduler
433 | if args.optimizer == 'sassha':
434 | if args.hessian_power_scheduler == 'constant':
435 | hessian_power_scheduler = ConstantScheduler(
436 | T_max=args.num_train_epochs*len(train_dataloader),
437 | max_value=0.5,
438 | min_value=0.5)
439 |
440 | elif args.hessian_power_scheduler == 'proportion':
441 | hessian_power_scheduler = ProportionScheduler(
442 | pytorch_lr_scheduler=scheduler,
443 | max_lr=args.lr,
444 | min_lr=args.min_lr,
445 | max_value=args.max_hessian_power,
446 | min_value=args.min_hessian_power)
447 |
448 | elif args.hessian_power_scheduler == 'linear':
449 | hessian_power_scheduler = LinearScheduler(
450 | T_max=args.num_train_epochs*len(train_dataloader),
451 | max_value=args.max_hessian_power,
452 | min_value=args.min_hessian_power)
453 |
454 | elif args.hessian_power_scheduler == 'cosine':
455 | hessian_power_scheduler = CosineScheduler(
456 | T_max=args.num_train_epochs*len(train_dataloader),
457 | max_value=args.max_hessian_power,
458 | min_value=args.min_hessian_power)
459 |
460 | optimizer.hessian_power_scheduler = hessian_power_scheduler
461 |
462 | # We need to recalculate our total training steps as the size of the training dataloader may have changed
463 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
464 | if overrode_max_train_steps:
465 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
466 | # Afterwards we recalculate our number of training epochs
467 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
468 |
469 | # Get the metric function
470 | if args.task_name is not None:
471 | metric = evaluate.load("glue", args.task_name, experiment_id=wandb_name)
472 | else:
473 | metric = evaluate.load("accuracy", experiment_id=wandb_name)
474 |
475 | def compute_metric(eval_pred):
476 | predictions, labels = eval_pred
477 | if args.task_name != "stsb":
478 | predictions = torch.argmax(predictions, axis=1)
479 | else:
480 | predictions = predictions[:, 0]
481 | return metric.compute(predictions=predictions, references=labels)
482 |
483 | # Train!
484 | total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps
485 |
486 | # Only show the progress bar once on each machine.
487 | progress_bar = tqdm(range(args.max_train_steps))
488 | completed_steps = 0
489 | starting_epoch = 0
490 |
491 | # Potentially load in the weights and states from a previous save
492 | if args.resume_from_checkpoint:
493 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
494 | checkpoint_path = args.resume_from_checkpoint
495 | path = os.path.basename(args.resume_from_checkpoint)
496 | else:
497 | # Get the most recent checkpoint
498 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
499 | dirs.sort(key=os.path.getctime)
500 | path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
501 | checkpoint_path = path
502 | path = os.path.basename(checkpoint_path)
503 |
504 | # Extract `epoch_{i}` or `step_{i}`
505 | training_difference = os.path.splitext(path)[0]
506 |
507 | if "epoch" in training_difference:
508 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1
509 | resume_step = None
510 | completed_steps = starting_epoch * num_update_steps_per_epoch
511 | else:
512 | # need to multiply `gradient_accumulation_steps` to reflect real steps
513 | resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps
514 | starting_epoch = resume_step // len(train_dataloader)
515 | completed_steps = resume_step // args.gradient_accumulation_steps
516 | resume_step -= starting_epoch * len(train_dataloader)
517 |
518 | # wandb
519 | wandb_project = args.project_name
520 | wandb.init(project=wandb_project, name=wandb_name)
521 | wandb.config.update(args)
522 |
523 | # update the progress_bar if load from checkpoint
524 | progress_bar.update(completed_steps)
525 |
526 | for epoch in range(starting_epoch, args.num_train_epochs):
527 | model.train()
528 | total_loss = 0
529 |
530 | if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
531 | # We skip the first `n` batches in the dataloader when resuming from a checkpoint
532 | active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
533 | else:
534 | active_dataloader = train_dataloader
535 |
536 | for step, batch in enumerate(active_dataloader):
537 | batch = {k: v.to(device) for k, v in batch.items()}
538 |
539 | if two_steps:
540 | outputs = model(**batch)
541 | loss = outputs.loss
542 | # We keep track of the loss at each epoch
543 | total_loss += loss.item()
544 | loss = loss / args.gradient_accumulation_steps
545 | loss.backward()
546 |
547 | if args.optimizer == 'sassha':
548 | optimizer.perturb_weights(zero_grad=True)
549 |
550 | elif args.optimizer in ['samsgd', 'samadamw']:
551 | optimizer.first_step(zero_grad=True)
552 |
553 | model(**batch).loss.backward(create_graph=create_graph)
554 |
555 | if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
556 | if args.optimizer == 'sassha':
557 | optimizer.unperturb()
558 | optimizer.step()
559 | optimizer.zero_grad()
560 |
561 | elif args.optimizer in ['samsgd', 'samadamw']:
562 | if args.grad_clip_norm != 0:
563 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm)
564 | optimizer.second_step(zero_grad=True)
565 |
566 | lr_scheduler.step()
567 | progress_bar.update(1)
568 | completed_steps += 1
569 |
570 | else:
571 | outputs = model(**batch)
572 | loss = outputs.loss
573 | # We keep track of the loss at each epoch
574 | total_loss += loss.item()
575 | loss = loss / args.gradient_accumulation_steps
576 | loss.backward(create_graph=create_graph)
577 |
578 | if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
579 |
580 | if args.grad_clip_norm != 0:
581 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm)
582 |
583 | optimizer.step()
584 | optimizer.zero_grad()
585 | lr_scheduler.step()
586 |
587 | progress_bar.update(1)
588 | completed_steps += 1
589 |
590 | if completed_steps >= args.max_train_steps:
591 | break
592 |
593 | model.eval()
594 | all_labels, all_logits = [], []
595 | val_loss = 0
596 | for step, batch in enumerate(eval_dataloader):
597 | batch = {k: v.to(device) for k, v in batch.items()}
598 | with torch.no_grad():
599 | outputs = model(**batch)
600 | val_loss += outputs.loss.item()
601 | all_labels.append(batch['labels'])
602 | all_logits.append(outputs.logits)
603 |
604 | all_labels = torch.cat(all_labels, dim=0)
605 | all_logits = torch.cat(all_logits, dim=0)
606 | eval_metric = compute_metric((all_logits, all_labels))
607 |
608 | wandb.log({
609 | "accuracy" if args.task_name is not None else "glue": eval_metric,
610 | "train_loss": total_loss / len(train_dataloader),
611 | "val_loss": val_loss / len(eval_dataloader),
612 | "epoch": epoch,
613 | "step": completed_steps,
614 | "lr": optimizer.param_groups[0]['lr'],
615 | "hessian_power": optimizer.hessian_power_t if args.optimizer == 'sassha' else 0,
616 | })
617 |
618 | if args.task_name == "mnli":
619 | # Final evaluation on mismatched validation set
620 | eval_dataset = processed_datasets["validation_mismatched"]
621 | eval_dataloader = DataLoader(
622 | eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
623 | )
624 |
625 | model.eval()
626 | all_labels, all_logits = [], []
627 | for step, batch in enumerate(eval_dataloader):
628 | batch = {k: v.to(device) for k, v in batch.items()}
629 | with torch.no_grad():
630 | outputs = model(**batch)
631 | all_labels.append(batch['labels'])
632 | all_logits.append(outputs.logits)
633 |
634 | all_labels = torch.cat(all_labels, dim=0)
635 | all_logits = torch.cat(all_logits, dim=0)
636 | eval_metric = compute_metric((all_logits, all_labels))
637 |
638 | wandb.log({
639 | "accuracy-mm" if args.task_name is not None else "glue": eval_metric,
640 | })
641 |
642 | if args.output_dir is not None:
643 | all_results = {f"eval_{k}": v for k, v in eval_metric.items()}
644 | with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
645 | json.dump(all_results, f)
646 |
647 |
648 | if __name__ == "__main__":
649 | main()
650 |
--------------------------------------------------------------------------------
/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | from optimizers.sam import SAM
2 | from optimizers.msassha import MSASSHA
3 | from optimizers.adahessian import Adahessian
4 | from optimizers.sophiaH import SophiaH
5 | from optimizers.shampoo import Shampoo
6 | from optimizers.sassha import SASSHA
7 | import torch.optim as optim
8 |
9 | def configure_sassha(model, args):
10 | create_graph = True
11 | two_steps = True
12 |
13 | optimizer = SASSHA(
14 | model.parameters(),
15 | betas=tuple(args.betas),
16 | lr=args.lr,
17 | weight_decay=args.weight_decay/args.lr,
18 | rho=args.rho,
19 | lazy_hessian=args.lazy_hessian,
20 | eps=args.eps,
21 | seed=args.seed)
22 |
23 | return optimizer, create_graph, two_steps
24 |
25 |
26 | def configure_samsgd(model, args):
27 | create_graph = False
28 | two_steps = True
29 |
30 | optimizer = SAM(
31 | model.parameters(), optim.SGD, rho=args.rho, adaptive=args.adaptive,
32 | momentum=0.9,
33 | lr=args.lr,
34 | weight_decay=args.weight_decay)
35 |
36 | return optimizer, create_graph, two_steps
37 |
38 |
39 | def configure_samadamw(model, args):
40 | create_graph = False
41 | two_steps = True
42 |
43 | optimizer = SAM(
44 | model.parameters(), optim.AdamW, rho=args.rho, adaptive=args.adaptive,
45 | betas=tuple(args.betas),
46 | lr=args.lr,
47 | weight_decay=args.weight_decay/args.lr)
48 |
49 | return optimizer, create_graph, two_steps
50 |
51 |
52 | def configure_sgd(model, args):
53 | create_graph = False
54 | two_steps = False
55 |
56 | optimizer = optim.SGD(
57 | model.parameters(),
58 | lr=args.lr,
59 | momentum=0.9,
60 | weight_decay=args.weight_decay)
61 |
62 | return optimizer, create_graph, two_steps
63 |
64 |
65 | def configure_adamw(model, args):
66 | create_graph = False
67 | two_steps = False
68 |
69 | optimizer = optim.AdamW(
70 | model.parameters(),
71 | betas=tuple(args.betas),
72 | lr=args.lr,
73 | weight_decay=args.weight_decay/args.lr)
74 |
75 | return optimizer, create_graph, two_steps
76 |
77 |
78 | def configure_adahessian(model, args):
79 | create_graph = True
80 | two_steps = False
81 |
82 | optimizer = Adahessian(
83 | model.parameters(),
84 | betas=tuple(args.betas),
85 | lr=args.lr,
86 | weight_decay=args.weight_decay/args.lr,
87 | lazy_hessian=args.lazy_hessian,
88 | eps=args.eps,
89 | seed=args.seed)
90 |
91 | return optimizer, create_graph, two_steps
92 |
93 |
94 | def configure_sophiah(model, args):
95 | create_graph = True
96 | two_steps = False
97 |
98 | optimizer = SophiaH(
99 | model.parameters(),
100 | betas=tuple(args.betas),
101 | lr=args.lr,
102 | weight_decay=args.weight_decay / args.lr,
103 | clip_threshold=args.clip_threshold,
104 | eps=args.eps,
105 | lazy_hessian=args.lazy_hessian,
106 | seed=args.seed)
107 |
108 | return optimizer, create_graph, two_steps
109 |
110 |
111 | def configure_msassha(model, args):
112 | create_graph = True
113 | two_steps = False
114 |
115 | optimizer = MSASSHA(
116 | model.parameters(),
117 | lr=args.lr,
118 | rho=args.rho,
119 | weight_decay=args.weight_decay / args.lr,
120 | lazy_hessian=args.lazy_hessian,
121 | eps=args.eps,
122 | seed=args.seed)
123 |
124 | return optimizer, create_graph, two_steps
125 |
126 |
127 | def configure_shampoo(model, args):
128 | create_graph = False
129 | two_steps = False
130 |
131 | optimizer = Shampoo(
132 | params=model.parameters(),
133 | lr=args.lr,
134 | momentum=0.9,
135 | weight_decay=args.weight_decay,
136 | epsilon=args.eps,
137 | update_freq=1)
138 |
139 | return optimizer, create_graph, two_steps
140 |
141 |
142 | def get_optimizer(model, args):
143 | optimizer_map = {
144 | 'sassha': configure_sassha,
145 | 'samsgd': configure_samsgd,
146 | 'samadamw': configure_samadamw,
147 | 'adahessian': configure_adahessian,
148 | 'adamw': configure_adamw,
149 | 'sgd': configure_sgd,
150 | 'sophiah': configure_sophiah,
151 | 'msassha': configure_msassha,
152 | 'shampoo': configure_shampoo,
153 | }
154 |
155 | if args.optimizer not in optimizer_map:
156 | raise ValueError(f"Unsupported optimizer: {args.optimizer}")
157 |
158 | return optimizer_map[args.optimizer](model, args)
159 |
160 |
--------------------------------------------------------------------------------
/optimizers/adahessian.py:
--------------------------------------------------------------------------------
1 | # Acknowledgement: This code is based on https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py
2 |
3 | import torch
4 | import math
5 |
6 | class Adahessian(torch.optim.Optimizer):
7 | """
8 | Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning"
9 | At the previous experiment, I found there's a problem on 'denom' that the experiment result varies depending on
10 | how to define 'denom'. So I decided to follow the official pytorch adam code.
11 |
12 | Arguments:
13 | params (iterable) -- iterable of parameters to optimize or dicts defining parameter groups
14 | lr (float, optional) -- learning rate (default: 0.15)
15 | betas ((float, float), optional) -- coefficients used for computing running averages of gradient and the squared hessian trace (default: (0.9, 0.999))
16 | eps (float, optional) -- term added to the denominator to improve numerical stability (default: 1e-4)
17 | weight_decay (float, optional) -- weight decay (L2 penalty) (default: 0.0)
18 | hessian_power (float, optional) -- exponent of the hessian trace (default: 1.0)
19 | lazy_hessian (int, optional) -- compute the hessian trace approximation only after *this* number of steps (to save time) (default: 1)
20 | n_samples (int, optional) -- how many times to sample `z` for the approximation of the hessian trace (default: 1)
21 | """
22 |
23 | def __init__(self,
24 | params,
25 | lr=0.15,
26 | betas=(0.9, 0.999),
27 | eps=1e-4,
28 | weight_decay=0.0,
29 | hessian_power=1,
30 | lazy_hessian=1,
31 | n_samples=1,
32 | seed=0):
33 |
34 | if not 0.0 <= lr:
35 | raise ValueError(f"Invalid learning rate: {lr}")
36 | if not 0.0 <= eps:
37 | raise ValueError(f"Invalid epsilon value: {eps}")
38 | if not 0.0 <= betas[0] < 1.0:
39 | raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
40 | if not 0.0 <= betas[1] < 1.0:
41 | raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
42 | if not 0.0 <= hessian_power <= 1.0:
43 | raise ValueError(f"Invalid Hessian power value: {hessian_power}")
44 |
45 | self.n_samples = n_samples
46 | self.lazy_hessian = lazy_hessian
47 | self.seed = seed
48 |
49 | # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training
50 | self.generator = torch.Generator().manual_seed(self.seed)
51 |
52 | defaults = dict(lr=lr,
53 | betas=betas,
54 | eps=eps,
55 | weight_decay=weight_decay,
56 | hessian_power=hessian_power)
57 |
58 | super(Adahessian, self).__init__(params, defaults)
59 |
60 | for p in self.get_params():
61 | p.hess = 0.0
62 | self.state[p]["hessian step"] = 0
63 |
64 | def get_params(self):
65 | """
66 | Gets all parameters in all param_groups with gradients
67 | """
68 |
69 | return (p for group in self.param_groups for p in group['params'] if p.requires_grad)
70 |
71 | def zero_hessian(self):
72 | """
73 | Zeros out the accumalated hessian traces.
74 | """
75 |
76 | for p in self.get_params():
77 | if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.lazy_hessian == 0:
78 | p.hess.zero_()
79 |
80 | @torch.no_grad()
81 | def set_hessian(self):
82 | """
83 | Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter.
84 | """
85 | params = []
86 | for p in filter(lambda p: p.grad is not None, self.get_params()):
87 | if self.state[p]["hessian step"] % self.lazy_hessian == 0: # compute a new Hessian per `lazy_hessian` step
88 | params.append(p)
89 | self.state[p]["hessian step"] += 1
90 |
91 | if len(params) == 0:
92 | return
93 |
94 | if self.generator.device != params[0].device and self.seed is not None: # hackish way of casting the generator to the right device
95 | self.generator = torch.Generator(params[0].device).manual_seed(self.seed)
96 |
97 | grads = [p.grad for p in params]
98 |
99 | last_sample = self.n_samples - 1
100 | for i in range(self.n_samples):
101 | # Rademacher distribution {-1.0, 1.0}
102 | zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params]
103 | h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < last_sample)
104 | for h_z, z, p in zip(h_zs, zs, params):
105 | p.hess += h_z * z / self.n_samples
106 |
107 | @torch.no_grad()
108 | def step(self, closure=None):
109 | """
110 | Performs a single optimization step.
111 | Arguments:
112 | closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None)
113 | """
114 |
115 | loss = None
116 | if closure is not None:
117 | loss = closure()
118 |
119 | self.zero_hessian()
120 | self.set_hessian()
121 |
122 | for group in self.param_groups:
123 | for p in group['params']:
124 | if p.grad is None or p.hess is None:
125 | continue
126 |
127 | if p.dim() <= 2:
128 | p.hess = p.hess.abs().clone()
129 |
130 | if p.dim() == 4:
131 | p.hess = torch.mean(p.hess.abs(), dim=[2, 3], keepdim=True).expand_as(p.hess).clone()
132 |
133 | # Perform correct stepweight decay as in AdamW
134 | p.mul_(1 - group['lr'] * group['weight_decay'])
135 |
136 | state = self.state[p]
137 |
138 | # State initialization
139 | if len(state) == 1:
140 | state['step'] = 0
141 | state['exp_avg'] = torch.zeros_like(p.data) # Exponential moving average of gradient values
142 | state['exp_hessian_diag_sq'] = torch.zeros_like(p.data) # Exponential moving average of Hessian diagonal square values
143 |
144 | exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq']
145 | beta1, beta2 = group['betas']
146 | state['step'] += 1
147 |
148 | # Decay the first and second moment running average coefficient
149 | exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
150 | exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2)
151 |
152 | bias_correction1 = 1 - beta1 ** state['step']
153 | bias_correction2 = 1 - beta2 ** state['step']
154 |
155 | step_size = group['lr'] / bias_correction1
156 | step_size_neg = -step_size
157 |
158 | k = group['hessian_power']
159 | denom = (exp_hessian_diag_sq / bias_correction2).pow_(k/2).add_(group['eps'])
160 |
161 | # make update
162 | p.addcdiv_(exp_avg, denom, value=step_size_neg)
163 |
164 | return loss
165 |
--------------------------------------------------------------------------------
/optimizers/hessian_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 |
4 | class ProportionScheduler:
5 | def __init__(self, pytorch_lr_scheduler, max_lr, min_lr, max_value, min_value):
6 | """
7 | This scheduler outputs a value that evolves proportional to pytorch_lr_scheduler, e.g.
8 | (value - min_value) / (max_value - min_value) = (lr - min_lr) / (max_lr - min_lr)
9 | """
10 | self.t = 0
11 | self.pytorch_lr_scheduler = pytorch_lr_scheduler
12 | self.max_lr = max_lr
13 | self.min_lr = min_lr
14 | self.max_value = max_value
15 | self.min_value = min_value
16 |
17 | assert (max_lr > min_lr) or ((max_lr==min_lr) and (max_value==min_value)), "Current scheduler for `value` is scheduled to evolve proportionally to `lr`," \
18 | "e.g. `(lr - min_lr) / (max_lr - min_lr) = (value - min_value) / (max_value - min_value)`. Please check `max_lr >= min_lr` and `max_value >= min_value`;" \
19 | "if `max_lr==min_lr` hence `lr` is constant with step, please set 'max_value == min_value' so 'value' is constant with step."
20 |
21 | assert max_value >= min_value
22 |
23 | self.step() # take 1 step during initialization to get self._last_lr
24 |
25 | def lr(self):
26 | return self._last_lr[0]
27 |
28 | def step(self):
29 | self.t += 1
30 | if hasattr(self.pytorch_lr_scheduler, "_last_lr"):
31 | lr = self.pytorch_lr_scheduler._last_lr[0]
32 | else:
33 | lr = self.pytorch_lr_scheduler.optimizer.param_groups[0]['lr']
34 |
35 | if self.max_lr > self.min_lr:
36 | value = self.max_value - (self.max_value - self.min_value) * (lr - self.min_lr) / (self.max_lr - self.min_lr)
37 | else:
38 | value = self.max_value
39 |
40 | self._last_lr = [value]
41 | return value
42 |
43 | class SchedulerBase:
44 | def __init__(self, T_max, max_value, min_value=0.0, init_value=0.0, warmup_steps=0, optimizer=None):
45 | super(SchedulerBase, self).__init__()
46 | self.t = 0
47 | self.min_value = min_value
48 | self.max_value = max_value
49 | self.init_value = init_value
50 | self.warmup_steps = warmup_steps
51 | self.total_steps = T_max
52 |
53 | # record current value in self._last_lr to match API from torch.optim.lr_scheduler
54 | self._last_lr = [init_value]
55 |
56 | # If optimizer is not None, will set learning rate to all trainable parameters in optimizer.
57 | # If optimizer is None, only output the value of lr.
58 | self.optimizer = optimizer
59 |
60 | def step(self):
61 | if self.t < self.warmup_steps:
62 | value = self.init_value + (self.max_value - self.init_value) * self.t / self.warmup_steps
63 | elif self.t == self.warmup_steps:
64 | value = self.min_value
65 | else:
66 | value = self.step_func()
67 | self.t += 1
68 |
69 | # apply the lr to optimizer if it's provided
70 | if self.optimizer is not None:
71 | for param_group in self.optimizer.param_groups:
72 | param_group['lr'] = value
73 |
74 | self._last_lr = [value]
75 | return value
76 |
77 | def step_func(self):
78 | pass
79 |
80 | def lr(self):
81 | return self._last_lr[0]
82 |
83 | class LinearScheduler(SchedulerBase):
84 | def step_func(self):
85 | value = self.min_value + (self.max_value - self.min_value) * (self.t - self.warmup_steps) / (
86 | self.total_steps - self.warmup_steps)
87 | return value
88 |
89 | class CosineScheduler(SchedulerBase):
90 | def step_func(self):
91 | phase = (self.t-self.warmup_steps) / (self.total_steps-self.warmup_steps) * math.pi
92 | value = self.max_value - (self.max_value-self.min_value) * (np.cos(phase) + 1.) / 2.0
93 | return value
94 |
95 | class ConstantScheduler(SchedulerBase):
96 | def step_func(self):
97 | value = self.min_value
98 | return value
99 |
100 |
--------------------------------------------------------------------------------
/optimizers/msassha.py:
--------------------------------------------------------------------------------
1 | # Acknowledgement: This code is based on https://github.com/MarlonBecker/MSAM/blob/main/optimizer/adamW_msam.py
2 |
3 | from typing import List
4 | from torch import Tensor
5 | import torch
6 | import math
7 |
8 | # cf. https://github.com/pytorch/pytorch/blob/main/torch/optim/adamw.py
9 | class MSASSHA(torch.optim.Optimizer):
10 | def __init__(
11 | self,
12 | params,
13 | lr: float = 0.15,
14 | betas: float = (0.9, 0.999),
15 | weight_decay: float = 1e-2,
16 | lazy_hessian: int = 10,
17 | rho: float = 0.3,
18 | n_samples: int = 1,
19 | eps: float = 1e-4,
20 | hessian_power: int = 1,
21 | seed: int = 0,
22 | maximize: bool = False
23 | ):
24 |
25 | defaults = dict(
26 | lr=lr,
27 | betas=betas,
28 | weight_decay=weight_decay,
29 | rho=rho,
30 | eps=eps,
31 | maximize=maximize
32 | )
33 |
34 | self.lazy_hessian = lazy_hessian
35 | self.n_samples = n_samples
36 | self.seed = seed
37 | self.hessian_power = hessian_power
38 |
39 | # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training
40 | self.generator = torch.Generator().manual_seed(self.seed)
41 |
42 | super(MSASSHA, self).__init__(params, defaults)
43 |
44 | # init momentum buffer to zeros
45 | # needed to make implementation of first ascent step cleaner (before SGD.step() was ever called)
46 |
47 | for p in self.get_params():
48 | p.hess = 0.0
49 | if self.track_hessian:
50 | p.real_hess = 0.0
51 |
52 | state = self.state[p]
53 | state["hessian_step"] = 0
54 |
55 | for group in self.param_groups:
56 | group["norm_factor"] = [0,]
57 |
58 | def get_params(self):
59 | """
60 | Gets all parameters in all param_groups with gradients
61 | """
62 |
63 | return (p for group in self.param_groups for p in group['params'] if p.requires_grad)
64 |
65 |
66 | def zero_hessian(self):
67 | """
68 | Zeros out the accumulated hessian traces.
69 | """
70 |
71 | for p in self.get_params():
72 | if not isinstance(p.hess, float) and self.state[p]["hessian_step"] % self.lazy_hessian == 0:
73 | p.hess.zero_()
74 |
75 |
76 | @torch.no_grad()
77 | def set_hessian(self):
78 | """
79 | Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter.
80 | """
81 |
82 | params = []
83 | for p in filter(lambda p: p.grad is not None, self.get_params()):
84 | if self.state[p]["hessian_step"] % self.lazy_hessian == 0: # compute a new Hessian per `lazy_hessian` step
85 | params.append(p)
86 | self.state[p]["hessian_step"] += 1
87 |
88 | if len(params) == 0:
89 | return
90 |
91 | if self.generator.device != params[0].device: # hackish way of casting the generator to the right device
92 | self.generator = torch.Generator(params[0].device).manual_seed(self.seed)
93 |
94 | grads = [p.grad for p in params]
95 |
96 | last_sample = self.n_samples - 1
97 | for i in range(self.n_samples):
98 | # Rademacher distribution {-1.0, 1.0}
99 | zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params]
100 | h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < last_sample)
101 | for h_z, z, p in zip(h_zs, zs, params):
102 | p.hess += h_z * z / self.n_samples
103 |
104 | @torch.no_grad()
105 | def step(self, closure=None):
106 | """ Performs a single optimization step.
107 |
108 | Args:
109 | closure (callable, optional): A closure that reevaluates the model
110 | and returns the loss.
111 | """
112 | loss = None
113 | if closure is not None:
114 | with torch.enable_grad():
115 | loss = closure()
116 | if len(self.param_groups) > 1:
117 | raise RuntimeError("only one parameter group supported atm for SAMANDA_MSAM")
118 |
119 | group = self.param_groups[0]
120 | params_with_grad = []
121 | grads = []
122 | hesses = []
123 | grad_momentums = []
124 | hess_momentums = []
125 | state_steps = []
126 |
127 | beta1, beta2 = group['betas']
128 |
129 | self.zero_hessian()
130 | self.set_hessian()
131 |
132 | for p in group['params']:
133 | if p.grad is None:
134 | continue
135 |
136 | params_with_grad.append(p)
137 | if p.grad.is_sparse:
138 | raise RuntimeError('Msassha does not support sparse gradients')
139 |
140 | grads.append(p.grad)
141 | hesses.append(p.hess)
142 |
143 | state = self.state[p]
144 | # State initialization
145 | if len(state) == 1:
146 | state['step'] = 0
147 | state['grad_momentum'] = torch.zeros_like(p, memory_format=torch.preserve_format)
148 | state['hess_momentum'] = torch.zeros_like(p, memory_format=torch.preserve_format)
149 |
150 | grad_momentums.append(state['grad_momentum'])
151 | hess_momentums.append(state['hess_momentum'])
152 |
153 | # update the steps for each param group update
154 | state['step'] += 1
155 | # record the step after step update
156 | state_steps.append(state['step'])
157 |
158 | samanda_msam(params_with_grad,
159 | grads,
160 | hesses,
161 | grad_momentums,
162 | hess_momentums,
163 | state_steps,
164 | beta1=beta1,
165 | beta2=beta2,
166 | lr=group['lr'],
167 | weight_decay=group['weight_decay'],
168 | lazy_hessian=self.lazy_hessian,
169 | rho=group['rho'],
170 | norm_factor=group['norm_factor'],
171 | eps=group['eps']
172 | )
173 |
174 | return loss
175 |
176 |
177 | @torch.no_grad()
178 | def move_up_to_momentumAscent(self):
179 | for group in self.param_groups:
180 | for p in group['params']:
181 | if "grad_momentum" in self.state[p]:
182 | p.sub_(self.state[p]["grad_momentum"], alpha=group["norm_factor"][0])
183 |
184 |
185 | @torch.no_grad()
186 | def move_back_from_momentumAscent(self):
187 | for group in self.param_groups:
188 | for p in group['params']:
189 | if "grad_momentum" in self.state[p]:
190 | p.add_(self.state[p]["grad_momentum"], alpha=group["norm_factor"][0])
191 |
192 | bias_correction2 = 0
193 | def samanda_msam(params: List[Tensor],
194 | grads: List[Tensor],
195 | hesses: List[Tensor],
196 | grad_momentums: List[Tensor],
197 | hess_momentums: List[Tensor],
198 | state_steps: List[int],
199 | *,
200 | beta1: float,
201 | beta2: float,
202 | lr: float,
203 | weight_decay: float,
204 | lazy_hessian: int,
205 | rho:float,
206 | norm_factor: list,
207 | eps: float
208 | ):
209 | r"""Functional API that performs AdamW algorithm computation.
210 |
211 | See :class:`~torch.optim.AdamW` for details.
212 | """
213 |
214 | for i, param in enumerate(params):
215 | grad = grads[i]
216 | hess = hesses[i]
217 | grad_momentum = grad_momentums[i]
218 | hess_momentum = hess_momentums[i]
219 | step = state_steps[i]
220 |
221 | # remove last perturbation (descent) w_t <- \tilde{w_t} + rho*m_t/||m_t||
222 | param.add_(grad_momentum, alpha=norm_factor[0])
223 |
224 | # weight decay
225 | param.mul_(1 - lr * weight_decay)
226 |
227 | # Decay the first and second moment running average coefficient
228 | grad_momentum.mul_(beta1).add_(grad, alpha=1-beta1)
229 | bias_correction1 = 1 - beta1 ** step
230 |
231 | if (step-1) % lazy_hessian == 0:
232 | hess_momentum.mul_(beta2).add_(hess.abs(), alpha=1-beta2)
233 | global bias_correction2
234 | bias_correction2 = 1 - beta2 ** step
235 |
236 | denom = ((hess_momentum ** self.hessian_power) / (bias_correction2 ** self.hessian_power)).add_(eps)
237 |
238 | step_size = lr / bias_correction1
239 |
240 | # make update
241 | param.addcdiv_(grad_momentum, denom, value=-step_size)
242 |
243 | #calculate ascent step norm
244 | ascent_norm = torch.norm(
245 | torch.stack([
246 | grad_momentum.norm(p=2)
247 | for grad_momentum in grad_momentums
248 | ]),
249 | p=2
250 | )
251 | norm_factor[0] = 1/(ascent_norm+1e-12) * rho
252 |
253 | # perturb for next iteration (ascent)
254 | for i, param in enumerate(params):
255 | param.sub_(grad_momentums[i], alpha = norm_factor[0])
256 |
257 |
--------------------------------------------------------------------------------
/optimizers/sam.py:
--------------------------------------------------------------------------------
1 | # Acknowledgement: This code is based on https://github.com/davda54/sam/blob/main/sam.py
2 | import torch
3 |
4 |
5 | class SAM(torch.optim.Optimizer):
6 | def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
7 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
8 |
9 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
10 | super(SAM, self).__init__(params, defaults)
11 |
12 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
13 | self.param_groups = self.base_optimizer.param_groups
14 | self.defaults.update(self.base_optimizer.defaults)
15 |
16 | @torch.no_grad()
17 | def first_step(self, zero_grad=False):
18 | grad_norm = self._grad_norm()
19 | for group in self.param_groups:
20 | scale = group["rho"] / (grad_norm + 1e-12)
21 |
22 | for p in group["params"]:
23 | if p.grad is None: continue
24 | self.state[p]["old_p"] = p.data.clone()
25 | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
26 | p.add_(e_w) # climb to the local maximum "w + e(w)"
27 |
28 | if zero_grad: self.zero_grad()
29 |
30 | @torch.no_grad()
31 | def second_step(self, zero_grad=False):
32 | for group in self.param_groups:
33 | for p in group["params"]:
34 | if p.grad is None: continue
35 | p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)"
36 |
37 | self.base_optimizer.step() # do the actual "sharpness-aware" update
38 |
39 | if zero_grad: self.zero_grad()
40 |
41 | @torch.no_grad()
42 | def step(self, closure=None):
43 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
44 | closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
45 |
46 | self.first_step(zero_grad=True)
47 | closure()
48 | self.second_step()
49 |
50 | def _grad_norm(self):
51 | shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism
52 | norm = torch.norm(
53 | torch.stack([
54 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
55 | for group in self.param_groups for p in group["params"]
56 | if p.grad is not None
57 | ]),
58 | p=2
59 | )
60 | return norm
61 |
62 | def load_state_dict(self, state_dict):
63 | super().load_state_dict(state_dict)
64 | self.base_optimizer.param_groups = self.param_groups
--------------------------------------------------------------------------------
/optimizers/sassha.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import contextlib
4 | from torch.distributed import ReduceOp
5 |
6 |
7 | class SASSHA(torch.optim.Optimizer):
8 | """Implements the Sharpness-Aware Second-Order optimization with Stable Hessian Approximation (SASSHA) algorithm.
9 |
10 | Args:
11 | params (iterable): Iterable of parameters to optimize or dicts defining parameter groups.
12 | hessian_power_scheduler (None): Update the Hessian power at every training step. Initially, set it to None, and later you can replace it.
13 | lr (float, optional): Learning rate.
14 | betas (Tuple[float, float], optional): Coefficients for computing moving averages of gradient and Hessian.
15 | weight_decay (float, optional): Weight decay (L2 penalty).
16 | rho (float, optional): Size of the neighborhood for computing the max loss
17 | hessian_power (float, optional): Exponent of the Hessian in the update rule.
18 | lazy_hessian (int, optional): Number of optimization steps to perform before updating the Hessian.
19 | n_samples (int, optional): Number of samples to draw for the Hutchinson approximation.
20 | perturb_eps (float, optional): Small value for perturbations in Hessian trace computation.
21 | eps (float, optional): Term added to the denominator to improve numerical stability.
22 | adaptive (bool, optional): set this argument to True if you want to use an experimental implementation of element-wise Adaptive SAM. Default is False.
23 | grad_reduce (str, optional): Reduction method for gradients ('mean' or 'sum'). Default is 'mean'.
24 | seed (int, optional): Random seed for reproducibility. Default is 0.
25 | **kwargs: Additional keyword arguments for compatibility with other optimizers.
26 | """
27 |
28 | def __init__(self, params,
29 | hessian_power_scheduler=None,
30 | lr=0.15,
31 | betas=(0.9, 0.999),
32 | weight_decay=0.0,
33 | rho=0.0,
34 | lazy_hessian=10,
35 | n_samples=1,
36 | perturb_eps=1e-12,
37 | eps=1e-4,
38 | adaptive=False,
39 | grad_reduce='mean',
40 | seed=0,
41 | **kwargs):
42 |
43 | if not 0.0 <= lr:
44 | raise ValueError(f"Invalid learning rate: {lr}")
45 | if not 0.0 <= eps:
46 | raise ValueError(f"Invalid epsilon value: {eps}")
47 | if not 0.0 <= betas[0] < 1.0:
48 | raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
49 | if not 0.0 <= betas[1] < 1.0:
50 | raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
51 |
52 | self.hessian_power_scheduler = hessian_power_scheduler
53 | self.lazy_hessian = lazy_hessian
54 | self.n_samples = n_samples
55 | self.adaptive = adaptive
56 | self.seed = seed
57 |
58 | defaults = dict(lr=lr,
59 | betas=betas,
60 | weight_decay=weight_decay,
61 | rho=rho,
62 | perturb_eps=perturb_eps,
63 | eps=eps)
64 |
65 | super(SASSHA, self).__init__(params, defaults)
66 |
67 | for p in self.get_params():
68 | p.hess = 0.0
69 | self.state[p]["hessian step"] = 0
70 |
71 | # set up reduction for gradient across workers
72 | if grad_reduce.lower() == 'mean':
73 | if hasattr(ReduceOp, 'AVG'):
74 | self.grad_reduce = ReduceOp.AVG
75 | self.manual_average = False
76 | else: # PyTorch <= 1.11.0 does not have AVG, need to manually average across processes
77 | self.grad_reduce = ReduceOp.SUM
78 | self.manual_average = True
79 | elif grad_reduce.lower() == 'sum':
80 | self.grad_reduce = ReduceOp.SUM
81 | self.manual_average = False
82 | else:
83 | raise ValueError('"grad_reduce" should be one of ["mean", "sum"].')
84 |
85 | # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training
86 | self.generator = torch.Generator().manual_seed(self.seed)
87 |
88 |
89 | def get_params(self):
90 | """
91 | Gets all parameters in all param_groups with gradients
92 | """
93 |
94 | return (p for group in self.param_groups for p in group['params'] if p.requires_grad)
95 |
96 | def zero_hessian(self):
97 | """
98 | Zeros out the accumalated hessian traces.
99 | """
100 |
101 | for p in self.get_params():
102 | if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.lazy_hessian == 0:
103 | p.hess.zero_()
104 |
105 |
106 | @torch.no_grad()
107 | def update_hessian_power(self):
108 | """
109 | Update the Hessian power at every training step.
110 | """
111 | if self.hessian_power_scheduler is not None:
112 | self.hessian_power_t = self.hessian_power_scheduler.step()
113 | else:
114 | self.hessian_power_t = None
115 | return self.hessian_power_t
116 |
117 |
118 | @torch.no_grad()
119 | def set_hessian(self):
120 | """
121 | Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter.
122 | """
123 | params = []
124 | for p in filter(lambda p: p.grad is not None, self.get_params()):
125 | if self.state[p]["hessian step"] % self.lazy_hessian == 0: # compute a new Hessian once per 'lazy hessian' steps
126 | params.append(p)
127 | self.state[p]["hessian step"] += 1
128 |
129 | if len(params) == 0:
130 | return
131 |
132 | if self.generator.device != params[0].device: # hackish way of casting the generator to the right device
133 | self.generator = torch.Generator(params[0].device).manual_seed(self.seed)
134 |
135 | grads = [p.grad for p in params]
136 |
137 | last_sample = self.n_samples - 1
138 | for i in range(self.n_samples):
139 | # Rademacher distribution {-1.0, 1.0}
140 | zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params]
141 | h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < last_sample)
142 | for h_z, z, p in zip(h_zs, zs, params):
143 | p.hess += h_z * z / self.n_samples
144 |
145 | @torch.no_grad()
146 | def perturb_weights(self, zero_grad=True):
147 | grad_norm = self._grad_norm(weight_adaptive=self.adaptive)
148 | for group in self.param_groups:
149 | scale = group["rho"] / (grad_norm + group["perturb_eps"])
150 |
151 | for p in group["params"]:
152 | if p.grad is None: continue
153 | e_w = p.grad * scale.to(p)
154 | if self.adaptive:
155 | e_w *= torch.pow(p, 2)
156 | p.add_(e_w) # climb to the local maximum "w + e(w)"
157 | self.state[p]['e_w'] = e_w
158 |
159 | if zero_grad: self.zero_grad()
160 |
161 | @torch.no_grad()
162 | def unperturb(self):
163 | for group in self.param_groups:
164 | for p in group['params']:
165 | if 'e_w' in self.state[p].keys():
166 | p.data.sub_(self.state[p]['e_w'])
167 |
168 | @torch.no_grad()
169 | def _grad_norm(self, by=None, weight_adaptive=False):
170 | if not by:
171 | norm = torch.norm(
172 | torch.stack([
173 | ( (torch.abs(p.data) if weight_adaptive else 1.0) * p.grad).norm(p=2)
174 | for group in self.param_groups for p in group["params"]
175 | if p.grad is not None
176 | ]),
177 | p=2
178 | )
179 | else:
180 | norm = torch.norm(
181 | torch.stack([
182 | ( (torch.abs(p.data) if weight_adaptive else 1.0) * self.state[p][by]).norm(p=2)
183 | for group in self.param_groups for p in group["params"]
184 | if p.grad is not None
185 | ]),
186 | p=2
187 | )
188 | return norm
189 |
190 | @torch.no_grad()
191 | def _sync_gradients(self):
192 | for group in self.param_groups:
193 | for p in group['params']:
194 | if p.grad is None: continue
195 | if torch.distributed.is_initialized(): # synchronize final gardients
196 | if self.manual_average:
197 | torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
198 | world_size = torch.distributed.get_world_size()
199 | p.grad.div_(float(world_size))
200 | else:
201 | torch.distributed.all_reduce(p.grad, op=self.grad_reduce)
202 | return
203 |
204 | @torch.no_grad()
205 | def _sync_hessians(self):
206 | for group in self.param_groups:
207 | for p in group['params']:
208 | if p.hess is None: continue
209 | if torch.distributed.is_initialized(): # synchronize final hessian
210 | if not p.hess.is_contiguous():
211 | p.hess = p.hess.contiguous()
212 |
213 | if self.manual_average:
214 | torch.distributed.all_reduce(p.hess, op=self.grad_reduce)
215 | world_size = torch.distributed.get_world_size()
216 | p.hess.div_(float(world_size))
217 | else:
218 | torch.distributed.all_reduce(p.hess, op=self.grad_reduce)
219 | return
220 |
221 | @torch.no_grad()
222 | def step(self, closure=None):
223 | """
224 | Performs a single optimization step.
225 | Arguments:
226 | closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None)
227 | """
228 |
229 | self.update_hessian_power()
230 |
231 | loss = None
232 | if closure is not None:
233 | loss = closure()
234 |
235 | self.zero_hessian()
236 | self.set_hessian()
237 | self._sync_gradients()
238 | self._sync_hessians()
239 |
240 | # prepare to update parameters
241 | for group in self.param_groups:
242 | for p in group['params']:
243 | if p.grad is None or p.hess is None:
244 | continue
245 |
246 | p.hess = p.hess.abs().clone()
247 |
248 | # Perform correct stepweight decay as in AdamW
249 | p.mul_(1 - group['lr'] * group['weight_decay'])
250 |
251 | state = self.state[p]
252 | # State initialization
253 | if len(state) == 2:
254 | state['step'] = 0
255 | state['exp_avg'] = torch.zeros_like(p.data)
256 | state['exp_hessian_diag'] = torch.zeros_like(p.data)
257 | state['bias_correction2'] = 0
258 |
259 | exp_avg, exp_hessian_diag = state['exp_avg'], state['exp_hessian_diag']
260 | beta1, beta2 = group['betas']
261 | state['step'] += 1
262 |
263 | # Decay the first and second moment running average coefficient
264 | exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
265 | bias_correction1 = 1 - beta1 ** state['step']
266 |
267 | if (state['hessian step']-1) % self.lazy_hessian == 0:
268 | exp_hessian_diag.mul_(beta2).add_(p.hess, alpha=1 - beta2)
269 | bias_correction2 = 1 - beta2 ** state['step']
270 | state['bias_correction2'] = bias_correction2 ** self.hessian_power_t
271 |
272 | step_size = group['lr'] / bias_correction1
273 | step_size_neg = -step_size
274 |
275 | denom = ((exp_hessian_diag**self.hessian_power_t) / state['bias_correction2']).add_(group['eps'])
276 |
277 | # make update
278 | p.addcdiv_(exp_avg, denom, value=step_size_neg)
279 |
280 | return loss
281 |
--------------------------------------------------------------------------------
/optimizers/shampoo.py:
--------------------------------------------------------------------------------
1 | # Acknowledgement: This code is based on https://github.com/jettify/pytorch-optimizer/blob/master/torch_optimizer/shampoo.py
2 |
3 | import torch
4 | from torch.optim.optimizer import Optimizer
5 |
6 | def _matrix_power(matrix: torch.Tensor, power: float) -> torch.Tensor:
7 | # use CPU for svd for speed up
8 | device = matrix.device
9 | matrix = matrix.cpu()
10 | u, s, v = torch.svd(matrix)
11 | return (u @ s.pow_(power).diag() @ v.t()).to(device)
12 |
13 |
14 | class Shampoo(Optimizer):
15 | r"""Implements Shampoo Optimizer Algorithm.
16 |
17 | It has been proposed in `Shampoo: Preconditioned Stochastic Tensor
18 | Optimization`__.
19 |
20 | Arguments:
21 | params: iterable of parameters to optimize or dicts defining
22 | parameter groups
23 | lr: learning rate (default: 1e-3)
24 | momentum: momentum factor (default: 0)
25 | weight_decay: weight decay (L2 penalty) (default: 0)
26 | epsilon: epsilon added to each mat_gbar_j for numerical stability
27 | (default: 1e-4)
28 | update_freq: update frequency to compute inverse (default: 1)
29 |
30 | Example:
31 | >>> import torch_optimizer as optim
32 | >>> optimizer = optim.Shampoo(model.parameters(), lr=0.01)
33 | >>> optimizer.zero_grad()
34 | >>> loss_fn(model(input), target).backward()
35 | >>> optimizer.step()
36 |
37 | __ https://arxiv.org/abs/1802.09568
38 |
39 | Note:
40 | Reference code: https://github.com/moskomule/shampoo.pytorch
41 | """
42 |
43 | def __init__(
44 | self,
45 | params,
46 | lr: float = 1e-1,
47 | momentum: float = 0.0,
48 | weight_decay: float = 0.0,
49 | epsilon: float = 1e-4,
50 | update_freq: int = 1,
51 | ):
52 | if lr <= 0.0:
53 | raise ValueError("Invalid learning rate: {}".format(lr))
54 | if momentum < 0.0:
55 | raise ValueError("Invalid momentum value: {}".format(momentum))
56 | if weight_decay < 0.0:
57 | raise ValueError(
58 | "Invalid weight_decay value: {}".format(weight_decay)
59 | )
60 | if epsilon < 0.0:
61 | raise ValueError("Invalid momentum value: {}".format(momentum))
62 | if update_freq < 1:
63 | raise ValueError("Invalid momentum value: {}".format(momentum))
64 |
65 | defaults = dict(
66 | lr=lr,
67 | momentum=momentum,
68 | weight_decay=weight_decay,
69 | epsilon=epsilon,
70 | update_freq=update_freq,
71 | )
72 | super(Shampoo, self).__init__(params, defaults)
73 |
74 | def step(self, closure=None):
75 | """Performs a single optimization step.
76 |
77 | Arguments:
78 | closure: A closure that reevaluates the model and returns the loss.
79 | """
80 | loss = None
81 | if closure is not None:
82 | loss = closure()
83 |
84 | for group in self.param_groups:
85 | for p in group["params"]:
86 | if p.grad is None:
87 | continue
88 | grad = p.grad.data
89 | order = grad.ndimension()
90 | original_size = grad.size()
91 | state = self.state[p]
92 | momentum = group["momentum"]
93 | weight_decay = group["weight_decay"]
94 | if len(state) == 0:
95 | state["step"] = 0
96 | if momentum > 0:
97 | state["momentum_buffer"] = grad.clone()
98 | for dim_id, dim in enumerate(grad.size()):
99 | # precondition matrices
100 | state["precond_{}".format(dim_id)] = group[
101 | "epsilon"
102 | ] * torch.eye(dim, out=grad.new(dim, dim))
103 | state[
104 | "inv_precond_{dim_id}".format(dim_id=dim_id)
105 | ] = grad.new(dim, dim).zero_()
106 |
107 | if momentum > 0:
108 | grad.mul_(1 - momentum).add_(
109 | state["momentum_buffer"], alpha=momentum
110 | )
111 |
112 | if weight_decay > 0:
113 | grad.add_(p.data, alpha=group["weight_decay"])
114 |
115 | # See Algorithm 2 for detail
116 | for dim_id, dim in enumerate(grad.size()):
117 | precond = state["precond_{}".format(dim_id)]
118 | inv_precond = state["inv_precond_{}".format(dim_id)]
119 |
120 | # mat_{dim_id}(grad)
121 | grad = grad.transpose_(0, dim_id).contiguous()
122 | transposed_size = grad.size()
123 | grad = grad.view(dim, -1)
124 |
125 | grad_t = grad.t()
126 | precond.add_(grad @ grad_t)
127 | if state["step"] % group["update_freq"] == 0:
128 | inv_precond.copy_(_matrix_power(precond, -1 / order))
129 |
130 | if dim_id == order - 1:
131 | # finally
132 | grad = grad_t @ inv_precond
133 | # grad: (-1, last_dim)
134 | grad = grad.view(original_size)
135 | else:
136 | # if not final
137 | grad = inv_precond @ grad
138 | # grad (dim, -1)
139 | grad = grad.view(transposed_size)
140 |
141 | state["step"] += 1
142 | state["momentum_buffer"] = grad
143 | p.data.add_(grad, alpha=-group["lr"])
144 |
145 | return loss
146 |
--------------------------------------------------------------------------------
/optimizers/sophiaH.py:
--------------------------------------------------------------------------------
1 | # Acknowledgement: This code is based on https://github.com/Liuhong99/Sophia/blob/main/sophia.py
2 |
3 | import torch
4 | import math
5 |
6 | class SophiaH(torch.optim.Optimizer):
7 | def __init__(self,
8 | params,
9 | lr=0.15,
10 | betas=(0.965, 0.99),
11 | eps=1e-15,
12 | weight_decay=1e-1,
13 | lazy_hessian=10,
14 | n_samples=1,
15 | clip_threshold=0.04,
16 | seed=0):
17 |
18 | if not 0.0 <= lr:
19 | raise ValueError(f"Invalid learning rate: {lr}")
20 | if not 0.0 <= eps:
21 | raise ValueError(f"Invalid epsilon value: {eps}")
22 | if not 0.0 <= betas[0] < 1.0:
23 | raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
24 | if not 0.0 <= betas[1] < 1.0:
25 | raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
26 | if not 0.0 <= clip_threshold:
27 | raise ValueError(f"Invalid threshold parameter: {clip_threshold}")
28 |
29 | self.n_samples = n_samples
30 | self.lazy_hessian = lazy_hessian
31 | self.seed = seed
32 |
33 | # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training
34 | self.generator = torch.Generator().manual_seed(self.seed)
35 |
36 | defaults = dict(lr=lr,
37 | betas=betas,
38 | eps=eps,
39 | weight_decay=weight_decay,
40 | clip_threshold=clip_threshold)
41 |
42 | super(SophiaH, self).__init__(params, defaults)
43 |
44 | for p in self.get_params():
45 | p.hess = 0.0
46 | self.state[p]["hessian step"] = 0
47 |
48 | def get_params(self):
49 | """
50 | Gets all parameters in all param_groups with gradients
51 | """
52 |
53 | return (p for group in self.param_groups for p in group['params'] if p.requires_grad)
54 |
55 | def zero_hessian(self):
56 | """
57 | Zeros out the accumalated hessian traces.
58 | """
59 |
60 | for p in self.get_params():
61 | if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.lazy_hessian == 0:
62 | p.hess.zero_()
63 |
64 | @torch.no_grad()
65 | def set_hessian(self):
66 | """
67 | Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter.
68 | """
69 |
70 | params = []
71 | for p in filter(lambda p: p.grad is not None, self.get_params()):
72 | if self.state[p]["hessian step"] % self.lazy_hessian == 0: # compute a new Hessian once per 'lazy hessian' steps
73 | params.append(p)
74 | self.state[p]["hessian step"] += 1
75 |
76 | if len(params) == 0:
77 | return
78 |
79 | if self.generator.device != params[0].device: # hackish way of casting the generator to the right device
80 | self.generator = torch.Generator(params[0].device).manual_seed(self.seed)
81 |
82 | grads = [p.grad for p in params]
83 |
84 | last_sample = self.n_samples - 1
85 | for i in range(self.n_samples):
86 | # Rademacher distribution {-1.0, 1.0}
87 | zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params]
88 | h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < last_sample)
89 | for h_z, z, p in zip(h_zs, zs, params):
90 | p.hess += h_z * z / self.n_samples
91 |
92 | @torch.no_grad()
93 | def step(self, closure=None):
94 | """
95 | Performs a single optimization step.
96 | Arguments:
97 | closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None)
98 | """
99 |
100 | loss = None
101 | if closure is not None:
102 | loss = closure()
103 |
104 | self.zero_hessian()
105 | self.set_hessian()
106 |
107 | for group in self.param_groups:
108 | for p in group['params']:
109 | if p.grad is None or p.hess is None:
110 | continue
111 |
112 | # Perform correct stepweight decay as in AdamW
113 | p.mul_(1 - group['lr'] * group['weight_decay'])
114 |
115 | state = self.state[p]
116 |
117 | # State initialization
118 | if len(state) == 1:
119 | state['step'] = 0
120 | state['exp_avg'] = torch.zeros_like(p.data)
121 | state['exp_hessian_diag'] = torch.zeros_like(p.data)
122 |
123 | exp_avg, exp_hessian_diag = state['exp_avg'], state['exp_hessian_diag']
124 | beta1, beta2 = group['betas']
125 | state['step'] += 1
126 |
127 | # Decay the first and second moment running average coefficient
128 | exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
129 |
130 | if (state['hessian step']-1) % self.lazy_hessian == 0:
131 | exp_hessian_diag.mul_(beta2).add_(p.hess, alpha=1 - beta2)
132 |
133 | step_size = group['lr']
134 | step_size_neg = -step_size
135 |
136 | denom = group['clip_threshold'] * exp_hessian_diag.clamp(0, None) + group['eps']
137 | ratio = (exp_avg.abs() / denom).clamp(None, 1)
138 |
139 | # make update
140 | p.addcmul_(exp_avg.sign(), ratio, value=step_size_neg)
141 |
142 | return loss
143 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | datasets==2.19.0
2 | einops==0.7.0
3 | evaluate==0.4.2
4 | huggingface_hub
5 | numpy==1.26.2
6 | torch==2.1.1
7 | torchvision==0.16.1
8 | tqdm
9 | transformers==4.48.0
10 | wandb
--------------------------------------------------------------------------------