├── pics ├── pgm_vsgd_2.png └── cifar100_res.png ├── configs ├── train │ ├── optimizer │ │ ├── sgd.yaml │ │ ├── adam.yaml │ │ ├── adamW.yaml │ │ └── vsgd.yaml │ ├── scheduler │ │ ├── stepLR.yaml │ │ ├── cosine.yaml │ │ └── plateau.yaml │ └── defaults.yaml ├── wandb │ └── defaults.yaml ├── defaults.yaml ├── model │ ├── resnext.yaml │ ├── vgg.yaml │ ├── convmixer.yaml │ └── resnet.yaml ├── dataset │ ├── cifar100.yaml │ └── tiny_imagenet.yaml └── experiment │ ├── cifar100_vgg.yaml │ ├── tiny_imagenet_vgg.yaml │ ├── cifar100_convmixer.yaml │ ├── tiny_imagenet_convmixer.yaml │ ├── tiny_imagenet_resnext.yaml │ └── cifar100_resnext.yaml ├── _config.yml ├── src ├── utils │ ├── wandb.py │ ├── tester.py │ └── trainer.py ├── model │ ├── classifier.py │ ├── convmixer.py │ ├── vgg.py │ └── resnext.py ├── dataset │ ├── cifar10.py │ ├── cifar100.py │ ├── data_module.py │ └── tiny_imagenet.py ├── run_experiment.py └── vsgd.py ├── environment.yml ├── LICENSE ├── README.md ├── .gitignore └── notebooks └── vsgd_example.ipynb /pics/pgm_vsgd_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generativeai-tue/vsgd/HEAD/pics/pgm_vsgd_2.png -------------------------------------------------------------------------------- /pics/cifar100_res.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/generativeai-tue/vsgd/HEAD/pics/cifar100_res.png -------------------------------------------------------------------------------- /configs/train/optimizer/sgd.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.SGD 2 | params: null 3 | lr: 0.1 4 | momentum: 0. 5 | weight_decay: 0. -------------------------------------------------------------------------------- /configs/train/scheduler/stepLR.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | _target_: torch.optim.lr_scheduler.StepLR 3 | step_size: 5000 4 | gamma: 0.5 5 | last_epoch: -1 -------------------------------------------------------------------------------- /configs/train/scheduler/cosine.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | _target_: torch.optim.lr_scheduler.CosineAnnealingLR 3 | T_max: ${train.max_iter} 4 | eta_min: 0. 5 | -------------------------------------------------------------------------------- /configs/wandb/defaults.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | setup: 3 | project: vsgd 4 | mode: online 5 | watch: 6 | log: gradients 7 | log_freq: 1000 8 | group: null -------------------------------------------------------------------------------- /configs/train/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.Adam 2 | params: null 3 | lr: 0.1 4 | eps: 1e-08 5 | weight_decay: 0. 6 | betas: 7 | - 0.9 8 | - 0.999 -------------------------------------------------------------------------------- /configs/train/optimizer/adamW.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.AdamW 2 | params: null 3 | lr: 0.1 4 | eps: 1e-08 5 | weight_decay: 0. 6 | betas: 7 | - 0.9 8 | - 0.999 -------------------------------------------------------------------------------- /configs/train/optimizer/vsgd.yaml: -------------------------------------------------------------------------------- 1 | _target_: vsgd.VSGD 2 | params: null 3 | ghattg: 30.0 4 | ps: 1e-8 5 | tau2: 0.9 6 | tau1: 0.81 7 | lr: 0.1 8 | weight_decay: 0.0 -------------------------------------------------------------------------------- /configs/train/scheduler/plateau.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 3 | factor: 0.5 4 | patience: 10 5 | threshold: 1e-3 6 | cooldown: 0 7 | min_lr: 1e-5 -------------------------------------------------------------------------------- /configs/defaults.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - model: resnet18 6 | - dataset: cifar10 7 | - train: defaults 8 | - wandb: defaults 9 | - experiment: null -------------------------------------------------------------------------------- /configs/model/resnext.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | _target_: model.resnext.ResNeXt 3 | cardinality: 8 4 | depth: 29 5 | widen_factor: 4 6 | dropRate: 0 7 | num_classes: ${dataset.num_classes} 8 | name: resnext 9 | -------------------------------------------------------------------------------- /configs/model/vgg.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | _target_: model.vgg.VGG 3 | cfg_id: A # A, B, D, E for 11, 13, 16, 19 layers of VGG 4 | batch_norm: true 5 | num_classes: ${dataset.num_classes} 6 | name: vgg 7 | -------------------------------------------------------------------------------- /configs/model/convmixer.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | _target_: model.convmixer.ConvMixer 3 | dim: 256 4 | depth: 8 5 | kernel_size: 5 6 | patch_size: 2 7 | num_classes: ${dataset.num_classes} 8 | name: convmixer 9 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman 2 | 3 | 4 | title: "Variational Stochastic Gradient Descent" 5 | description: "Code repository of the paper Variational Stochastic Gradient Descent for Deep Neural Networks" 6 | 7 | -------------------------------------------------------------------------------- /configs/dataset/cifar100.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | data_module: 3 | _target_: dataset.cifar100.Cifar100 4 | batch_size: 64 5 | test_batch_size: 1024 6 | use_augmentations: true 7 | x_dim: 32 8 | num_classes: 100 9 | name: cifar100 -------------------------------------------------------------------------------- /configs/dataset/tiny_imagenet.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | data_module: 3 | _target_: dataset.tiny_imagenet.TinyImagenet 4 | batch_size: 64 5 | test_batch_size: 1024 6 | use_augmentations: true 7 | x_dim: 64 8 | num_classes: 200 9 | name: tiny-imagenet-200 -------------------------------------------------------------------------------- /src/utils/wandb.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import wandb 5 | 6 | 7 | def get_checkpoint(entity: str, project: str, idx: str, device: str = "cpu"): 8 | # download the checkpoint from wandb to the local machine. 9 | file = wandb.restore( 10 | "last_chpt.pth", run_path=os.path.join(entity, project, idx), replace=True 11 | ) 12 | # load the checkpoint 13 | chpt = torch.load(file.name, map_location=device) 14 | return chpt 15 | -------------------------------------------------------------------------------- /configs/train/defaults.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | defaults: 3 | - optimizer: adamW 4 | - scheduler: null 5 | 6 | seed: 0 7 | resume_id: null 8 | device: cuda 9 | start_iter: 0 10 | max_iter: 10000 11 | grad_clip: 0 12 | grad_skip_thr: 0 # skip the update step is maximal grad norm is larger then this value (ignore if 0) 13 | save_freq: 1 # how often to save the checkpoint (in iterations) 14 | eval_test_freq: 10000 # how often to run evaluation on test dataset (in iterations) 15 | experiment_name: null 16 | optimizer_log_freq: -1 17 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: vsgd 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - anaconda 6 | - conda-forge 7 | - defaults 8 | - huggingface 9 | dependencies: 10 | - python=3.10.4 11 | - numpy 12 | - scipy 13 | - matplotlib 14 | - wandb 15 | - tqdm 16 | - imageio 17 | - pip 18 | - wheel 19 | - pytorch=1.12.1 20 | - torchvision=0.13.1 21 | - cudatoolkit=11.3 22 | - torchmetrics 23 | - transformers 24 | - pip: 25 | - hydra-core==1.3 26 | - torch-fidelity==0.3.0 27 | - black 28 | 29 | -------------------------------------------------------------------------------- /configs/experiment/cifar100_vgg.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /dataset: cifar100 4 | - override /model: vgg 5 | - override /train/scheduler: stepLR 6 | 7 | dataset: 8 | data_module: 9 | batch_size: 256 10 | test_batch_size: 256 11 | use_augmentations: true 12 | model: 13 | cfg_id: D 14 | batch_norm: true 15 | train: 16 | experiment_name: test 17 | resume_id: null 18 | seed: 124 19 | device: 'cuda:0' 20 | save_freq: 500 21 | eval_test_freq: 500 22 | grad_clip: 0 23 | grad_skip_thr: 0 24 | max_iter: 30000 25 | optimizer_log_freq: 100 26 | scheduler: 27 | step_size: 10000 28 | gamma: 0.5 29 | -------------------------------------------------------------------------------- /configs/experiment/tiny_imagenet_vgg.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /dataset: tiny_imagenet 4 | - override /model: vgg 5 | - override /train/scheduler: stepLR 6 | 7 | dataset: 8 | data_module: 9 | batch_size: 128 10 | test_batch_size: 128 11 | use_augmentations: true 12 | model: 13 | cfg_id: E 14 | batch_norm: true 15 | train: 16 | experiment_name: test 17 | resume_id: null 18 | seed: 124 19 | device: 'cuda:0' 20 | save_freq: 500 21 | eval_test_freq: 500 22 | grad_clip: 0 23 | grad_skip_thr: 0 24 | max_iter: 60000 25 | optimizer_log_freq: 100 26 | scheduler: 27 | step_size: 20000 28 | gamma: 0.5 -------------------------------------------------------------------------------- /configs/experiment/cifar100_convmixer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /dataset: cifar100 4 | - override /model: convmixer 5 | - override /train/scheduler: stepLR 6 | 7 | dataset: 8 | data_module: 9 | batch_size: 256 10 | test_batch_size: 256 11 | use_augmentations: true 12 | model: 13 | dim: 256 14 | depth: 8 15 | kernel_size: 5 16 | patch_size: 2 17 | train: 18 | experiment_name: test 19 | resume_id: null 20 | seed: 124 21 | device: 'cuda:0' 22 | save_freq: 500 23 | eval_test_freq: 500 24 | grad_clip: 0 25 | grad_skip_thr: 0 26 | max_iter: 30000 27 | optimizer_log_freq: 100 28 | scheduler: 29 | step_size: 10000 30 | gamma: 0.5 31 | -------------------------------------------------------------------------------- /configs/experiment/tiny_imagenet_convmixer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /dataset: tiny_imagenet 4 | - override /model: convmixer 5 | - override /train/scheduler: stepLR 6 | 7 | dataset: 8 | data_module: 9 | batch_size: 128 10 | test_batch_size: 128 11 | use_augmentations: true 12 | model: 13 | dim: 256 14 | depth: 8 15 | kernel_size: 5 16 | patch_size: 2 17 | train: 18 | experiment_name: test 19 | resume_id: null 20 | seed: 124 21 | device: 'cuda:0' 22 | save_freq: 500 23 | eval_test_freq: 500 24 | grad_clip: 0 25 | grad_skip_thr: 0 26 | max_iter: 60000 27 | optimizer_log_freq: 100 28 | scheduler: 29 | step_size: 20000 30 | gamma: 0.5 -------------------------------------------------------------------------------- /configs/experiment/tiny_imagenet_resnext.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /dataset: tiny_imagenet 4 | - override /model: resnext 5 | - override /train/scheduler: stepLR 6 | 7 | dataset: 8 | data_module: 9 | batch_size: 64 10 | test_batch_size: 64 11 | use_augmentations: true 12 | model: 13 | cardinality: 8 14 | depth: 18 15 | widen_factor: 4 16 | dropRate: 0 17 | train: 18 | experiment_name: test 19 | resume_id: null 20 | seed: 124 21 | device: 'cuda:0' 22 | save_freq: 500 23 | eval_test_freq: 500 24 | grad_clip: 0 25 | grad_skip_thr: 0 26 | max_iter: 60000 27 | optimizer_log_freq: 100 28 | scheduler: 29 | step_size: 20000 30 | gamma: 0.5 -------------------------------------------------------------------------------- /configs/experiment/cifar100_resnext.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /dataset: cifar100 4 | - override /model: resnext 5 | - override /train/scheduler: stepLR 6 | 7 | dataset: 8 | data_module: 9 | batch_size: 128 10 | test_batch_size: 128 11 | use_augmentations: true 12 | model: 13 | cardinality: 8 14 | depth: 18 15 | widen_factor: 4 16 | dropRate: 0 17 | train: 18 | experiment_name: test 19 | resume_id: null 20 | seed: 124 21 | device: 'cuda:0' 22 | save_freq: 500 23 | eval_test_freq: 500 24 | grad_clip: 0 25 | grad_skip_thr: 0 26 | max_iter: 30000 27 | optimizer_log_freq: 100 28 | scheduler: 29 | step_size: 10000 30 | gamma: 0.5 31 | 32 | 33 | -------------------------------------------------------------------------------- /src/utils/tester.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import wandb 3 | from tqdm import tqdm 4 | 5 | 6 | def test(args, loader, model): 7 | model.eval() 8 | history = {} 9 | N = 0 10 | with torch.no_grad(): 11 | for _, batch in tqdm(enumerate(loader)): 12 | if "cuda" in args.device: 13 | for i in range(len(batch)): 14 | batch[i] = batch[i].cuda(non_blocking=True) 15 | 16 | N += batch[0].shape[0] 17 | logs = model.test_step( 18 | batch=batch, 19 | ) 20 | 21 | for k in logs.keys(): 22 | if f"test/{k}" not in history.keys(): 23 | history[f"test/{k}"] = 0.0 24 | history[f"test/{k}"] += logs[k] 25 | 26 | for k in history.keys(): 27 | history[k] /= len(loader.dataset) 28 | 29 | wandb.log(history) 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Generativ/e AI group at the TU/e 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/model/classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ClassifierWrapper(nn.Module): 6 | def __init__(self, backbone, loss_fn=nn.CrossEntropyLoss(), **kwargs): 7 | super().__init__() 8 | self.backbone = backbone 9 | self.loss_fn = loss_fn 10 | 11 | def forward(self, batch): 12 | output = self.backbone(pixel_values=batch[0], return_dict=False) 13 | return output[0] 14 | 15 | def train_step(self, batch, scaler=None, device=None): 16 | if scaler is not None: 17 | with torch.autocast(device_type=device, dtype=torch.float16): 18 | logits = self.forward(batch) 19 | loss = self.loss_fn(logits, batch[1]) 20 | else: 21 | logits = self.forward(batch) 22 | loss = self.loss_fn(logits, batch[1]) 23 | 24 | logs = { 25 | "loss": loss.data, 26 | "accuracy": (logits.argmax(dim=1) == batch[1]).float().mean(), 27 | } 28 | return loss, logs 29 | 30 | def test_step(self, batch): 31 | logits = self.forward(batch) 32 | loss = self.loss_fn(logits, batch[1]) 33 | 34 | logs = { 35 | "loss": loss.data * batch[0].shape[0], 36 | "accuracy": (logits.argmax(dim=1) == batch[1]).float().sum(), 37 | } 38 | return logs 39 | -------------------------------------------------------------------------------- /configs/model/resnet.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | _target_: model.classifier.ClassifierWrapper 3 | backbone: 4 | _target_: transformers.ResNetForImageClassification 5 | config: 6 | _target_: transformers.ResNetConfig 7 | num_channels: 3 8 | embedding_size: 64 # Dimensionality (hidden size) for the embedding layer. 9 | hidden_sizes: # (List[int], optional, defaults to [256, 512, 1024, 2048]) — Dimensionality (hidden size) at each stage. 10 | - 64 11 | - 128 12 | - 256 13 | - 512 14 | depths: # (List[int], optional, defaults to [3, 4, 6, 3]) — Depth (number of layers) for each stage. 15 | - 2 16 | - 2 17 | - 2 18 | - 2 19 | layer_type: "basic" # (str, optional, defaults to "bottleneck") — The layer to use, it can be either "basic" (used for smaller models, like resnet-18 or resnet-34) or "bottleneck" (used for larger models like resnet-50 and above). 20 | hidden_act: "relu" # (str, optional, defaults to "relu") — The non-linear activation function in each block. If string, "gelu", "relu", "selu" and "gelu_new" are supported. 21 | downsample_in_first_stage: False # (bool, optional, defaults to False) — If True, the first stage will downsample the inputs using a stride of 2. 22 | downsample_in_bottleneck: False # (bool, optional, defaults to False) — If True, the first conv 1x1 in ResNetBottleNeckLayer will downsample the inputs using a stride of 2. 23 | num_labels: ${dataset.num_classes} 24 | 25 | name: resnet 26 | # see https://huggingface.co/docs/transformers/main/en/model_doc/resnet#transformers.ResNetConfig for more details -------------------------------------------------------------------------------- /src/model/convmixer.py: -------------------------------------------------------------------------------- 1 | """VGG for CIFAR10. FC layers are removed. 2 | (c) YANG, Wei 3 | """ 4 | 5 | import torch.nn as nn 6 | 7 | from model.classifier import ClassifierWrapper 8 | 9 | # adapted from https://github.com/locuslab/convmixer-cifar10/blob/main/train.py 10 | 11 | 12 | class Residual(nn.Module): 13 | def __init__(self, fn): 14 | super().__init__() 15 | self.fn = fn 16 | 17 | def forward(self, x): 18 | return self.fn(x) + x 19 | 20 | 21 | class ConvMixer(ClassifierWrapper): 22 | def __init__( 23 | self, 24 | dim, 25 | depth, 26 | kernel_size=5, 27 | patch_size=2, 28 | num_classes=100, 29 | loss_fn=nn.CrossEntropyLoss(), 30 | **kwargs 31 | ): 32 | bb = nn.Sequential( 33 | nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size), 34 | nn.GELU(), 35 | nn.BatchNorm2d(dim), 36 | *[ 37 | nn.Sequential( 38 | Residual( 39 | nn.Sequential( 40 | nn.Conv2d( 41 | dim, dim, kernel_size, groups=dim, padding="same" 42 | ), 43 | nn.GELU(), 44 | nn.BatchNorm2d(dim), 45 | ) 46 | ), 47 | nn.Conv2d(dim, dim, kernel_size=1), 48 | nn.GELU(), 49 | nn.BatchNorm2d(dim), 50 | ) 51 | for i in range(depth) 52 | ], 53 | nn.AdaptiveAvgPool2d((1, 1)), 54 | nn.Flatten(), 55 | nn.Linear(dim, num_classes) 56 | ) 57 | super().__init__(backbone=bb, loss_fn=loss_fn) 58 | 59 | def forward(self, batch): 60 | x = self.backbone(batch[0]) 61 | return x 62 | -------------------------------------------------------------------------------- /src/dataset/cifar10.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torch.utils.data import random_split 4 | from torchvision import datasets, transforms 5 | 6 | from dataset.data_module import DataModule, ToTensor 7 | 8 | 9 | class Cifar10(DataModule): 10 | def __init__( 11 | self, 12 | batch_size, 13 | test_batch_size, 14 | root, 15 | use_augmentations, 16 | ): 17 | super(Cifar10, self).__init__( 18 | batch_size=batch_size, 19 | test_batch_size=test_batch_size, 20 | root=root, 21 | ) 22 | self.__dict__.update(locals()) 23 | if use_augmentations: 24 | self.transforms = transforms.Compose( 25 | [ 26 | transforms.RandomCrop(32, padding=4), 27 | transforms.RandomHorizontalFlip(), 28 | transforms.ToTensor(), 29 | transforms.Normalize( 30 | (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) 31 | ), 32 | ] 33 | ) 34 | self.test_transforms = transforms.Compose( 35 | [ 36 | transforms.ToTensor(), 37 | transforms.Normalize( 38 | (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) 39 | ), 40 | ] 41 | ) 42 | 43 | else: 44 | self.transforms = transforms.Compose( 45 | [ 46 | transforms.RandomHorizontalFlip(), 47 | ToTensor(), 48 | ] 49 | ) 50 | self.test_transforms = transforms.Compose([ToTensor()]) 51 | self.prepare_data() 52 | 53 | def prepare_data(self): 54 | datasets.CIFAR10(self.root, train=True, download=True) 55 | datasets.CIFAR10(self.root, train=False, download=True) 56 | 57 | def setup(self): 58 | cifar_full = datasets.CIFAR10(self.root, train=True, transform=self.transforms) 59 | cifar_full.processed_folder = os.path.join(self.root, cifar_full.base_folder) 60 | N = len(cifar_full) 61 | self.train = cifar_full 62 | self.train, self.val = random_split(cifar_full, [N - 256, 256]) 63 | self.test = datasets.CIFAR10( 64 | self.root, train=False, transform=self.test_transforms 65 | ) 66 | self.test.processed_folder = os.path.join(self.root, self.test.base_folder) 67 | -------------------------------------------------------------------------------- /src/dataset/cifar100.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import random_split 6 | from torchvision import datasets, transforms 7 | 8 | from dataset.data_module import DataModule, ToTensor 9 | 10 | 11 | class Cifar100(DataModule): 12 | def __init__( 13 | self, 14 | batch_size, 15 | test_batch_size, 16 | root, 17 | use_augmentations, 18 | ): 19 | super().__init__( 20 | batch_size=batch_size, 21 | test_batch_size=test_batch_size, 22 | root=root, 23 | ) 24 | self.__dict__.update(locals()) 25 | if use_augmentations: 26 | self.transforms = transforms.Compose( 27 | [ 28 | transforms.RandomCrop(32, padding=4), 29 | transforms.RandomHorizontalFlip(), 30 | transforms.ToTensor(), 31 | transforms.Normalize( 32 | (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761) 33 | ), 34 | ] 35 | ) 36 | self.test_transforms = transforms.Compose( 37 | [ 38 | transforms.ToTensor(), 39 | transforms.Normalize( 40 | (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761) 41 | ), 42 | ] 43 | ) 44 | 45 | else: 46 | self.transforms = transforms.Compose( 47 | [ 48 | transforms.RandomHorizontalFlip(), 49 | ToTensor(), 50 | ] 51 | ) 52 | self.test_transforms = transforms.Compose([ToTensor()]) 53 | self.prepare_data() 54 | 55 | def prepare_data(self): 56 | datasets.CIFAR100(self.root, train=True, download=True) 57 | datasets.CIFAR100(self.root, train=False, download=True) 58 | 59 | def setup(self): 60 | cifar_full = datasets.CIFAR100(self.root, train=True, transform=self.transforms) 61 | cifar_full.processed_folder = os.path.join(self.root, cifar_full.base_folder) 62 | N = len(cifar_full) 63 | # self.train = cifar_full 64 | # self.val = None 65 | self.train, self.val = random_split(cifar_full, [N - 256, 256]) 66 | self.test = datasets.CIFAR100( 67 | self.root, train=False, transform=self.test_transforms 68 | ) 69 | self.test.processed_folder = os.path.join(self.root, self.test.base_folder) 70 | -------------------------------------------------------------------------------- /src/dataset/data_module.py: -------------------------------------------------------------------------------- 1 | from itertools import permutations 2 | 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | from torch.utils.data import DataLoader 7 | from torchvision import transforms 8 | 9 | 10 | class ToTensor: 11 | def __call__(self, x): 12 | x = torch.FloatTensor(np.asarray(x, dtype=np.float32)).permute(2, 0, 1) 13 | return x 14 | 15 | 16 | class Random90Rotation: 17 | def __call__(self, x): 18 | k = torch.ceil(3.0 * torch.rand(1)).long() 19 | u = torch.rand(1) 20 | if u < 0.5: 21 | x = x.rotate(90 * k) 22 | return x 23 | 24 | 25 | class ChannelSwap: 26 | def __call__(self, x): 27 | permutation = list(permutations(range(3), 3))[np.random.randint(0, 5)] 28 | u = torch.rand(1) 29 | if u < 0.5: 30 | x = np.array(x)[..., permutation] 31 | x = Image.fromarray(x) 32 | return x 33 | 34 | 35 | class DataModule: 36 | def __init__( 37 | self, 38 | batch_size, 39 | test_batch_size, 40 | root="data/", 41 | ): 42 | self.__dict__.update(locals()) 43 | self.transforms = transforms.Compose( 44 | [ 45 | ToTensor(), 46 | ] 47 | ) 48 | self.test_transforms = transforms.Compose( 49 | [ 50 | ToTensor(), 51 | ] 52 | ) 53 | self.prepare_data() 54 | 55 | def prepare_data(self) -> None: 56 | """ 57 | Download the data. Do preprocessing if necessary. 58 | :return: 59 | """ 60 | raise NotImplementedError 61 | 62 | def setup(self) -> None: 63 | """ 64 | Create self.train and self.val, self.test dataset 65 | :return: None 66 | """ 67 | raise NotImplementedError 68 | 69 | def train_dataloader(self): 70 | params = { 71 | "pin_memory": True, 72 | "drop_last": True, 73 | "shuffle": True, 74 | "num_workers": 1, 75 | } 76 | train_loader = DataLoader(self.train, self.batch_size, **params) 77 | while True: 78 | yield from train_loader 79 | 80 | def val_dataloader(self): 81 | params = { 82 | "pin_memory": True, 83 | "drop_last": True, 84 | "shuffle": True, 85 | "num_workers": 1, 86 | } 87 | val_loader = DataLoader(self.val, self.test_batch_size, **params) 88 | while True: 89 | yield from val_loader 90 | 91 | def test_dataloader(self): 92 | test_loader = DataLoader( 93 | self.test, 94 | self.test_batch_size, 95 | num_workers=1, 96 | shuffle=False, 97 | pin_memory=True, 98 | drop_last=False, 99 | ) 100 | return test_loader 101 | -------------------------------------------------------------------------------- /src/model/vgg.py: -------------------------------------------------------------------------------- 1 | """VGG for CIFAR10. FC layers are removed. 2 | (c) YANG, Wei 3 | """ 4 | import math 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from model.classifier import ClassifierWrapper 10 | 11 | # adapted from https://github.com/alecwangcq/KFAC-Pytorch/blob/master/models/cifar/vgg.py 12 | 13 | __all__ = [ 14 | "VGG", 15 | "vgg11", 16 | "vgg11_bn", 17 | "vgg13", 18 | "vgg13_bn", 19 | "vgg16", 20 | "vgg16_bn", 21 | "vgg19_bn", 22 | "vgg19", 23 | ] 24 | 25 | 26 | cfg = { 27 | # vgg11 28 | "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], 29 | # vgg13: 30 | "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], 31 | # vgg16: 32 | "D": [ 33 | 64, 34 | 64, 35 | "M", 36 | 128, 37 | 128, 38 | "M", 39 | 256, 40 | 256, 41 | 256, 42 | "M", 43 | 512, 44 | 512, 45 | 512, 46 | "M", 47 | 512, 48 | 512, 49 | 512, 50 | "M", 51 | ], 52 | # vgg19: 53 | "E": [ 54 | 64, 55 | 64, 56 | "M", 57 | 128, 58 | 128, 59 | "M", 60 | 256, 61 | 256, 62 | 256, 63 | 256, 64 | "M", 65 | 512, 66 | 512, 67 | 512, 68 | 512, 69 | "M", 70 | 512, 71 | 512, 72 | 512, 73 | 512, 74 | "M", 75 | ], 76 | } 77 | 78 | 79 | class VGG(ClassifierWrapper): 80 | def __init__( 81 | self, 82 | cfg_id, 83 | batch_norm=False, 84 | num_classes=1000, 85 | loss_fn=nn.CrossEntropyLoss(), 86 | **kwargs 87 | ): 88 | super().__init__( 89 | backbone=VGG.make_layers(cfg[cfg_id], batch_norm=batch_norm), 90 | loss_fn=loss_fn, 91 | ) 92 | self.classifier = nn.Linear(512, num_classes) 93 | self._initialize_weights() 94 | 95 | def forward(self, batch): 96 | x = self.backbone(batch[0]) 97 | x = F.avg_pool2d(x, kernel_size=x.shape[-1], stride=1) 98 | x = x.view(x.size(0), -1) 99 | x = self.classifier(x) 100 | return x 101 | 102 | def _initialize_weights(self): 103 | for m in self.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 106 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 107 | if m.bias is not None: 108 | m.bias.data.zero_() 109 | elif isinstance(m, nn.BatchNorm2d): 110 | m.weight.data.fill_(1) 111 | m.bias.data.zero_() 112 | elif isinstance(m, nn.Linear): 113 | n = m.weight.size(1) 114 | m.weight.data.normal_(0, 0.01) 115 | m.bias.data.zero_() 116 | 117 | @staticmethod 118 | def make_layers(cfg, batch_norm=False): 119 | layers = [] 120 | in_channels = 3 121 | for v in cfg: 122 | if v == "M": 123 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 124 | else: 125 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 126 | if batch_norm: 127 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 128 | else: 129 | layers += [conv2d, nn.ReLU(inplace=True)] 130 | in_channels = v 131 | return nn.Sequential(*layers) 132 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Variational Stochastic Gradient Descent for Deep Neural Networks 2 | 3 | This repository contains the source code accompanying the paper: 4 | 5 | [Variational Stochastic Gradient Descent for Deep Neural Networks](https://openreview.net/forum?id=xu4ATNjcdy) 6 |
[[Demos]](https://github.com/generativeai-tue/vsgd/blob/main/notebooks) 7 | 8 |
**[Anna Kuzina\*](https://akuzina.github.io/), [Haotian Chen\*](https://www.linkedin.com/in/haotian-chen-359b4520b/), [Babak Esmaeili](https://babak0032.github.io), & [Jakub M. Tomczak](https://jmtomczak.github.io/)**. 9 | 10 | 11 | #### Abstract 12 | *Optimizing deep neural networks is one of the main tasks in successful deep learning. Current state-of-the-art optimizers are adaptive gradient-based optimization methods such as Adam. Recently, there has been an increasing interest in formulating gradient-based optimizers in a probabilistic framework for better estimation of gradients and modeling uncertainties. Here, we propose to combine both approaches, resulting in the Variational Stochastic Gradient Descent (VSGD) optimizer. We model gradient updates as a probabilistic model and utilize stochastic variational inference (SVI) to derive an efficient and effective update rule. Further, we show how our VSGD method relates to other adaptive gradient-based optimizers like Adam. 13 | Lastly, 14 | we carry out experiments on two image classification datasets and four deep neural network architectures, where we show that VSGD outperforms Adam and SGD.* 15 | 16 | 17 | 18 | 19 | 20 | ### Repository structure 21 | 22 | #### Folders 23 | 24 | This repository is organized as follows: 25 | 26 | * `src` contains the main PyTorch library 27 | * `configs` contains the default configuration for `src/run_experiment.py` 28 | * `notebooks` contains a demo of using VSGD optimizer 29 | 30 | 31 | ---- 32 | ### Reproduce 33 | 34 | ###### Install conda *(recommended)* 35 | 36 | ```bash 37 | conda env create -f environment.yml 38 | conda activate vsgd 39 | ``` 40 | 41 | ###### Login wandb *(recommended)* 42 | ```bash 43 | wandb login 44 | ``` 45 | 46 | ###### Download TinyImagenet dataset 47 | 48 | ```bash 49 | cd data/ 50 | wget http://cs231n.stanford.edu/tiny-imagenet-200.zip 51 | unzip tiny-imagenet-200.zip 52 | ``` 53 | 54 | ###### Starting an experiment 55 | All the experiments are run with `src/run_experiment.py`. Experiment configuration is handled by [Hydra](https://hydra.cc), one can find default configuration in the `configs/` folder. 56 | 57 | `configs/experiment/` contains configs for dataset-architecture pairs. For example, to train VGG model on cifar100 dataset with VSGD optimizer, run: 58 | ```bash 59 | PYTHONPATH=src/ python src/run_experiment.py experiment=cifar100_vgg train/optimizer=vsgd 60 | ``` 61 | 62 | One can also change any default hyperparameters using the command line: 63 | ```bash 64 | PYTHONPATH=src/ python src/run_experiment.py experiment=cifar100_vgg train/optimizer=vsgd train.optimizer.weight_decay=0.01 65 | ``` 66 | 67 | 68 | ---- 69 | 70 | ### Cite 71 | If you found this work useful in your research, please consider citing: 72 | 73 | ``` 74 | @article{ 75 | chen2024variational, 76 | title={Variational Stochastic Gradient Descent for Deep Neural Networks}, 77 | author={Chen, Haotian and Kuzina, Anna and Esmaeili, Babak and Tomczak, Jakub}, 78 | year={2024}, 79 | } 80 | ``` 81 | 82 | ### Acknowledgements 83 | *Anna Kuzina is funded by the Hybrid Intelligence Center, a 10-year programme funded by the Dutch Ministry of Education, Culture and Science through the Netherlands Organisation for Scientific Research, https://hybrid-intelligence-centre.nl. 84 | This work was carried out on the Dutch national e-infrastructure with the support of SURF Cooperative.* 85 | 86 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | scripts/ 3 | notebooks/ 4 | pics/ 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ -------------------------------------------------------------------------------- /src/run_experiment.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pprint import pprint 3 | 4 | import hydra.utils 5 | import numpy as np 6 | import omegaconf 7 | import torch 8 | import wandb 9 | from hydra.utils import instantiate 10 | 11 | import utils.tester as tester 12 | import utils.trainer as trainer 13 | from utils.wandb import get_checkpoint 14 | 15 | 16 | def params_to(param, device): 17 | param.data = param.data.to(device) 18 | if param._grad is not None: 19 | param._grad.data = param._grad.data.to(device) 20 | 21 | 22 | def optimizer_to(optim, device): 23 | for param in optim.state.values(): 24 | if isinstance(param, torch.Tensor): 25 | params_to(param, device) 26 | elif isinstance(param, dict): 27 | for subparam in param.values(): 28 | if isinstance(subparam, torch.Tensor): 29 | params_to(subparam, device) 30 | 31 | 32 | def load_from_checkpoint(args, model, optimizer=None, scheduler=None): 33 | chpt = get_checkpoint( 34 | args.wandb.setup.entity, 35 | args.wandb.setup.project, 36 | args.train.resume_id, 37 | device="cpu", 38 | ) 39 | args.train.start_iter = chpt["iteration"] 40 | # Load model and ema model 41 | model.load_state_dict(chpt["model_state_dict"]) 42 | 43 | # Load optimizer 44 | if optimizer is not None: 45 | opt_state_dict = chpt["optimizer_state_dict"] 46 | optimizer.load_state_dict(opt_state_dict) 47 | optimizer_to(optimizer, args.train.device) 48 | 49 | # Load scheduler 50 | if scheduler is not None: 51 | scheduler_state_dict = chpt["scheduler_state_dict"] 52 | scheduler.load_state_dict(scheduler_state_dict) 53 | 54 | return args, model, optimizer, scheduler 55 | 56 | 57 | def init_wandb(args): 58 | wandb.require("service") 59 | 60 | tags = [ 61 | args.dataset.name, 62 | args.model.name, 63 | args.train.optimizer._target_, 64 | args.train.experiment_name, 65 | ] 66 | if args.train.resume_id is not None: 67 | wandb.init( 68 | **args.wandb.setup, 69 | id=args.train.resume_id, 70 | resume="must", 71 | settings=wandb.Settings(start_method="thread"), 72 | ) 73 | else: 74 | wandb_cfg = omegaconf.OmegaConf.to_container( 75 | args, resolve=True, throw_on_missing=True 76 | ) 77 | wandb.init( 78 | **args.wandb.setup, 79 | config=wandb_cfg, 80 | group=f"{args.model.name}_{args.dataset.name}" 81 | if args.wandb.group is None 82 | else args.wandb.group, 83 | tags=tags, 84 | dir=hydra.utils.get_original_cwd(), 85 | settings=wandb.Settings(start_method="thread"), 86 | ) 87 | pprint(wandb.run.config) 88 | # define our custom x axis metric 89 | wandb.define_metric("iter") 90 | for pref in ["train", "val", "pic"]: 91 | wandb.define_metric(f"{pref}/*", step_metric="iter") 92 | wandb.define_metric("val/loss", summary="min", step_metric="iter") 93 | wandb.define_metric("test/loss", summary="min", step_metric="iter") 94 | 95 | 96 | def compute_params(model, args): 97 | # add network size 98 | num_param = sum(p.numel() for p in model.parameters() if p.requires_grad) 99 | print(num_param) 100 | wandb.run.summary["num_parameters"] = num_param 101 | 102 | 103 | @hydra.main(version_base="1.3", config_path="../configs", config_name="defaults.yaml") 104 | def run(args: omegaconf.DictConfig) -> None: 105 | # set cuda visible devices 106 | if args.train.device[-1] == "0": 107 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 108 | args.train.device = "cuda" 109 | elif args.train.device[-1] == "1": 110 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 111 | args.train.device = "cuda" 112 | 113 | # Set the seed 114 | torch.manual_seed(args.train.seed) 115 | torch.cuda.manual_seed(args.train.seed) 116 | np.random.seed(args.train.seed) 117 | torch.backends.cudnn.deterministic = True 118 | torch.backends.cudnn.benchmark = False 119 | 120 | # ------------ 121 | # data 122 | # ------------ 123 | dset_params = {"root": os.path.join(hydra.utils.get_original_cwd(), "data/")} 124 | data_module = instantiate(args.dataset.data_module, **dset_params) 125 | data_module.setup() 126 | train_loader = data_module.train_dataloader() 127 | val_loader = data_module.val_dataloader() 128 | test_loader = data_module.test_dataloader() 129 | 130 | # ------------ 131 | # model & optimizer 132 | # ------------ 133 | model = instantiate(args.model) 134 | optimizer = instantiate(args.train.optimizer, params=model.parameters()) 135 | scheduler = None 136 | if hasattr(args.train, "scheduler"): 137 | scheduler = instantiate(args.train.scheduler, optimizer=optimizer) 138 | 139 | if args.train.resume_id is not None: 140 | print(f"Resume training {args.train.resume_id}") 141 | args, model, optimizer, scheduler = load_from_checkpoint( 142 | args, model, optimizer, scheduler 143 | ) 144 | 145 | model.train() 146 | model.to(args.train.device) 147 | 148 | # ------------ 149 | # logging 150 | # ------------ 151 | init_wandb(args) 152 | wandb.watch(model, **args.wandb.watch) 153 | compute_params(model, args) 154 | 155 | # ------------ 156 | # training 157 | # ------------ 158 | if args.train.start_iter < args.train.max_iter: 159 | trainer.train( 160 | args.train, 161 | train_loader, 162 | val_loader, 163 | test_loader, 164 | model, 165 | optimizer, 166 | scheduler, 167 | ) 168 | 169 | # ------------ 170 | # testing 171 | # ------------ 172 | model = instantiate(args.model) 173 | with omegaconf.open_dict(args): 174 | args.train.resume_id = wandb.run.id 175 | _, model, _, _ = load_from_checkpoint(args, model) 176 | model.to(args.train.device) 177 | 178 | tester.test( 179 | args.train, 180 | test_loader, 181 | model, 182 | ) 183 | print("Test finished") 184 | wandb.finish() 185 | 186 | 187 | if __name__ == "__main__": 188 | run() 189 | -------------------------------------------------------------------------------- /src/dataset/tiny_imagenet.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | from dataset.data_module import DataModule, ToTensor 9 | 10 | 11 | class TinyImagenet(DataModule): 12 | def __init__( 13 | self, 14 | batch_size, 15 | test_batch_size, 16 | root, 17 | use_augmentations, 18 | ): 19 | super().__init__( 20 | batch_size=batch_size, 21 | test_batch_size=test_batch_size, 22 | root=root, 23 | ) 24 | self.__dict__.update(locals()) 25 | if use_augmentations: 26 | self.transforms = transforms.Compose( 27 | [ 28 | transforms.RandomCrop(size=64, padding=4), 29 | transforms.RandomHorizontalFlip(), 30 | transforms.RandomAffine( 31 | degrees=45, translate=(0.1, 0.1), scale=(0.9, 1.1) 32 | ), 33 | transforms.ColorJitter( 34 | brightness=0.2, contrast=0.2, saturation=0.2 35 | ), 36 | transforms.ToTensor(), 37 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 38 | ] 39 | ) 40 | self.test_transforms = transforms.Compose( 41 | [ 42 | transforms.ToTensor(), 43 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 44 | ] 45 | ) 46 | 47 | else: 48 | self.transforms = transforms.Compose( 49 | [ 50 | transforms.RandomHorizontalFlip(), 51 | ToTensor(), 52 | ] 53 | ) 54 | self.test_transforms = transforms.Compose([ToTensor()]) 55 | 56 | def prepare_data(self) -> None: 57 | pass 58 | 59 | def setup(self): 60 | self.train = TinyImageNet(self.root, split="train", transform=self.transforms) 61 | self.val = TinyImageNet(self.root, split="val", transform=self.test_transforms) 62 | self.test = TinyImageNet(self.root, split="val", transform=self.test_transforms) 63 | 64 | 65 | EXTENSION = "JPEG" 66 | NUM_IMAGES_PER_CLASS = 500 67 | CLASS_LIST_FILE = "wnids.txt" 68 | VAL_ANNOTATION_FILE = "val_annotations.txt" 69 | 70 | 71 | class TinyImageNet(Dataset): 72 | """Tiny ImageNet data set available from `http://cs231n.stanford.edu/tiny-imagenet-200.zip`. 73 | Dataset code adapted from https://github.com/leemengtw/tiny-imagenet/blob/master/TinyImageNet.py 74 | 75 | Parameters 76 | ---------- 77 | root: string 78 | Root directory including `train`, `test` and `val` subdirectories. 79 | split: string 80 | Indicating which split to return as a data set. 81 | Valid option: [`train`, `test`, `val`] 82 | transform: torchvision.transforms 83 | A (series) of valid transformation(s). 84 | in_memory: bool 85 | Set to True if there is enough memory (about 5G) and want to minimize disk IO overhead. 86 | """ 87 | 88 | def __init__( 89 | self, 90 | root, 91 | split="train", 92 | transform=None, 93 | target_transform=None, 94 | in_memory=False, 95 | ): 96 | self.root = os.path.join(os.path.expanduser(root), "tiny-imagenet-200") 97 | self.split = split 98 | self.transform = transform 99 | self.target_transform = target_transform 100 | self.in_memory = in_memory 101 | self.split_dir = os.path.join(self.root, self.split) 102 | self.image_paths = sorted( 103 | glob.iglob( 104 | os.path.join(self.split_dir, "**", "*.%s" % EXTENSION), recursive=True 105 | ) 106 | ) 107 | self.labels = {} # fname - label number mapping 108 | self.images = [] # used for in-memory processing 109 | 110 | # build class label - number mapping 111 | with open(os.path.join(self.root, CLASS_LIST_FILE), "r") as fp: 112 | self.label_texts = sorted([text.strip() for text in fp.readlines()]) 113 | self.label_text_to_number = {text: i for i, text in enumerate(self.label_texts)} 114 | 115 | if self.split == "train": 116 | for label_text, i in self.label_text_to_number.items(): 117 | for cnt in range(NUM_IMAGES_PER_CLASS): 118 | self.labels["%s_%d.%s" % (label_text, cnt, EXTENSION)] = i 119 | elif self.split == "val": 120 | with open(os.path.join(self.split_dir, VAL_ANNOTATION_FILE), "r") as fp: 121 | for line in fp.readlines(): 122 | terms = line.split("\t") 123 | file_name, label_text = terms[0], terms[1] 124 | self.labels[file_name] = self.label_text_to_number[label_text] 125 | 126 | # read all images into torch tensor in memory to minimize disk IO overhead 127 | if self.in_memory: 128 | self.images = [self.read_image(path) for path in self.image_paths] 129 | 130 | def __len__(self): 131 | return len(self.image_paths) 132 | 133 | def __getitem__(self, index): 134 | file_path = self.image_paths[index] 135 | 136 | if self.in_memory: 137 | img = self.images[index] 138 | else: 139 | img = self.read_image(file_path) 140 | if self.split == "test": 141 | return img 142 | else: 143 | # file_name = file_path.split('/')[-1] 144 | return img, self.labels[os.path.basename(file_path)] 145 | 146 | def __repr__(self): 147 | fmt_str = "Dataset " + self.__class__.__name__ + "\n" 148 | fmt_str += " Number of datapoints: {}\n".format(self.__len__()) 149 | tmp = self.split 150 | fmt_str += " Split: {}\n".format(tmp) 151 | fmt_str += " Root Location: {}\n".format(self.root) 152 | tmp = " Transforms (if any): " 153 | fmt_str += "{0}{1}\n".format( 154 | tmp, self.transform.__repr__().replace("\n", "\n" + " " * len(tmp)) 155 | ) 156 | tmp = " Target Transforms (if any): " 157 | fmt_str += "{0}{1}".format( 158 | tmp, self.target_transform.__repr__().replace("\n", "\n" + " " * len(tmp)) 159 | ) 160 | return fmt_str 161 | 162 | def read_image(self, path): 163 | # img = imageio.imread(path, pilmode='RGB') 164 | img = Image.open(path) 165 | img = img.convert("RGB") 166 | return self.transform(img) if self.transform else img 167 | -------------------------------------------------------------------------------- /src/vsgd.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.optim.optimizer import Optimizer, required 6 | 7 | 8 | class VSGD(Optimizer): 9 | def __init__( 10 | self, 11 | params: required, 12 | ghattg: float = 30.0, 13 | ps: float = 1e-8, 14 | tau1: float = 0.81, 15 | tau2: float = 0.9, 16 | lr: float = 0.1, 17 | weight_decay: float = 0.0, 18 | eps: float = 1e-8, 19 | ): 20 | """ 21 | Args: 22 | ghattg: prior variance ratio between ghat and g, 23 | Var(ghat_t-g_t)/Var(g_t-g_{t-1}). 24 | ps: piror strength. 25 | tau1: remember rate for the gamma parameters of g 26 | tau2: remember rate for the gamma parameter of ghat 27 | lr: learning rate. 28 | weight_decay (float): weight decay coefficient (default: 0.0) 29 | """ 30 | 31 | if not 0.0 <= weight_decay: 32 | raise ValueError(f"Invalid weight_decay value: {weight_decay}") 33 | defaults = dict( 34 | ghattg=ghattg, 35 | ps=ps, 36 | tau1=tau1, 37 | tau2=tau2, 38 | lr=lr, 39 | weight_decay=weight_decay, 40 | eps=eps, 41 | ) 42 | super().__init__(params, defaults) 43 | 44 | def __setstate__(self, state): 45 | super(VSGD, self).__setstate__(state) 46 | 47 | def step(self, closure=None): 48 | """Performs a single optimization step. 49 | 50 | Args: 51 | closure (Callable, optional): A closure that reevaluates the model 52 | and returns the loss. 53 | """ 54 | # self._cuda_graph_capture_health_check() 55 | 56 | loss = None 57 | if closure is not None: 58 | with torch.enable_grad(): 59 | loss = closure() 60 | 61 | for group in self.param_groups: 62 | params_with_grad = [] 63 | grads = [] 64 | mug_list = [] 65 | step_list = [] 66 | pa2_list = [] 67 | pbg2_list = [] 68 | pbhg2_list = [] 69 | bg_list = [] 70 | bhg_list = [] 71 | 72 | self._init_group( 73 | group, 74 | params_with_grad, 75 | grads, 76 | mug_list, 77 | step_list, 78 | pa2_list, 79 | pbg2_list, 80 | pbhg2_list, 81 | bg_list, 82 | bhg_list, 83 | group["ghattg"], 84 | group["ps"], 85 | ) 86 | 87 | vsgd( 88 | params_with_grad, 89 | grads, 90 | mug_list, 91 | step_list, 92 | pa2_list, 93 | pbg2_list, 94 | pbhg2_list, 95 | bg_list, 96 | bhg_list, 97 | group["tau1"], 98 | group["tau2"], 99 | group["lr"], 100 | group["weight_decay"], 101 | group["eps"], 102 | ) 103 | 104 | return loss 105 | 106 | def _init_group( 107 | self, 108 | group, 109 | params_with_grad: List[Tensor], 110 | grads: List[Tensor], 111 | mug_list: List, 112 | step_list: List, 113 | pa2_list: List, 114 | pbg2_list: List, 115 | pbhg2_list: List, 116 | bg_list: List, 117 | bhg_list: List, 118 | ghattg: float, 119 | ps: float, 120 | ): 121 | for p in group["params"]: 122 | if p.grad is None: 123 | continue 124 | params_with_grad.append(p) 125 | 126 | grads.append(p.grad) 127 | state = self.state[p] 128 | 129 | # State initialization 130 | if len(state) == 0: 131 | for k in ["mug", "bg", "bhg"]: 132 | # set a non zero small number to represent prior ignornance 133 | state[k] = torch.zeros_like(p, memory_format=torch.preserve_format) 134 | # initialize 2*a_0 and 2*b_0 as constants 135 | state["pa2"] = torch.tensor(2.0 * ps + 1.0 + 1e-4) 136 | state["pbg2"] = torch.tensor(2.0 * ps) 137 | state["pbhg2"] = torch.tensor(2.0 * ghattg * ps) 138 | state["step"] = torch.tensor(0.0) 139 | 140 | mug_list.append(state["mug"]) 141 | bg_list.append(state["bg"]) 142 | bhg_list.append(state["bhg"]) 143 | step_list.append(state["step"]) 144 | pa2_list.append(state["pa2"]) 145 | pbg2_list.append(state["pbg2"]) 146 | pbhg2_list.append(state["pbhg2"]) 147 | 148 | def get_current_beta1_estimate(self) -> Tensor: 149 | betas = [] 150 | for group in self.param_groups: 151 | for p in group["params"]: 152 | state = self.state[p] 153 | bg = state["bg"] 154 | bhg = state["bhg"] 155 | betas.append((bhg / (bg + bhg)).data) 156 | return betas 157 | 158 | 159 | def vsgd( 160 | params_with_grad: List[Tensor], 161 | grads: List[Tensor], 162 | mug_list: List[Tensor], 163 | step_list: List[Tensor], 164 | pa2_list: List[Tensor], 165 | pbg2_list: List[Tensor], 166 | pbhg2_list: List[Tensor], 167 | bg_list: List[Tensor], 168 | bhg_list: List[Tensor], 169 | tau1: float, 170 | tau2: float, 171 | lr: float, 172 | weight_decay: float, 173 | eps: float, 174 | ): 175 | for i, param in enumerate(params_with_grad): 176 | ghat = grads[i] 177 | mug = mug_list[i] 178 | mug1 = torch.clone(mug) 179 | step = step_list[i] 180 | step += 1 181 | pa2 = pa2_list[i] 182 | pbg2 = pbg2_list[i] 183 | pbhg2 = pbhg2_list[i] 184 | bg = bg_list[i] 185 | bhg = bhg_list[i] 186 | # weight decay following AdamW 187 | param.data.mul_(1 - lr * weight_decay) 188 | 189 | # variances of g and ghat 190 | if step == 1.0: 191 | sg = pbg2 / (pa2 - 1.0) 192 | shg = pbhg2 / (pa2 - 1.0) 193 | else: 194 | sg = bg / pa2 195 | shg = bhg / pa2 196 | # update muh, mug, Sigg and Sigh 197 | mug.copy_((ghat * sg + mug1 * shg) / (sg + shg)) 198 | sigg = sg * shg / (sg + shg) 199 | 200 | # update 2*b 201 | mug_sq = sigg + mug**2 202 | bg2 = pbg2 + mug_sq - 2.0 * mug * mug1 + mug1**2 203 | bhg2 = pbhg2 + mug_sq - 2.0 * ghat * mug + ghat**2 204 | 205 | rho1 = step ** (-tau1) 206 | rho2 = step ** (-tau2) 207 | bg.mul_(1.0 - rho1).add_(bg2, alpha=rho1) 208 | bhg.mul_(1.0 - rho2).add_(bhg2, alpha=rho2) 209 | 210 | # update param 211 | param.data.add_(lr / (torch.sqrt(mug_sq) + eps) * mug, alpha=-1.0) 212 | -------------------------------------------------------------------------------- /src/model/resnext.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | """ 4 | Creates a ResNeXt Model as defined in: 5 | Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016). 6 | Aggregated residual transformations for deep neural networks. 7 | arXiv preprint arXiv:1611.05431. 8 | import from https://github.com/prlz77/ResNeXt.pytorch/blob/master/models/model.py 9 | """ 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.nn import init 13 | 14 | from model.classifier import ClassifierWrapper 15 | 16 | __all__ = ["resnext"] 17 | 18 | 19 | class ResNeXtBottleneck(nn.Module): 20 | """ 21 | RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua) 22 | """ 23 | 24 | def __init__(self, in_channels, out_channels, stride, cardinality, widen_factor): 25 | """Constructor 26 | Args: 27 | in_channels: input channel dimensionality 28 | out_channels: output channel dimensionality 29 | stride: conv stride. Replaces pooling layer. 30 | cardinality: num of convolution groups. 31 | widen_factor: factor to reduce the input dimensionality before convolution. 32 | """ 33 | super(ResNeXtBottleneck, self).__init__() 34 | D = cardinality * out_channels // widen_factor 35 | self.conv_reduce = nn.Conv2d( 36 | in_channels, D, kernel_size=1, stride=1, padding=0, bias=False 37 | ) 38 | self.bn_reduce = nn.BatchNorm2d(D) 39 | self.conv_conv = nn.Conv2d( 40 | D, 41 | D, 42 | kernel_size=3, 43 | stride=stride, 44 | padding=1, 45 | groups=cardinality, 46 | bias=False, 47 | ) 48 | self.bn = nn.BatchNorm2d(D) 49 | self.conv_expand = nn.Conv2d( 50 | D, out_channels, kernel_size=1, stride=1, padding=0, bias=False 51 | ) 52 | self.bn_expand = nn.BatchNorm2d(out_channels) 53 | 54 | self.shortcut = nn.Sequential() 55 | if in_channels != out_channels: 56 | self.shortcut.add_module( 57 | "shortcut_conv", 58 | nn.Conv2d( 59 | in_channels, 60 | out_channels, 61 | kernel_size=1, 62 | stride=stride, 63 | padding=0, 64 | bias=False, 65 | ), 66 | ) 67 | self.shortcut.add_module("shortcut_bn", nn.BatchNorm2d(out_channels)) 68 | 69 | def forward(self, x): 70 | bottleneck = self.conv_reduce.forward(x) 71 | bottleneck = F.relu(self.bn_reduce.forward(bottleneck), inplace=True) 72 | bottleneck = self.conv_conv.forward(bottleneck) 73 | bottleneck = F.relu(self.bn.forward(bottleneck), inplace=True) 74 | bottleneck = self.conv_expand.forward(bottleneck) 75 | bottleneck = self.bn_expand.forward(bottleneck) 76 | residual = self.shortcut.forward(x) 77 | return F.relu(residual + bottleneck, inplace=True) 78 | 79 | 80 | class ResNeXt(ClassifierWrapper): 81 | """ 82 | ResNext optimized for the Cifar dataset, as specified in 83 | https://arxiv.org/pdf/1611.05431.pdf 84 | """ 85 | 86 | def __init__( 87 | self, 88 | cardinality, 89 | depth, 90 | num_classes, 91 | widen_factor=4, 92 | dropRate=0, 93 | loss_fn=nn.CrossEntropyLoss(), 94 | **kwargs 95 | ): 96 | """Constructor 97 | Args: 98 | cardinality: number of convolution groups. 99 | depth: number of layers. 100 | num_classes: number of classes 101 | widen_factor: factor to adjust the channel dimensionality 102 | """ 103 | super().__init__(backbone=nn.Identity(), loss_fn=loss_fn) 104 | self.cardinality = cardinality 105 | self.depth = depth 106 | self.block_depth = (self.depth - 2) // 9 107 | self.widen_factor = widen_factor 108 | self.num_classes = num_classes 109 | self.output_size = 64 110 | self.stages = [ 111 | 64, 112 | 64 * self.widen_factor, 113 | 128 * self.widen_factor, 114 | 256 * self.widen_factor, 115 | ] 116 | 117 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False) 118 | self.bn_1 = nn.BatchNorm2d(64) 119 | self.stage_1 = self.block("stage_1", self.stages[0], self.stages[1], 1) 120 | self.stage_2 = self.block("stage_2", self.stages[1], self.stages[2], 2) 121 | self.stage_3 = self.block("stage_3", self.stages[2], self.stages[3], 2) 122 | self.classifier = nn.Linear(1024, num_classes) 123 | init.kaiming_normal(self.classifier.weight) 124 | 125 | for key in self.state_dict(): 126 | if key.split(".")[-1] == "weight": 127 | if "conv" in key: 128 | init.kaiming_normal(self.state_dict()[key], mode="fan_out") 129 | if "bn" in key: 130 | self.state_dict()[key][...] = 1 131 | elif key.split(".")[-1] == "bias": 132 | self.state_dict()[key][...] = 0 133 | 134 | def block(self, name, in_channels, out_channels, pool_stride=2): 135 | """Stack n bottleneck modules where n is inferred from the depth of the network. 136 | Args: 137 | name: string name of the current block. 138 | in_channels: number of input channels 139 | out_channels: number of output channels 140 | pool_stride: factor to reduce the spatial dimensionality in the first bottleneck of the block. 141 | Returns: a Module consisting of n sequential bottlenecks. 142 | """ 143 | block = nn.Sequential() 144 | for bottleneck in range(self.block_depth): 145 | name_ = "%s_bottleneck_%d" % (name, bottleneck) 146 | if bottleneck == 0: 147 | block.add_module( 148 | name_, 149 | ResNeXtBottleneck( 150 | in_channels, 151 | out_channels, 152 | pool_stride, 153 | self.cardinality, 154 | self.widen_factor, 155 | ), 156 | ) 157 | else: 158 | block.add_module( 159 | name_, 160 | ResNeXtBottleneck( 161 | out_channels, 162 | out_channels, 163 | 1, 164 | self.cardinality, 165 | self.widen_factor, 166 | ), 167 | ) 168 | return block 169 | 170 | def forward(self, batch): 171 | x = self.conv_1_3x3.forward(batch[0]) 172 | x = F.relu(self.bn_1.forward(x), inplace=True) 173 | x = self.stage_1.forward(x) 174 | x = self.stage_2.forward(x) 175 | x = self.stage_3.forward(x) 176 | x = F.avg_pool2d(x, kernel_size=x.shape[-1], stride=1) 177 | x = x.view(-1, 1024) 178 | return self.classifier(x) 179 | -------------------------------------------------------------------------------- /src/utils/trainer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import time 4 | 5 | import torch 6 | import wandb 7 | 8 | from utils.tester import test 9 | 10 | 11 | def save_chpt(args, iteration, model, optimizer, scheduler, loss, name="last_chpt"): 12 | chpt = { 13 | "iteration": iteration, 14 | "model_state_dict": model.state_dict(), 15 | "optimizer_state_dict": optimizer.state_dict(), 16 | "scheduler_state_dict": None if scheduler is None else scheduler.state_dict(), 17 | "loss": loss, 18 | } 19 | torch.save(chpt, os.path.join(wandb.run.dir, f"{name}.pth")) 20 | wandb.save(os.path.join(wandb.run.dir, f"{name}.pth"), base_path=wandb.run.dir) 21 | print("->model saved<-\n") 22 | 23 | 24 | def train( 25 | args, 26 | train_loader, 27 | val_loader, 28 | test_loader, 29 | model, 30 | optimizer, 31 | scheduler, 32 | ): 33 | with torch.no_grad(): 34 | if val_loader is not None: 35 | # compute metrics on initialization 36 | batch = next(val_loader) 37 | history_val = run_iter( 38 | args=args, 39 | iteration=args.start_iter, 40 | batch=batch, 41 | model=model, 42 | optimizer=None, 43 | mode="val", 44 | ) 45 | wandb.log({**history_val, "iter": args.start_iter}) 46 | 47 | for iteration in range(args.start_iter, args.max_iter): 48 | batch = next(train_loader) 49 | 50 | time_start = time.time() 51 | history_train = run_iter( 52 | args, 53 | iteration=iteration, 54 | batch=batch, 55 | model=model, 56 | optimizer=optimizer, 57 | mode="train", 58 | ) 59 | 60 | train_elapsed = time.time() - time_start 61 | time_start = time.time() 62 | history_val = {} 63 | 64 | if val_loader is not None: 65 | batch = next(val_loader) 66 | with torch.no_grad(): 67 | history_val = run_iter( 68 | args, 69 | iteration=iteration + 1, 70 | batch=batch, 71 | model=model, 72 | optimizer=None, 73 | mode="val", 74 | ) 75 | 76 | if scheduler is not None: 77 | if scheduler.__class__.__name__ == "ReduceLROnPlateau": 78 | scheduler.step(history_val["val/loss"]) 79 | else: 80 | scheduler.step() 81 | 82 | val_elapsed = time.time() - time_start 83 | hist = { 84 | **history_train, 85 | **history_val, 86 | "train_time": train_elapsed, 87 | "val_time": val_elapsed, 88 | } 89 | 90 | # save metrics to wandb 91 | wandb.log(hist) 92 | # save checkpoint 93 | if iteration % args.save_freq == 0 or iteration == args.max_iter: 94 | loss = hist["train/loss"] 95 | if "val/loss" in hist.keys(): 96 | loss = hist["val/loss"] 97 | save_chpt( 98 | args, 99 | iteration, 100 | model, 101 | optimizer, 102 | scheduler, 103 | loss, 104 | ) 105 | 106 | if iteration % 100 == 0: 107 | print( 108 | "Iteration: {}/{}, Time elapsed: {:.2f}s\n" 109 | "* Train loss: {:.2f} \n".format( 110 | iteration + 1, 111 | args.max_iter, 112 | val_elapsed + train_elapsed, 113 | hist["train/loss"], 114 | ) 115 | ) 116 | if "val/loss" in hist.keys(): 117 | if math.isnan(hist["val/loss"]): 118 | print("Nan loss, stopping training") 119 | break 120 | 121 | # run test eval to track the performance 122 | if (iteration + 1) % args.eval_test_freq == 0 and ( 123 | iteration + 1 124 | ) < args.max_iter: 125 | print("Run test evaluation...") 126 | with torch.no_grad(): 127 | test( 128 | args=args, 129 | loader=test_loader, 130 | model=model, 131 | ) 132 | 133 | print("Save last checkpoint") 134 | loss = hist["train/loss"] 135 | if "val/loss" in hist.keys(): 136 | loss = hist["val/loss"] 137 | save_chpt(args, args.max_iter, model, optimizer, scheduler, loss) 138 | 139 | 140 | def run_iter(args, iteration, batch, model, optimizer, mode="train"): 141 | if mode == "train": 142 | model.train() 143 | try: 144 | lr = optimizer.param_groups[0]["lr"] 145 | except: 146 | lr = 0.0 147 | history = {"lr": lr, "iter": iteration + 1} 148 | if iteration > 0: 149 | if args.optimizer_log_freq > 0 and iteration % args.optimizer_log_freq == 0: 150 | if hasattr(optimizer, "get_current_beta1_estimate"): 151 | vals = optimizer.get_current_beta1_estimate() 152 | vals = torch.cat([x.reshape(-1) for x in vals]).reshape(1, -1).cpu() 153 | history["beta1"] = wandb.Histogram(vals) 154 | history["beta1_median"] = vals.median() 155 | history["beta1_mean"] = vals.mean() 156 | elif "betas" in optimizer.param_groups[0]: 157 | beta1 = optimizer.param_groups[0]["betas"][0] 158 | bias_correction = 1 - beta1**iteration 159 | history["beta1_median"] = beta1 / bias_correction 160 | history["beta1_mean"] = beta1 / bias_correction 161 | elif "momentum" in optimizer.param_groups[0]: 162 | beta1 = optimizer.param_groups[0]["momentum"] 163 | history["beta1_median"] = beta1 164 | history["beta1_mean"] = beta1 165 | 166 | elif mode == "val": 167 | model.eval() 168 | history = {} 169 | 170 | if "cuda" in args.device: 171 | for i in range(len(batch)): 172 | batch[i] = batch[i].cuda(non_blocking=True) 173 | # Loss 174 | logs = {} 175 | if mode == "train": 176 | loss, logs = model.train_step(batch, device=args.device) 177 | elif mode == "val": 178 | with torch.no_grad(): 179 | loss, logs = model.train_step(batch, device=args.device) 180 | 181 | if mode == "train": 182 | optimize(args, loss, model, optimizer) 183 | 184 | # Get the history 185 | for k in logs.keys(): 186 | h_key = k 187 | if "/" not in k: 188 | h_key = f"{mode}/{k}" 189 | if "hist" in k: 190 | history[h_key] = wandb.Histogram(logs[k]) 191 | else: 192 | history[h_key] = logs[k] 193 | 194 | return history 195 | 196 | 197 | def optim_step(params, optimizer, grad_clip_val, grad_skip_val): 198 | # clip gradient 199 | grad_norm = torch.nn.utils.clip_grad_norm_(params, grad_clip_val).item() 200 | 201 | if grad_skip_val == 0 or grad_norm < grad_skip_val: 202 | optimizer.step() 203 | return grad_norm 204 | 205 | 206 | def optimize(args, loss, model, optimizer): 207 | if args.grad_clip > 0: 208 | clip_to = args.grad_clip 209 | else: 210 | clip_to = 1e6 211 | 212 | logs = {"skipped_steps": 1} 213 | nans = torch.isnan(loss).sum().item() 214 | if nans == 0: 215 | logs["skipped_steps"] = 0 216 | # backprop through the main loss 217 | optimizer.zero_grad() 218 | loss.backward() 219 | params = [p for n, p in model.named_parameters() if p.requires_grad] 220 | 221 | grad_norm = optim_step( 222 | params=params, 223 | optimizer=optimizer, 224 | grad_clip_val=clip_to, 225 | grad_skip_val=args.grad_skip_thr, 226 | ) 227 | logs["grad_norm"] = grad_norm 228 | 229 | wandb.log(logs) 230 | -------------------------------------------------------------------------------- /notebooks/vsgd_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import matplotlib as mpl\n", 10 | "import matplotlib\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "\n", 13 | "import numpy as np\n", 14 | "import os\n", 15 | "import torch\n", 16 | "\n", 17 | "mpl.rcParams['text.usetex'] = True\n", 18 | "mpl.rcParams['text.latex.preamble'] = r'\\usepackage{amsmath}'\n", 19 | "plt.rcParams['figure.figsize'] = [9, 7]\n", 20 | "\n", 21 | "\n", 22 | "# Append ../src to path\n", 23 | "import sys\n", 24 | "source_path = os.path.join(os.getcwd(), '../src')\n", 25 | "if source_path not in sys.path:\n", 26 | " sys.path.append(source_path)\n", 27 | " " 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "## 1. Get the datasets" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "from torchvision.datasets import MNIST\n", 44 | "from torch.utils.data import DataLoader\n", 45 | "from torchvision import transforms\n", 46 | "from torchvision.transforms import ToTensor\n", 47 | "\n", 48 | "train_dataset = MNIST(root='../data/', download=True, train=True, transform=ToTensor())\n", 49 | "test_dataset = MNIST(root='../data/', download=True, train=False, transform=ToTensor())\n" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "## 2. Define NN\n", 57 | "\n", 58 | "In this example, we use the convmixer architecture\n", 59 | "\n", 60 | "\n", 61 | "code source: https://github.com/locuslab/convmixer-cifar10/blob/main/train.py" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 3, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "import torch.nn as nn \n", 71 | "\n", 72 | "class Residual(nn.Module):\n", 73 | " def __init__(self, fn):\n", 74 | " super().__init__()\n", 75 | " self.fn = fn\n", 76 | "\n", 77 | " def forward(self, x):\n", 78 | " return self.fn(x) + x\n", 79 | " \n", 80 | "def build_convmixer(dim, patch_size, kernel_size, depth, num_classes):\n", 81 | " return nn.Sequential(\n", 82 | " nn.Conv2d(1, dim, kernel_size=patch_size, stride=patch_size),\n", 83 | " nn.GELU(),\n", 84 | " nn.BatchNorm2d(dim),\n", 85 | " *[\n", 86 | " nn.Sequential(\n", 87 | " Residual(\n", 88 | " nn.Sequential(\n", 89 | " nn.Conv2d(\n", 90 | " dim, dim, kernel_size, groups=dim, padding=\"same\"\n", 91 | " ),\n", 92 | " nn.GELU(),\n", 93 | " nn.BatchNorm2d(dim),\n", 94 | " )\n", 95 | " ),\n", 96 | " nn.Conv2d(dim, dim, kernel_size=1),\n", 97 | " nn.GELU(),\n", 98 | " nn.BatchNorm2d(dim),\n", 99 | " )\n", 100 | " for _ in range(depth)\n", 101 | " ],\n", 102 | " nn.AdaptiveAvgPool2d((1, 1)),\n", 103 | " nn.Flatten(),\n", 104 | " nn.Linear(dim, num_classes)\n", 105 | " )" 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "metadata": {}, 111 | "source": [ 112 | "## 3. Train and test functions" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 4, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "\n", 122 | "def train(net, train_dataset, optimizer, max_epochs, batch_size):\n", 123 | " device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 124 | " net.to(device)\n", 125 | "\n", 126 | " criterion = nn.CrossEntropyLoss()\n", 127 | " trainloader = DataLoader(train_dataset, \n", 128 | " batch_size=batch_size, \n", 129 | " shuffle=True, \n", 130 | " num_workers=2)\n", 131 | "\n", 132 | " logs = {'train_loss': []}\n", 133 | " for epoch in range(max_epochs): # loop over the dataset multiple times\n", 134 | "\n", 135 | " running_loss = 0.0\n", 136 | " for i, data in enumerate(trainloader, 0):\n", 137 | " # get the inputs; data is a list of [inputs, labels]\n", 138 | " inputs, labels = data\n", 139 | " inputs = inputs.to(device)\n", 140 | " labels = labels.to(device)\n", 141 | " \n", 142 | "\n", 143 | " # zero the parameter gradients\n", 144 | " optimizer.zero_grad()\n", 145 | "\n", 146 | " # forward + backward + optimize\n", 147 | " outputs = net(inputs)\n", 148 | " loss = criterion(outputs, labels)\n", 149 | " loss.backward()\n", 150 | " optimizer.step()\n", 151 | "\n", 152 | " running_loss += loss.item()\n", 153 | " logs['train_loss'].append(loss.item())\n", 154 | " \n", 155 | " print(f\"Epoch {epoch+1}, loss: {running_loss / len(trainloader): .2f}\")\n", 156 | " print('Finished Training')\n", 157 | " return net, logs\n", 158 | "\n", 159 | "def test(test_dataset, net, batch_size):\n", 160 | " \n", 161 | " device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 162 | " net.to(device)\n", 163 | "\n", 164 | " criterion = nn.CrossEntropyLoss()\n", 165 | " testloader = DataLoader(test_dataset, \n", 166 | " batch_size=batch_size, \n", 167 | " shuffle=False, \n", 168 | " num_workers=2,\n", 169 | " drop_last=False,\n", 170 | " )\n", 171 | "\n", 172 | " # logs = {'test_accuracy': []}\n", 173 | " correct_clfs = 0.\n", 174 | " running_loss = 0.\n", 175 | " for i, data in enumerate(testloader, 0):\n", 176 | " # get the inputs; data is a list of [inputs, labels]\n", 177 | " inputs, labels = data\n", 178 | " inputs = inputs.to(device)\n", 179 | " labels = labels.to(device)\n", 180 | " \n", 181 | " # forward \n", 182 | " logits = net(inputs)\n", 183 | " loss = criterion(logits, labels)\n", 184 | " \n", 185 | " running_loss += loss.item()\n", 186 | "\n", 187 | " correct_clfs += (logits.argmax(dim=1) == labels).float().sum().item()\n", 188 | " N_points = len(test_dataset)\n", 189 | " return running_loss / N_points, correct_clfs / N_points\n" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "metadata": {}, 195 | "source": [ 196 | "## 4. Compare to Adam and SGD" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 14, 202 | "metadata": {}, 203 | "outputs": [ 204 | { 205 | "name": "stdout", 206 | "output_type": "stream", 207 | "text": [ 208 | "Epoch 1, loss: 0.27\n", 209 | "Epoch 2, loss: 0.10\n", 210 | "Epoch 3, loss: 0.09\n", 211 | "Finished Training\n", 212 | "Test accuracy: 98.08%\n" 213 | ] 214 | } 215 | ], 216 | "source": [ 217 | "from vsgd import VSGD\n", 218 | "\n", 219 | "net = build_convmixer(dim=16, patch_size=2, kernel_size=3, depth=4, num_classes=10)\n", 220 | "vsgd = VSGD(net.parameters(), lr=0.01, ps=1e-7)\n", 221 | "net, logs_vsgd = train(net, train_dataset, vsgd, max_epochs=3, batch_size=32)\n", 222 | "logs_vsgd['test_loss'], logs_vsgd['test_acc'] = test(test_dataset, net, 32)\n", 223 | "\n", 224 | "acc = logs_vsgd['test_acc']*100\n", 225 | "print(f'Test accuracy: {acc:.2f}%')" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 17, 231 | "metadata": {}, 232 | "outputs": [ 233 | { 234 | "name": "stdout", 235 | "output_type": "stream", 236 | "text": [ 237 | "Epoch 1, loss: 0.57\n", 238 | "Epoch 2, loss: 0.13\n", 239 | "Epoch 3, loss: 0.10\n", 240 | "Finished Training\n", 241 | "Test accuracy: 97.81%\n" 242 | ] 243 | } 244 | ], 245 | "source": [ 246 | "from torch.optim import SGD\n", 247 | "\n", 248 | "sgd_net = build_convmixer(dim=16, patch_size=2, kernel_size=3, depth=4, num_classes=10)\n", 249 | "sgd = SGD(sgd_net.parameters(), lr=0.01, momentum=0.9)\n", 250 | "sgd_net, logs_sgd = train(sgd_net, train_dataset, sgd, max_epochs=3, batch_size=32)\n", 251 | "logs_sgd['test_loss'], logs_sgd['test_acc'] = test(test_dataset, sgd_net, 32)\n", 252 | "\n", 253 | "acc = logs_sgd['test_acc']*100\n", 254 | "print(f'Test accuracy: {acc:.2f}%')" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": 15, 260 | "metadata": {}, 261 | "outputs": [ 262 | { 263 | "name": "stdout", 264 | "output_type": "stream", 265 | "text": [ 266 | "Epoch 1, loss: 0.24\n", 267 | "Epoch 2, loss: 0.10\n", 268 | "Epoch 3, loss: 0.09\n", 269 | "Finished Training\n", 270 | "Test accuracy: 97.68%\n" 271 | ] 272 | } 273 | ], 274 | "source": [ 275 | "from torch.optim import Adam\n", 276 | "\n", 277 | "adam_net = build_convmixer(dim=16, patch_size=2, kernel_size=3, depth=4, num_classes=10)\n", 278 | "adam = Adam(adam_net.parameters(), lr=0.01)\n", 279 | "adam_net, logs_adam = train(adam_net, train_dataset, adam, max_epochs=3, batch_size=32)\n", 280 | "logs_adam['test_loss'], logs_adam['test_acc'] = test(test_dataset, adam_net, batch_size=32)\n", 281 | "\n", 282 | "acc = logs_adam['test_acc']*100\n", 283 | "print(f'Test accuracy: {acc:.2f}%')" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 22, 289 | "metadata": {}, 290 | "outputs": [ 291 | { 292 | "data": { 293 | "image/png": "", 294 | "text/plain": [ 295 | "
" 296 | ] 297 | }, 298 | "metadata": { 299 | "needs_background": "light" 300 | }, 301 | "output_type": "display_data" 302 | } 303 | ], 304 | "source": [ 305 | "fig, ax = plt.subplots(1, 2, figsize=(12, 5))\n", 306 | "ax[0].plot(logs_vsgd['train_loss'], label='VSGD', alpha=0.7)\n", 307 | "ax[0].plot(logs_sgd['train_loss'], label='SGD', alpha=0.7)\n", 308 | "ax[0].plot(logs_adam['train_loss'], label='Adam', alpha=0.7)\n", 309 | "ax[0].legend(fontsize=20)\n", 310 | "ax[0].grid()\n", 311 | "ax[0].set_title('Training loss', fontsize=24)\n", 312 | "\n", 313 | "ax[1].bar(['VSGD', 'SGD', 'Adam'], [logs_vsgd['test_acc'], logs_sgd['test_acc'], logs_adam['test_acc']]);\n", 314 | "ax[1].grid()\n", 315 | "ax[1].set_ylim(0.97, 1.0);\n", 316 | "ax[1].set_title('Test accuracy', fontsize=24);" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": null, 322 | "metadata": {}, 323 | "outputs": [], 324 | "source": [] 325 | } 326 | ], 327 | "metadata": { 328 | "kernelspec": { 329 | "display_name": "BASE_ENV", 330 | "language": "python", 331 | "name": "base_env" 332 | }, 333 | "language_info": { 334 | "codemirror_mode": { 335 | "name": "ipython", 336 | "version": 3 337 | }, 338 | "file_extension": ".py", 339 | "mimetype": "text/x-python", 340 | "name": "python", 341 | "nbconvert_exporter": "python", 342 | "pygments_lexer": "ipython3", 343 | "version": "3.6.12" 344 | } 345 | }, 346 | "nbformat": 4, 347 | "nbformat_minor": 4 348 | } 349 | --------------------------------------------------------------------------------