├── pics
├── pgm_vsgd_2.png
└── cifar100_res.png
├── configs
├── train
│ ├── optimizer
│ │ ├── sgd.yaml
│ │ ├── adam.yaml
│ │ ├── adamW.yaml
│ │ └── vsgd.yaml
│ ├── scheduler
│ │ ├── stepLR.yaml
│ │ ├── cosine.yaml
│ │ └── plateau.yaml
│ └── defaults.yaml
├── wandb
│ └── defaults.yaml
├── defaults.yaml
├── model
│ ├── resnext.yaml
│ ├── vgg.yaml
│ ├── convmixer.yaml
│ └── resnet.yaml
├── dataset
│ ├── cifar100.yaml
│ └── tiny_imagenet.yaml
└── experiment
│ ├── cifar100_vgg.yaml
│ ├── tiny_imagenet_vgg.yaml
│ ├── cifar100_convmixer.yaml
│ ├── tiny_imagenet_convmixer.yaml
│ ├── tiny_imagenet_resnext.yaml
│ └── cifar100_resnext.yaml
├── _config.yml
├── src
├── utils
│ ├── wandb.py
│ ├── tester.py
│ └── trainer.py
├── model
│ ├── classifier.py
│ ├── convmixer.py
│ ├── vgg.py
│ └── resnext.py
├── dataset
│ ├── cifar10.py
│ ├── cifar100.py
│ ├── data_module.py
│ └── tiny_imagenet.py
├── run_experiment.py
└── vsgd.py
├── environment.yml
├── LICENSE
├── README.md
├── .gitignore
└── notebooks
└── vsgd_example.ipynb
/pics/pgm_vsgd_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/generativeai-tue/vsgd/HEAD/pics/pgm_vsgd_2.png
--------------------------------------------------------------------------------
/pics/cifar100_res.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/generativeai-tue/vsgd/HEAD/pics/cifar100_res.png
--------------------------------------------------------------------------------
/configs/train/optimizer/sgd.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.optim.SGD
2 | params: null
3 | lr: 0.1
4 | momentum: 0.
5 | weight_decay: 0.
--------------------------------------------------------------------------------
/configs/train/scheduler/stepLR.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | _target_: torch.optim.lr_scheduler.StepLR
3 | step_size: 5000
4 | gamma: 0.5
5 | last_epoch: -1
--------------------------------------------------------------------------------
/configs/train/scheduler/cosine.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | _target_: torch.optim.lr_scheduler.CosineAnnealingLR
3 | T_max: ${train.max_iter}
4 | eta_min: 0.
5 |
--------------------------------------------------------------------------------
/configs/wandb/defaults.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | setup:
3 | project: vsgd
4 | mode: online
5 | watch:
6 | log: gradients
7 | log_freq: 1000
8 | group: null
--------------------------------------------------------------------------------
/configs/train/optimizer/adam.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.optim.Adam
2 | params: null
3 | lr: 0.1
4 | eps: 1e-08
5 | weight_decay: 0.
6 | betas:
7 | - 0.9
8 | - 0.999
--------------------------------------------------------------------------------
/configs/train/optimizer/adamW.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.optim.AdamW
2 | params: null
3 | lr: 0.1
4 | eps: 1e-08
5 | weight_decay: 0.
6 | betas:
7 | - 0.9
8 | - 0.999
--------------------------------------------------------------------------------
/configs/train/optimizer/vsgd.yaml:
--------------------------------------------------------------------------------
1 | _target_: vsgd.VSGD
2 | params: null
3 | ghattg: 30.0
4 | ps: 1e-8
5 | tau2: 0.9
6 | tau1: 0.81
7 | lr: 0.1
8 | weight_decay: 0.0
--------------------------------------------------------------------------------
/configs/train/scheduler/plateau.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
3 | factor: 0.5
4 | patience: 10
5 | threshold: 1e-3
6 | cooldown: 0
7 | min_lr: 1e-5
--------------------------------------------------------------------------------
/configs/defaults.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - _self_
5 | - model: resnet18
6 | - dataset: cifar10
7 | - train: defaults
8 | - wandb: defaults
9 | - experiment: null
--------------------------------------------------------------------------------
/configs/model/resnext.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | _target_: model.resnext.ResNeXt
3 | cardinality: 8
4 | depth: 29
5 | widen_factor: 4
6 | dropRate: 0
7 | num_classes: ${dataset.num_classes}
8 | name: resnext
9 |
--------------------------------------------------------------------------------
/configs/model/vgg.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | _target_: model.vgg.VGG
3 | cfg_id: A # A, B, D, E for 11, 13, 16, 19 layers of VGG
4 | batch_norm: true
5 | num_classes: ${dataset.num_classes}
6 | name: vgg
7 |
--------------------------------------------------------------------------------
/configs/model/convmixer.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | _target_: model.convmixer.ConvMixer
3 | dim: 256
4 | depth: 8
5 | kernel_size: 5
6 | patch_size: 2
7 | num_classes: ${dataset.num_classes}
8 | name: convmixer
9 |
--------------------------------------------------------------------------------
/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-cayman
2 |
3 |
4 | title: "Variational Stochastic Gradient Descent"
5 | description: "Code repository of the paper Variational Stochastic Gradient Descent for Deep Neural Networks"
6 |
7 |
--------------------------------------------------------------------------------
/configs/dataset/cifar100.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | data_module:
3 | _target_: dataset.cifar100.Cifar100
4 | batch_size: 64
5 | test_batch_size: 1024
6 | use_augmentations: true
7 | x_dim: 32
8 | num_classes: 100
9 | name: cifar100
--------------------------------------------------------------------------------
/configs/dataset/tiny_imagenet.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | data_module:
3 | _target_: dataset.tiny_imagenet.TinyImagenet
4 | batch_size: 64
5 | test_batch_size: 1024
6 | use_augmentations: true
7 | x_dim: 64
8 | num_classes: 200
9 | name: tiny-imagenet-200
--------------------------------------------------------------------------------
/src/utils/wandb.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import wandb
5 |
6 |
7 | def get_checkpoint(entity: str, project: str, idx: str, device: str = "cpu"):
8 | # download the checkpoint from wandb to the local machine.
9 | file = wandb.restore(
10 | "last_chpt.pth", run_path=os.path.join(entity, project, idx), replace=True
11 | )
12 | # load the checkpoint
13 | chpt = torch.load(file.name, map_location=device)
14 | return chpt
15 |
--------------------------------------------------------------------------------
/configs/train/defaults.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | defaults:
3 | - optimizer: adamW
4 | - scheduler: null
5 |
6 | seed: 0
7 | resume_id: null
8 | device: cuda
9 | start_iter: 0
10 | max_iter: 10000
11 | grad_clip: 0
12 | grad_skip_thr: 0 # skip the update step is maximal grad norm is larger then this value (ignore if 0)
13 | save_freq: 1 # how often to save the checkpoint (in iterations)
14 | eval_test_freq: 10000 # how often to run evaluation on test dataset (in iterations)
15 | experiment_name: null
16 | optimizer_log_freq: -1
17 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: vsgd
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - anaconda
6 | - conda-forge
7 | - defaults
8 | - huggingface
9 | dependencies:
10 | - python=3.10.4
11 | - numpy
12 | - scipy
13 | - matplotlib
14 | - wandb
15 | - tqdm
16 | - imageio
17 | - pip
18 | - wheel
19 | - pytorch=1.12.1
20 | - torchvision=0.13.1
21 | - cudatoolkit=11.3
22 | - torchmetrics
23 | - transformers
24 | - pip:
25 | - hydra-core==1.3
26 | - torch-fidelity==0.3.0
27 | - black
28 |
29 |
--------------------------------------------------------------------------------
/configs/experiment/cifar100_vgg.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /dataset: cifar100
4 | - override /model: vgg
5 | - override /train/scheduler: stepLR
6 |
7 | dataset:
8 | data_module:
9 | batch_size: 256
10 | test_batch_size: 256
11 | use_augmentations: true
12 | model:
13 | cfg_id: D
14 | batch_norm: true
15 | train:
16 | experiment_name: test
17 | resume_id: null
18 | seed: 124
19 | device: 'cuda:0'
20 | save_freq: 500
21 | eval_test_freq: 500
22 | grad_clip: 0
23 | grad_skip_thr: 0
24 | max_iter: 30000
25 | optimizer_log_freq: 100
26 | scheduler:
27 | step_size: 10000
28 | gamma: 0.5
29 |
--------------------------------------------------------------------------------
/configs/experiment/tiny_imagenet_vgg.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /dataset: tiny_imagenet
4 | - override /model: vgg
5 | - override /train/scheduler: stepLR
6 |
7 | dataset:
8 | data_module:
9 | batch_size: 128
10 | test_batch_size: 128
11 | use_augmentations: true
12 | model:
13 | cfg_id: E
14 | batch_norm: true
15 | train:
16 | experiment_name: test
17 | resume_id: null
18 | seed: 124
19 | device: 'cuda:0'
20 | save_freq: 500
21 | eval_test_freq: 500
22 | grad_clip: 0
23 | grad_skip_thr: 0
24 | max_iter: 60000
25 | optimizer_log_freq: 100
26 | scheduler:
27 | step_size: 20000
28 | gamma: 0.5
--------------------------------------------------------------------------------
/configs/experiment/cifar100_convmixer.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /dataset: cifar100
4 | - override /model: convmixer
5 | - override /train/scheduler: stepLR
6 |
7 | dataset:
8 | data_module:
9 | batch_size: 256
10 | test_batch_size: 256
11 | use_augmentations: true
12 | model:
13 | dim: 256
14 | depth: 8
15 | kernel_size: 5
16 | patch_size: 2
17 | train:
18 | experiment_name: test
19 | resume_id: null
20 | seed: 124
21 | device: 'cuda:0'
22 | save_freq: 500
23 | eval_test_freq: 500
24 | grad_clip: 0
25 | grad_skip_thr: 0
26 | max_iter: 30000
27 | optimizer_log_freq: 100
28 | scheduler:
29 | step_size: 10000
30 | gamma: 0.5
31 |
--------------------------------------------------------------------------------
/configs/experiment/tiny_imagenet_convmixer.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /dataset: tiny_imagenet
4 | - override /model: convmixer
5 | - override /train/scheduler: stepLR
6 |
7 | dataset:
8 | data_module:
9 | batch_size: 128
10 | test_batch_size: 128
11 | use_augmentations: true
12 | model:
13 | dim: 256
14 | depth: 8
15 | kernel_size: 5
16 | patch_size: 2
17 | train:
18 | experiment_name: test
19 | resume_id: null
20 | seed: 124
21 | device: 'cuda:0'
22 | save_freq: 500
23 | eval_test_freq: 500
24 | grad_clip: 0
25 | grad_skip_thr: 0
26 | max_iter: 60000
27 | optimizer_log_freq: 100
28 | scheduler:
29 | step_size: 20000
30 | gamma: 0.5
--------------------------------------------------------------------------------
/configs/experiment/tiny_imagenet_resnext.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /dataset: tiny_imagenet
4 | - override /model: resnext
5 | - override /train/scheduler: stepLR
6 |
7 | dataset:
8 | data_module:
9 | batch_size: 64
10 | test_batch_size: 64
11 | use_augmentations: true
12 | model:
13 | cardinality: 8
14 | depth: 18
15 | widen_factor: 4
16 | dropRate: 0
17 | train:
18 | experiment_name: test
19 | resume_id: null
20 | seed: 124
21 | device: 'cuda:0'
22 | save_freq: 500
23 | eval_test_freq: 500
24 | grad_clip: 0
25 | grad_skip_thr: 0
26 | max_iter: 60000
27 | optimizer_log_freq: 100
28 | scheduler:
29 | step_size: 20000
30 | gamma: 0.5
--------------------------------------------------------------------------------
/configs/experiment/cifar100_resnext.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /dataset: cifar100
4 | - override /model: resnext
5 | - override /train/scheduler: stepLR
6 |
7 | dataset:
8 | data_module:
9 | batch_size: 128
10 | test_batch_size: 128
11 | use_augmentations: true
12 | model:
13 | cardinality: 8
14 | depth: 18
15 | widen_factor: 4
16 | dropRate: 0
17 | train:
18 | experiment_name: test
19 | resume_id: null
20 | seed: 124
21 | device: 'cuda:0'
22 | save_freq: 500
23 | eval_test_freq: 500
24 | grad_clip: 0
25 | grad_skip_thr: 0
26 | max_iter: 30000
27 | optimizer_log_freq: 100
28 | scheduler:
29 | step_size: 10000
30 | gamma: 0.5
31 |
32 |
33 |
--------------------------------------------------------------------------------
/src/utils/tester.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import wandb
3 | from tqdm import tqdm
4 |
5 |
6 | def test(args, loader, model):
7 | model.eval()
8 | history = {}
9 | N = 0
10 | with torch.no_grad():
11 | for _, batch in tqdm(enumerate(loader)):
12 | if "cuda" in args.device:
13 | for i in range(len(batch)):
14 | batch[i] = batch[i].cuda(non_blocking=True)
15 |
16 | N += batch[0].shape[0]
17 | logs = model.test_step(
18 | batch=batch,
19 | )
20 |
21 | for k in logs.keys():
22 | if f"test/{k}" not in history.keys():
23 | history[f"test/{k}"] = 0.0
24 | history[f"test/{k}"] += logs[k]
25 |
26 | for k in history.keys():
27 | history[k] /= len(loader.dataset)
28 |
29 | wandb.log(history)
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Generativ/e AI group at the TU/e
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/model/classifier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class ClassifierWrapper(nn.Module):
6 | def __init__(self, backbone, loss_fn=nn.CrossEntropyLoss(), **kwargs):
7 | super().__init__()
8 | self.backbone = backbone
9 | self.loss_fn = loss_fn
10 |
11 | def forward(self, batch):
12 | output = self.backbone(pixel_values=batch[0], return_dict=False)
13 | return output[0]
14 |
15 | def train_step(self, batch, scaler=None, device=None):
16 | if scaler is not None:
17 | with torch.autocast(device_type=device, dtype=torch.float16):
18 | logits = self.forward(batch)
19 | loss = self.loss_fn(logits, batch[1])
20 | else:
21 | logits = self.forward(batch)
22 | loss = self.loss_fn(logits, batch[1])
23 |
24 | logs = {
25 | "loss": loss.data,
26 | "accuracy": (logits.argmax(dim=1) == batch[1]).float().mean(),
27 | }
28 | return loss, logs
29 |
30 | def test_step(self, batch):
31 | logits = self.forward(batch)
32 | loss = self.loss_fn(logits, batch[1])
33 |
34 | logs = {
35 | "loss": loss.data * batch[0].shape[0],
36 | "accuracy": (logits.argmax(dim=1) == batch[1]).float().sum(),
37 | }
38 | return logs
39 |
--------------------------------------------------------------------------------
/configs/model/resnet.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | _target_: model.classifier.ClassifierWrapper
3 | backbone:
4 | _target_: transformers.ResNetForImageClassification
5 | config:
6 | _target_: transformers.ResNetConfig
7 | num_channels: 3
8 | embedding_size: 64 # Dimensionality (hidden size) for the embedding layer.
9 | hidden_sizes: # (List[int], optional, defaults to [256, 512, 1024, 2048]) — Dimensionality (hidden size) at each stage.
10 | - 64
11 | - 128
12 | - 256
13 | - 512
14 | depths: # (List[int], optional, defaults to [3, 4, 6, 3]) — Depth (number of layers) for each stage.
15 | - 2
16 | - 2
17 | - 2
18 | - 2
19 | layer_type: "basic" # (str, optional, defaults to "bottleneck") — The layer to use, it can be either "basic" (used for smaller models, like resnet-18 or resnet-34) or "bottleneck" (used for larger models like resnet-50 and above).
20 | hidden_act: "relu" # (str, optional, defaults to "relu") — The non-linear activation function in each block. If string, "gelu", "relu", "selu" and "gelu_new" are supported.
21 | downsample_in_first_stage: False # (bool, optional, defaults to False) — If True, the first stage will downsample the inputs using a stride of 2.
22 | downsample_in_bottleneck: False # (bool, optional, defaults to False) — If True, the first conv 1x1 in ResNetBottleNeckLayer will downsample the inputs using a stride of 2.
23 | num_labels: ${dataset.num_classes}
24 |
25 | name: resnet
26 | # see https://huggingface.co/docs/transformers/main/en/model_doc/resnet#transformers.ResNetConfig for more details
--------------------------------------------------------------------------------
/src/model/convmixer.py:
--------------------------------------------------------------------------------
1 | """VGG for CIFAR10. FC layers are removed.
2 | (c) YANG, Wei
3 | """
4 |
5 | import torch.nn as nn
6 |
7 | from model.classifier import ClassifierWrapper
8 |
9 | # adapted from https://github.com/locuslab/convmixer-cifar10/blob/main/train.py
10 |
11 |
12 | class Residual(nn.Module):
13 | def __init__(self, fn):
14 | super().__init__()
15 | self.fn = fn
16 |
17 | def forward(self, x):
18 | return self.fn(x) + x
19 |
20 |
21 | class ConvMixer(ClassifierWrapper):
22 | def __init__(
23 | self,
24 | dim,
25 | depth,
26 | kernel_size=5,
27 | patch_size=2,
28 | num_classes=100,
29 | loss_fn=nn.CrossEntropyLoss(),
30 | **kwargs
31 | ):
32 | bb = nn.Sequential(
33 | nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
34 | nn.GELU(),
35 | nn.BatchNorm2d(dim),
36 | *[
37 | nn.Sequential(
38 | Residual(
39 | nn.Sequential(
40 | nn.Conv2d(
41 | dim, dim, kernel_size, groups=dim, padding="same"
42 | ),
43 | nn.GELU(),
44 | nn.BatchNorm2d(dim),
45 | )
46 | ),
47 | nn.Conv2d(dim, dim, kernel_size=1),
48 | nn.GELU(),
49 | nn.BatchNorm2d(dim),
50 | )
51 | for i in range(depth)
52 | ],
53 | nn.AdaptiveAvgPool2d((1, 1)),
54 | nn.Flatten(),
55 | nn.Linear(dim, num_classes)
56 | )
57 | super().__init__(backbone=bb, loss_fn=loss_fn)
58 |
59 | def forward(self, batch):
60 | x = self.backbone(batch[0])
61 | return x
62 |
--------------------------------------------------------------------------------
/src/dataset/cifar10.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from torch.utils.data import random_split
4 | from torchvision import datasets, transforms
5 |
6 | from dataset.data_module import DataModule, ToTensor
7 |
8 |
9 | class Cifar10(DataModule):
10 | def __init__(
11 | self,
12 | batch_size,
13 | test_batch_size,
14 | root,
15 | use_augmentations,
16 | ):
17 | super(Cifar10, self).__init__(
18 | batch_size=batch_size,
19 | test_batch_size=test_batch_size,
20 | root=root,
21 | )
22 | self.__dict__.update(locals())
23 | if use_augmentations:
24 | self.transforms = transforms.Compose(
25 | [
26 | transforms.RandomCrop(32, padding=4),
27 | transforms.RandomHorizontalFlip(),
28 | transforms.ToTensor(),
29 | transforms.Normalize(
30 | (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
31 | ),
32 | ]
33 | )
34 | self.test_transforms = transforms.Compose(
35 | [
36 | transforms.ToTensor(),
37 | transforms.Normalize(
38 | (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
39 | ),
40 | ]
41 | )
42 |
43 | else:
44 | self.transforms = transforms.Compose(
45 | [
46 | transforms.RandomHorizontalFlip(),
47 | ToTensor(),
48 | ]
49 | )
50 | self.test_transforms = transforms.Compose([ToTensor()])
51 | self.prepare_data()
52 |
53 | def prepare_data(self):
54 | datasets.CIFAR10(self.root, train=True, download=True)
55 | datasets.CIFAR10(self.root, train=False, download=True)
56 |
57 | def setup(self):
58 | cifar_full = datasets.CIFAR10(self.root, train=True, transform=self.transforms)
59 | cifar_full.processed_folder = os.path.join(self.root, cifar_full.base_folder)
60 | N = len(cifar_full)
61 | self.train = cifar_full
62 | self.train, self.val = random_split(cifar_full, [N - 256, 256])
63 | self.test = datasets.CIFAR10(
64 | self.root, train=False, transform=self.test_transforms
65 | )
66 | self.test.processed_folder = os.path.join(self.root, self.test.base_folder)
67 |
--------------------------------------------------------------------------------
/src/dataset/cifar100.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | from torch.utils.data import random_split
6 | from torchvision import datasets, transforms
7 |
8 | from dataset.data_module import DataModule, ToTensor
9 |
10 |
11 | class Cifar100(DataModule):
12 | def __init__(
13 | self,
14 | batch_size,
15 | test_batch_size,
16 | root,
17 | use_augmentations,
18 | ):
19 | super().__init__(
20 | batch_size=batch_size,
21 | test_batch_size=test_batch_size,
22 | root=root,
23 | )
24 | self.__dict__.update(locals())
25 | if use_augmentations:
26 | self.transforms = transforms.Compose(
27 | [
28 | transforms.RandomCrop(32, padding=4),
29 | transforms.RandomHorizontalFlip(),
30 | transforms.ToTensor(),
31 | transforms.Normalize(
32 | (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)
33 | ),
34 | ]
35 | )
36 | self.test_transforms = transforms.Compose(
37 | [
38 | transforms.ToTensor(),
39 | transforms.Normalize(
40 | (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)
41 | ),
42 | ]
43 | )
44 |
45 | else:
46 | self.transforms = transforms.Compose(
47 | [
48 | transforms.RandomHorizontalFlip(),
49 | ToTensor(),
50 | ]
51 | )
52 | self.test_transforms = transforms.Compose([ToTensor()])
53 | self.prepare_data()
54 |
55 | def prepare_data(self):
56 | datasets.CIFAR100(self.root, train=True, download=True)
57 | datasets.CIFAR100(self.root, train=False, download=True)
58 |
59 | def setup(self):
60 | cifar_full = datasets.CIFAR100(self.root, train=True, transform=self.transforms)
61 | cifar_full.processed_folder = os.path.join(self.root, cifar_full.base_folder)
62 | N = len(cifar_full)
63 | # self.train = cifar_full
64 | # self.val = None
65 | self.train, self.val = random_split(cifar_full, [N - 256, 256])
66 | self.test = datasets.CIFAR100(
67 | self.root, train=False, transform=self.test_transforms
68 | )
69 | self.test.processed_folder = os.path.join(self.root, self.test.base_folder)
70 |
--------------------------------------------------------------------------------
/src/dataset/data_module.py:
--------------------------------------------------------------------------------
1 | from itertools import permutations
2 |
3 | import numpy as np
4 | import torch
5 | from PIL import Image
6 | from torch.utils.data import DataLoader
7 | from torchvision import transforms
8 |
9 |
10 | class ToTensor:
11 | def __call__(self, x):
12 | x = torch.FloatTensor(np.asarray(x, dtype=np.float32)).permute(2, 0, 1)
13 | return x
14 |
15 |
16 | class Random90Rotation:
17 | def __call__(self, x):
18 | k = torch.ceil(3.0 * torch.rand(1)).long()
19 | u = torch.rand(1)
20 | if u < 0.5:
21 | x = x.rotate(90 * k)
22 | return x
23 |
24 |
25 | class ChannelSwap:
26 | def __call__(self, x):
27 | permutation = list(permutations(range(3), 3))[np.random.randint(0, 5)]
28 | u = torch.rand(1)
29 | if u < 0.5:
30 | x = np.array(x)[..., permutation]
31 | x = Image.fromarray(x)
32 | return x
33 |
34 |
35 | class DataModule:
36 | def __init__(
37 | self,
38 | batch_size,
39 | test_batch_size,
40 | root="data/",
41 | ):
42 | self.__dict__.update(locals())
43 | self.transforms = transforms.Compose(
44 | [
45 | ToTensor(),
46 | ]
47 | )
48 | self.test_transforms = transforms.Compose(
49 | [
50 | ToTensor(),
51 | ]
52 | )
53 | self.prepare_data()
54 |
55 | def prepare_data(self) -> None:
56 | """
57 | Download the data. Do preprocessing if necessary.
58 | :return:
59 | """
60 | raise NotImplementedError
61 |
62 | def setup(self) -> None:
63 | """
64 | Create self.train and self.val, self.test dataset
65 | :return: None
66 | """
67 | raise NotImplementedError
68 |
69 | def train_dataloader(self):
70 | params = {
71 | "pin_memory": True,
72 | "drop_last": True,
73 | "shuffle": True,
74 | "num_workers": 1,
75 | }
76 | train_loader = DataLoader(self.train, self.batch_size, **params)
77 | while True:
78 | yield from train_loader
79 |
80 | def val_dataloader(self):
81 | params = {
82 | "pin_memory": True,
83 | "drop_last": True,
84 | "shuffle": True,
85 | "num_workers": 1,
86 | }
87 | val_loader = DataLoader(self.val, self.test_batch_size, **params)
88 | while True:
89 | yield from val_loader
90 |
91 | def test_dataloader(self):
92 | test_loader = DataLoader(
93 | self.test,
94 | self.test_batch_size,
95 | num_workers=1,
96 | shuffle=False,
97 | pin_memory=True,
98 | drop_last=False,
99 | )
100 | return test_loader
101 |
--------------------------------------------------------------------------------
/src/model/vgg.py:
--------------------------------------------------------------------------------
1 | """VGG for CIFAR10. FC layers are removed.
2 | (c) YANG, Wei
3 | """
4 | import math
5 |
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | from model.classifier import ClassifierWrapper
10 |
11 | # adapted from https://github.com/alecwangcq/KFAC-Pytorch/blob/master/models/cifar/vgg.py
12 |
13 | __all__ = [
14 | "VGG",
15 | "vgg11",
16 | "vgg11_bn",
17 | "vgg13",
18 | "vgg13_bn",
19 | "vgg16",
20 | "vgg16_bn",
21 | "vgg19_bn",
22 | "vgg19",
23 | ]
24 |
25 |
26 | cfg = {
27 | # vgg11
28 | "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
29 | # vgg13:
30 | "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
31 | # vgg16:
32 | "D": [
33 | 64,
34 | 64,
35 | "M",
36 | 128,
37 | 128,
38 | "M",
39 | 256,
40 | 256,
41 | 256,
42 | "M",
43 | 512,
44 | 512,
45 | 512,
46 | "M",
47 | 512,
48 | 512,
49 | 512,
50 | "M",
51 | ],
52 | # vgg19:
53 | "E": [
54 | 64,
55 | 64,
56 | "M",
57 | 128,
58 | 128,
59 | "M",
60 | 256,
61 | 256,
62 | 256,
63 | 256,
64 | "M",
65 | 512,
66 | 512,
67 | 512,
68 | 512,
69 | "M",
70 | 512,
71 | 512,
72 | 512,
73 | 512,
74 | "M",
75 | ],
76 | }
77 |
78 |
79 | class VGG(ClassifierWrapper):
80 | def __init__(
81 | self,
82 | cfg_id,
83 | batch_norm=False,
84 | num_classes=1000,
85 | loss_fn=nn.CrossEntropyLoss(),
86 | **kwargs
87 | ):
88 | super().__init__(
89 | backbone=VGG.make_layers(cfg[cfg_id], batch_norm=batch_norm),
90 | loss_fn=loss_fn,
91 | )
92 | self.classifier = nn.Linear(512, num_classes)
93 | self._initialize_weights()
94 |
95 | def forward(self, batch):
96 | x = self.backbone(batch[0])
97 | x = F.avg_pool2d(x, kernel_size=x.shape[-1], stride=1)
98 | x = x.view(x.size(0), -1)
99 | x = self.classifier(x)
100 | return x
101 |
102 | def _initialize_weights(self):
103 | for m in self.modules():
104 | if isinstance(m, nn.Conv2d):
105 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
106 | m.weight.data.normal_(0, math.sqrt(2.0 / n))
107 | if m.bias is not None:
108 | m.bias.data.zero_()
109 | elif isinstance(m, nn.BatchNorm2d):
110 | m.weight.data.fill_(1)
111 | m.bias.data.zero_()
112 | elif isinstance(m, nn.Linear):
113 | n = m.weight.size(1)
114 | m.weight.data.normal_(0, 0.01)
115 | m.bias.data.zero_()
116 |
117 | @staticmethod
118 | def make_layers(cfg, batch_norm=False):
119 | layers = []
120 | in_channels = 3
121 | for v in cfg:
122 | if v == "M":
123 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
124 | else:
125 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
126 | if batch_norm:
127 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
128 | else:
129 | layers += [conv2d, nn.ReLU(inplace=True)]
130 | in_channels = v
131 | return nn.Sequential(*layers)
132 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Variational Stochastic Gradient Descent for Deep Neural Networks
2 |
3 | This repository contains the source code accompanying the paper:
4 |
5 | [Variational Stochastic Gradient Descent for Deep Neural Networks](https://openreview.net/forum?id=xu4ATNjcdy)
6 |
[[Demos]](https://github.com/generativeai-tue/vsgd/blob/main/notebooks)
7 |
8 |
**[Anna Kuzina\*](https://akuzina.github.io/), [Haotian Chen\*](https://www.linkedin.com/in/haotian-chen-359b4520b/), [Babak Esmaeili](https://babak0032.github.io), & [Jakub M. Tomczak](https://jmtomczak.github.io/)**.
9 |
10 |
11 | #### Abstract
12 | *Optimizing deep neural networks is one of the main tasks in successful deep learning. Current state-of-the-art optimizers are adaptive gradient-based optimization methods such as Adam. Recently, there has been an increasing interest in formulating gradient-based optimizers in a probabilistic framework for better estimation of gradients and modeling uncertainties. Here, we propose to combine both approaches, resulting in the Variational Stochastic Gradient Descent (VSGD) optimizer. We model gradient updates as a probabilistic model and utilize stochastic variational inference (SVI) to derive an efficient and effective update rule. Further, we show how our VSGD method relates to other adaptive gradient-based optimizers like Adam.
13 | Lastly,
14 | we carry out experiments on two image classification datasets and four deep neural network architectures, where we show that VSGD outperforms Adam and SGD.*
15 |
16 |
17 |
18 |
19 |
20 | ### Repository structure
21 |
22 | #### Folders
23 |
24 | This repository is organized as follows:
25 |
26 | * `src` contains the main PyTorch library
27 | * `configs` contains the default configuration for `src/run_experiment.py`
28 | * `notebooks` contains a demo of using VSGD optimizer
29 |
30 |
31 | ----
32 | ### Reproduce
33 |
34 | ###### Install conda *(recommended)*
35 |
36 | ```bash
37 | conda env create -f environment.yml
38 | conda activate vsgd
39 | ```
40 |
41 | ###### Login wandb *(recommended)*
42 | ```bash
43 | wandb login
44 | ```
45 |
46 | ###### Download TinyImagenet dataset
47 |
48 | ```bash
49 | cd data/
50 | wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
51 | unzip tiny-imagenet-200.zip
52 | ```
53 |
54 | ###### Starting an experiment
55 | All the experiments are run with `src/run_experiment.py`. Experiment configuration is handled by [Hydra](https://hydra.cc), one can find default configuration in the `configs/` folder.
56 |
57 | `configs/experiment/` contains configs for dataset-architecture pairs. For example, to train VGG model on cifar100 dataset with VSGD optimizer, run:
58 | ```bash
59 | PYTHONPATH=src/ python src/run_experiment.py experiment=cifar100_vgg train/optimizer=vsgd
60 | ```
61 |
62 | One can also change any default hyperparameters using the command line:
63 | ```bash
64 | PYTHONPATH=src/ python src/run_experiment.py experiment=cifar100_vgg train/optimizer=vsgd train.optimizer.weight_decay=0.01
65 | ```
66 |
67 |
68 | ----
69 |
70 | ### Cite
71 | If you found this work useful in your research, please consider citing:
72 |
73 | ```
74 | @article{
75 | chen2024variational,
76 | title={Variational Stochastic Gradient Descent for Deep Neural Networks},
77 | author={Chen, Haotian and Kuzina, Anna and Esmaeili, Babak and Tomczak, Jakub},
78 | year={2024},
79 | }
80 | ```
81 |
82 | ### Acknowledgements
83 | *Anna Kuzina is funded by the Hybrid Intelligence Center, a 10-year programme funded by the Dutch Ministry of Education, Culture and Science through the Netherlands Organisation for Scientific Research, https://hybrid-intelligence-centre.nl.
84 | This work was carried out on the Dutch national e-infrastructure with the support of SURF Cooperative.*
85 |
86 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | data/
2 | scripts/
3 | notebooks/
4 | pics/
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 | cover/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | .pybuilder/
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | # For a library or package, you might want to ignore these files since the code is
91 | # intended to run in multiple environments; otherwise, check them in:
92 | # .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # poetry
102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103 | # This is especially recommended for binary packages to ensure reproducibility, and is more
104 | # commonly ignored for libraries.
105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106 | #poetry.lock
107 |
108 | # pdm
109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110 | #pdm.lock
111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112 | # in version control.
113 | # https://pdm.fming.dev/#use-with-ide
114 | .pdm.toml
115 |
116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117 | __pypackages__/
118 |
119 | # Celery stuff
120 | celerybeat-schedule
121 | celerybeat.pid
122 |
123 | # SageMath parsed files
124 | *.sage.py
125 |
126 | # Environments
127 | .env
128 | .venv
129 | env/
130 | venv/
131 | ENV/
132 | env.bak/
133 | venv.bak/
134 |
135 | # Spyder project settings
136 | .spyderproject
137 | .spyproject
138 |
139 | # Rope project settings
140 | .ropeproject
141 |
142 | # mkdocs documentation
143 | /site
144 |
145 | # mypy
146 | .mypy_cache/
147 | .dmypy.json
148 | dmypy.json
149 |
150 | # Pyre type checker
151 | .pyre/
152 |
153 | # pytype static type analyzer
154 | .pytype/
155 |
156 | # Cython debug symbols
157 | cython_debug/
158 |
159 | # PyCharm
160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162 | # and can be added to the global gitignore or merged into this file. For a more nuclear
163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164 | #.idea/
--------------------------------------------------------------------------------
/src/run_experiment.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pprint import pprint
3 |
4 | import hydra.utils
5 | import numpy as np
6 | import omegaconf
7 | import torch
8 | import wandb
9 | from hydra.utils import instantiate
10 |
11 | import utils.tester as tester
12 | import utils.trainer as trainer
13 | from utils.wandb import get_checkpoint
14 |
15 |
16 | def params_to(param, device):
17 | param.data = param.data.to(device)
18 | if param._grad is not None:
19 | param._grad.data = param._grad.data.to(device)
20 |
21 |
22 | def optimizer_to(optim, device):
23 | for param in optim.state.values():
24 | if isinstance(param, torch.Tensor):
25 | params_to(param, device)
26 | elif isinstance(param, dict):
27 | for subparam in param.values():
28 | if isinstance(subparam, torch.Tensor):
29 | params_to(subparam, device)
30 |
31 |
32 | def load_from_checkpoint(args, model, optimizer=None, scheduler=None):
33 | chpt = get_checkpoint(
34 | args.wandb.setup.entity,
35 | args.wandb.setup.project,
36 | args.train.resume_id,
37 | device="cpu",
38 | )
39 | args.train.start_iter = chpt["iteration"]
40 | # Load model and ema model
41 | model.load_state_dict(chpt["model_state_dict"])
42 |
43 | # Load optimizer
44 | if optimizer is not None:
45 | opt_state_dict = chpt["optimizer_state_dict"]
46 | optimizer.load_state_dict(opt_state_dict)
47 | optimizer_to(optimizer, args.train.device)
48 |
49 | # Load scheduler
50 | if scheduler is not None:
51 | scheduler_state_dict = chpt["scheduler_state_dict"]
52 | scheduler.load_state_dict(scheduler_state_dict)
53 |
54 | return args, model, optimizer, scheduler
55 |
56 |
57 | def init_wandb(args):
58 | wandb.require("service")
59 |
60 | tags = [
61 | args.dataset.name,
62 | args.model.name,
63 | args.train.optimizer._target_,
64 | args.train.experiment_name,
65 | ]
66 | if args.train.resume_id is not None:
67 | wandb.init(
68 | **args.wandb.setup,
69 | id=args.train.resume_id,
70 | resume="must",
71 | settings=wandb.Settings(start_method="thread"),
72 | )
73 | else:
74 | wandb_cfg = omegaconf.OmegaConf.to_container(
75 | args, resolve=True, throw_on_missing=True
76 | )
77 | wandb.init(
78 | **args.wandb.setup,
79 | config=wandb_cfg,
80 | group=f"{args.model.name}_{args.dataset.name}"
81 | if args.wandb.group is None
82 | else args.wandb.group,
83 | tags=tags,
84 | dir=hydra.utils.get_original_cwd(),
85 | settings=wandb.Settings(start_method="thread"),
86 | )
87 | pprint(wandb.run.config)
88 | # define our custom x axis metric
89 | wandb.define_metric("iter")
90 | for pref in ["train", "val", "pic"]:
91 | wandb.define_metric(f"{pref}/*", step_metric="iter")
92 | wandb.define_metric("val/loss", summary="min", step_metric="iter")
93 | wandb.define_metric("test/loss", summary="min", step_metric="iter")
94 |
95 |
96 | def compute_params(model, args):
97 | # add network size
98 | num_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
99 | print(num_param)
100 | wandb.run.summary["num_parameters"] = num_param
101 |
102 |
103 | @hydra.main(version_base="1.3", config_path="../configs", config_name="defaults.yaml")
104 | def run(args: omegaconf.DictConfig) -> None:
105 | # set cuda visible devices
106 | if args.train.device[-1] == "0":
107 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
108 | args.train.device = "cuda"
109 | elif args.train.device[-1] == "1":
110 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
111 | args.train.device = "cuda"
112 |
113 | # Set the seed
114 | torch.manual_seed(args.train.seed)
115 | torch.cuda.manual_seed(args.train.seed)
116 | np.random.seed(args.train.seed)
117 | torch.backends.cudnn.deterministic = True
118 | torch.backends.cudnn.benchmark = False
119 |
120 | # ------------
121 | # data
122 | # ------------
123 | dset_params = {"root": os.path.join(hydra.utils.get_original_cwd(), "data/")}
124 | data_module = instantiate(args.dataset.data_module, **dset_params)
125 | data_module.setup()
126 | train_loader = data_module.train_dataloader()
127 | val_loader = data_module.val_dataloader()
128 | test_loader = data_module.test_dataloader()
129 |
130 | # ------------
131 | # model & optimizer
132 | # ------------
133 | model = instantiate(args.model)
134 | optimizer = instantiate(args.train.optimizer, params=model.parameters())
135 | scheduler = None
136 | if hasattr(args.train, "scheduler"):
137 | scheduler = instantiate(args.train.scheduler, optimizer=optimizer)
138 |
139 | if args.train.resume_id is not None:
140 | print(f"Resume training {args.train.resume_id}")
141 | args, model, optimizer, scheduler = load_from_checkpoint(
142 | args, model, optimizer, scheduler
143 | )
144 |
145 | model.train()
146 | model.to(args.train.device)
147 |
148 | # ------------
149 | # logging
150 | # ------------
151 | init_wandb(args)
152 | wandb.watch(model, **args.wandb.watch)
153 | compute_params(model, args)
154 |
155 | # ------------
156 | # training
157 | # ------------
158 | if args.train.start_iter < args.train.max_iter:
159 | trainer.train(
160 | args.train,
161 | train_loader,
162 | val_loader,
163 | test_loader,
164 | model,
165 | optimizer,
166 | scheduler,
167 | )
168 |
169 | # ------------
170 | # testing
171 | # ------------
172 | model = instantiate(args.model)
173 | with omegaconf.open_dict(args):
174 | args.train.resume_id = wandb.run.id
175 | _, model, _, _ = load_from_checkpoint(args, model)
176 | model.to(args.train.device)
177 |
178 | tester.test(
179 | args.train,
180 | test_loader,
181 | model,
182 | )
183 | print("Test finished")
184 | wandb.finish()
185 |
186 |
187 | if __name__ == "__main__":
188 | run()
189 |
--------------------------------------------------------------------------------
/src/dataset/tiny_imagenet.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 |
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from torchvision import transforms
7 |
8 | from dataset.data_module import DataModule, ToTensor
9 |
10 |
11 | class TinyImagenet(DataModule):
12 | def __init__(
13 | self,
14 | batch_size,
15 | test_batch_size,
16 | root,
17 | use_augmentations,
18 | ):
19 | super().__init__(
20 | batch_size=batch_size,
21 | test_batch_size=test_batch_size,
22 | root=root,
23 | )
24 | self.__dict__.update(locals())
25 | if use_augmentations:
26 | self.transforms = transforms.Compose(
27 | [
28 | transforms.RandomCrop(size=64, padding=4),
29 | transforms.RandomHorizontalFlip(),
30 | transforms.RandomAffine(
31 | degrees=45, translate=(0.1, 0.1), scale=(0.9, 1.1)
32 | ),
33 | transforms.ColorJitter(
34 | brightness=0.2, contrast=0.2, saturation=0.2
35 | ),
36 | transforms.ToTensor(),
37 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
38 | ]
39 | )
40 | self.test_transforms = transforms.Compose(
41 | [
42 | transforms.ToTensor(),
43 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
44 | ]
45 | )
46 |
47 | else:
48 | self.transforms = transforms.Compose(
49 | [
50 | transforms.RandomHorizontalFlip(),
51 | ToTensor(),
52 | ]
53 | )
54 | self.test_transforms = transforms.Compose([ToTensor()])
55 |
56 | def prepare_data(self) -> None:
57 | pass
58 |
59 | def setup(self):
60 | self.train = TinyImageNet(self.root, split="train", transform=self.transforms)
61 | self.val = TinyImageNet(self.root, split="val", transform=self.test_transforms)
62 | self.test = TinyImageNet(self.root, split="val", transform=self.test_transforms)
63 |
64 |
65 | EXTENSION = "JPEG"
66 | NUM_IMAGES_PER_CLASS = 500
67 | CLASS_LIST_FILE = "wnids.txt"
68 | VAL_ANNOTATION_FILE = "val_annotations.txt"
69 |
70 |
71 | class TinyImageNet(Dataset):
72 | """Tiny ImageNet data set available from `http://cs231n.stanford.edu/tiny-imagenet-200.zip`.
73 | Dataset code adapted from https://github.com/leemengtw/tiny-imagenet/blob/master/TinyImageNet.py
74 |
75 | Parameters
76 | ----------
77 | root: string
78 | Root directory including `train`, `test` and `val` subdirectories.
79 | split: string
80 | Indicating which split to return as a data set.
81 | Valid option: [`train`, `test`, `val`]
82 | transform: torchvision.transforms
83 | A (series) of valid transformation(s).
84 | in_memory: bool
85 | Set to True if there is enough memory (about 5G) and want to minimize disk IO overhead.
86 | """
87 |
88 | def __init__(
89 | self,
90 | root,
91 | split="train",
92 | transform=None,
93 | target_transform=None,
94 | in_memory=False,
95 | ):
96 | self.root = os.path.join(os.path.expanduser(root), "tiny-imagenet-200")
97 | self.split = split
98 | self.transform = transform
99 | self.target_transform = target_transform
100 | self.in_memory = in_memory
101 | self.split_dir = os.path.join(self.root, self.split)
102 | self.image_paths = sorted(
103 | glob.iglob(
104 | os.path.join(self.split_dir, "**", "*.%s" % EXTENSION), recursive=True
105 | )
106 | )
107 | self.labels = {} # fname - label number mapping
108 | self.images = [] # used for in-memory processing
109 |
110 | # build class label - number mapping
111 | with open(os.path.join(self.root, CLASS_LIST_FILE), "r") as fp:
112 | self.label_texts = sorted([text.strip() for text in fp.readlines()])
113 | self.label_text_to_number = {text: i for i, text in enumerate(self.label_texts)}
114 |
115 | if self.split == "train":
116 | for label_text, i in self.label_text_to_number.items():
117 | for cnt in range(NUM_IMAGES_PER_CLASS):
118 | self.labels["%s_%d.%s" % (label_text, cnt, EXTENSION)] = i
119 | elif self.split == "val":
120 | with open(os.path.join(self.split_dir, VAL_ANNOTATION_FILE), "r") as fp:
121 | for line in fp.readlines():
122 | terms = line.split("\t")
123 | file_name, label_text = terms[0], terms[1]
124 | self.labels[file_name] = self.label_text_to_number[label_text]
125 |
126 | # read all images into torch tensor in memory to minimize disk IO overhead
127 | if self.in_memory:
128 | self.images = [self.read_image(path) for path in self.image_paths]
129 |
130 | def __len__(self):
131 | return len(self.image_paths)
132 |
133 | def __getitem__(self, index):
134 | file_path = self.image_paths[index]
135 |
136 | if self.in_memory:
137 | img = self.images[index]
138 | else:
139 | img = self.read_image(file_path)
140 | if self.split == "test":
141 | return img
142 | else:
143 | # file_name = file_path.split('/')[-1]
144 | return img, self.labels[os.path.basename(file_path)]
145 |
146 | def __repr__(self):
147 | fmt_str = "Dataset " + self.__class__.__name__ + "\n"
148 | fmt_str += " Number of datapoints: {}\n".format(self.__len__())
149 | tmp = self.split
150 | fmt_str += " Split: {}\n".format(tmp)
151 | fmt_str += " Root Location: {}\n".format(self.root)
152 | tmp = " Transforms (if any): "
153 | fmt_str += "{0}{1}\n".format(
154 | tmp, self.transform.__repr__().replace("\n", "\n" + " " * len(tmp))
155 | )
156 | tmp = " Target Transforms (if any): "
157 | fmt_str += "{0}{1}".format(
158 | tmp, self.target_transform.__repr__().replace("\n", "\n" + " " * len(tmp))
159 | )
160 | return fmt_str
161 |
162 | def read_image(self, path):
163 | # img = imageio.imread(path, pilmode='RGB')
164 | img = Image.open(path)
165 | img = img.convert("RGB")
166 | return self.transform(img) if self.transform else img
167 |
--------------------------------------------------------------------------------
/src/vsgd.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import torch
4 | from torch import Tensor
5 | from torch.optim.optimizer import Optimizer, required
6 |
7 |
8 | class VSGD(Optimizer):
9 | def __init__(
10 | self,
11 | params: required,
12 | ghattg: float = 30.0,
13 | ps: float = 1e-8,
14 | tau1: float = 0.81,
15 | tau2: float = 0.9,
16 | lr: float = 0.1,
17 | weight_decay: float = 0.0,
18 | eps: float = 1e-8,
19 | ):
20 | """
21 | Args:
22 | ghattg: prior variance ratio between ghat and g,
23 | Var(ghat_t-g_t)/Var(g_t-g_{t-1}).
24 | ps: piror strength.
25 | tau1: remember rate for the gamma parameters of g
26 | tau2: remember rate for the gamma parameter of ghat
27 | lr: learning rate.
28 | weight_decay (float): weight decay coefficient (default: 0.0)
29 | """
30 |
31 | if not 0.0 <= weight_decay:
32 | raise ValueError(f"Invalid weight_decay value: {weight_decay}")
33 | defaults = dict(
34 | ghattg=ghattg,
35 | ps=ps,
36 | tau1=tau1,
37 | tau2=tau2,
38 | lr=lr,
39 | weight_decay=weight_decay,
40 | eps=eps,
41 | )
42 | super().__init__(params, defaults)
43 |
44 | def __setstate__(self, state):
45 | super(VSGD, self).__setstate__(state)
46 |
47 | def step(self, closure=None):
48 | """Performs a single optimization step.
49 |
50 | Args:
51 | closure (Callable, optional): A closure that reevaluates the model
52 | and returns the loss.
53 | """
54 | # self._cuda_graph_capture_health_check()
55 |
56 | loss = None
57 | if closure is not None:
58 | with torch.enable_grad():
59 | loss = closure()
60 |
61 | for group in self.param_groups:
62 | params_with_grad = []
63 | grads = []
64 | mug_list = []
65 | step_list = []
66 | pa2_list = []
67 | pbg2_list = []
68 | pbhg2_list = []
69 | bg_list = []
70 | bhg_list = []
71 |
72 | self._init_group(
73 | group,
74 | params_with_grad,
75 | grads,
76 | mug_list,
77 | step_list,
78 | pa2_list,
79 | pbg2_list,
80 | pbhg2_list,
81 | bg_list,
82 | bhg_list,
83 | group["ghattg"],
84 | group["ps"],
85 | )
86 |
87 | vsgd(
88 | params_with_grad,
89 | grads,
90 | mug_list,
91 | step_list,
92 | pa2_list,
93 | pbg2_list,
94 | pbhg2_list,
95 | bg_list,
96 | bhg_list,
97 | group["tau1"],
98 | group["tau2"],
99 | group["lr"],
100 | group["weight_decay"],
101 | group["eps"],
102 | )
103 |
104 | return loss
105 |
106 | def _init_group(
107 | self,
108 | group,
109 | params_with_grad: List[Tensor],
110 | grads: List[Tensor],
111 | mug_list: List,
112 | step_list: List,
113 | pa2_list: List,
114 | pbg2_list: List,
115 | pbhg2_list: List,
116 | bg_list: List,
117 | bhg_list: List,
118 | ghattg: float,
119 | ps: float,
120 | ):
121 | for p in group["params"]:
122 | if p.grad is None:
123 | continue
124 | params_with_grad.append(p)
125 |
126 | grads.append(p.grad)
127 | state = self.state[p]
128 |
129 | # State initialization
130 | if len(state) == 0:
131 | for k in ["mug", "bg", "bhg"]:
132 | # set a non zero small number to represent prior ignornance
133 | state[k] = torch.zeros_like(p, memory_format=torch.preserve_format)
134 | # initialize 2*a_0 and 2*b_0 as constants
135 | state["pa2"] = torch.tensor(2.0 * ps + 1.0 + 1e-4)
136 | state["pbg2"] = torch.tensor(2.0 * ps)
137 | state["pbhg2"] = torch.tensor(2.0 * ghattg * ps)
138 | state["step"] = torch.tensor(0.0)
139 |
140 | mug_list.append(state["mug"])
141 | bg_list.append(state["bg"])
142 | bhg_list.append(state["bhg"])
143 | step_list.append(state["step"])
144 | pa2_list.append(state["pa2"])
145 | pbg2_list.append(state["pbg2"])
146 | pbhg2_list.append(state["pbhg2"])
147 |
148 | def get_current_beta1_estimate(self) -> Tensor:
149 | betas = []
150 | for group in self.param_groups:
151 | for p in group["params"]:
152 | state = self.state[p]
153 | bg = state["bg"]
154 | bhg = state["bhg"]
155 | betas.append((bhg / (bg + bhg)).data)
156 | return betas
157 |
158 |
159 | def vsgd(
160 | params_with_grad: List[Tensor],
161 | grads: List[Tensor],
162 | mug_list: List[Tensor],
163 | step_list: List[Tensor],
164 | pa2_list: List[Tensor],
165 | pbg2_list: List[Tensor],
166 | pbhg2_list: List[Tensor],
167 | bg_list: List[Tensor],
168 | bhg_list: List[Tensor],
169 | tau1: float,
170 | tau2: float,
171 | lr: float,
172 | weight_decay: float,
173 | eps: float,
174 | ):
175 | for i, param in enumerate(params_with_grad):
176 | ghat = grads[i]
177 | mug = mug_list[i]
178 | mug1 = torch.clone(mug)
179 | step = step_list[i]
180 | step += 1
181 | pa2 = pa2_list[i]
182 | pbg2 = pbg2_list[i]
183 | pbhg2 = pbhg2_list[i]
184 | bg = bg_list[i]
185 | bhg = bhg_list[i]
186 | # weight decay following AdamW
187 | param.data.mul_(1 - lr * weight_decay)
188 |
189 | # variances of g and ghat
190 | if step == 1.0:
191 | sg = pbg2 / (pa2 - 1.0)
192 | shg = pbhg2 / (pa2 - 1.0)
193 | else:
194 | sg = bg / pa2
195 | shg = bhg / pa2
196 | # update muh, mug, Sigg and Sigh
197 | mug.copy_((ghat * sg + mug1 * shg) / (sg + shg))
198 | sigg = sg * shg / (sg + shg)
199 |
200 | # update 2*b
201 | mug_sq = sigg + mug**2
202 | bg2 = pbg2 + mug_sq - 2.0 * mug * mug1 + mug1**2
203 | bhg2 = pbhg2 + mug_sq - 2.0 * ghat * mug + ghat**2
204 |
205 | rho1 = step ** (-tau1)
206 | rho2 = step ** (-tau2)
207 | bg.mul_(1.0 - rho1).add_(bg2, alpha=rho1)
208 | bhg.mul_(1.0 - rho2).add_(bhg2, alpha=rho2)
209 |
210 | # update param
211 | param.data.add_(lr / (torch.sqrt(mug_sq) + eps) * mug, alpha=-1.0)
212 |
--------------------------------------------------------------------------------
/src/model/resnext.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | """
4 | Creates a ResNeXt Model as defined in:
5 | Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016).
6 | Aggregated residual transformations for deep neural networks.
7 | arXiv preprint arXiv:1611.05431.
8 | import from https://github.com/prlz77/ResNeXt.pytorch/blob/master/models/model.py
9 | """
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torch.nn import init
13 |
14 | from model.classifier import ClassifierWrapper
15 |
16 | __all__ = ["resnext"]
17 |
18 |
19 | class ResNeXtBottleneck(nn.Module):
20 | """
21 | RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua)
22 | """
23 |
24 | def __init__(self, in_channels, out_channels, stride, cardinality, widen_factor):
25 | """Constructor
26 | Args:
27 | in_channels: input channel dimensionality
28 | out_channels: output channel dimensionality
29 | stride: conv stride. Replaces pooling layer.
30 | cardinality: num of convolution groups.
31 | widen_factor: factor to reduce the input dimensionality before convolution.
32 | """
33 | super(ResNeXtBottleneck, self).__init__()
34 | D = cardinality * out_channels // widen_factor
35 | self.conv_reduce = nn.Conv2d(
36 | in_channels, D, kernel_size=1, stride=1, padding=0, bias=False
37 | )
38 | self.bn_reduce = nn.BatchNorm2d(D)
39 | self.conv_conv = nn.Conv2d(
40 | D,
41 | D,
42 | kernel_size=3,
43 | stride=stride,
44 | padding=1,
45 | groups=cardinality,
46 | bias=False,
47 | )
48 | self.bn = nn.BatchNorm2d(D)
49 | self.conv_expand = nn.Conv2d(
50 | D, out_channels, kernel_size=1, stride=1, padding=0, bias=False
51 | )
52 | self.bn_expand = nn.BatchNorm2d(out_channels)
53 |
54 | self.shortcut = nn.Sequential()
55 | if in_channels != out_channels:
56 | self.shortcut.add_module(
57 | "shortcut_conv",
58 | nn.Conv2d(
59 | in_channels,
60 | out_channels,
61 | kernel_size=1,
62 | stride=stride,
63 | padding=0,
64 | bias=False,
65 | ),
66 | )
67 | self.shortcut.add_module("shortcut_bn", nn.BatchNorm2d(out_channels))
68 |
69 | def forward(self, x):
70 | bottleneck = self.conv_reduce.forward(x)
71 | bottleneck = F.relu(self.bn_reduce.forward(bottleneck), inplace=True)
72 | bottleneck = self.conv_conv.forward(bottleneck)
73 | bottleneck = F.relu(self.bn.forward(bottleneck), inplace=True)
74 | bottleneck = self.conv_expand.forward(bottleneck)
75 | bottleneck = self.bn_expand.forward(bottleneck)
76 | residual = self.shortcut.forward(x)
77 | return F.relu(residual + bottleneck, inplace=True)
78 |
79 |
80 | class ResNeXt(ClassifierWrapper):
81 | """
82 | ResNext optimized for the Cifar dataset, as specified in
83 | https://arxiv.org/pdf/1611.05431.pdf
84 | """
85 |
86 | def __init__(
87 | self,
88 | cardinality,
89 | depth,
90 | num_classes,
91 | widen_factor=4,
92 | dropRate=0,
93 | loss_fn=nn.CrossEntropyLoss(),
94 | **kwargs
95 | ):
96 | """Constructor
97 | Args:
98 | cardinality: number of convolution groups.
99 | depth: number of layers.
100 | num_classes: number of classes
101 | widen_factor: factor to adjust the channel dimensionality
102 | """
103 | super().__init__(backbone=nn.Identity(), loss_fn=loss_fn)
104 | self.cardinality = cardinality
105 | self.depth = depth
106 | self.block_depth = (self.depth - 2) // 9
107 | self.widen_factor = widen_factor
108 | self.num_classes = num_classes
109 | self.output_size = 64
110 | self.stages = [
111 | 64,
112 | 64 * self.widen_factor,
113 | 128 * self.widen_factor,
114 | 256 * self.widen_factor,
115 | ]
116 |
117 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
118 | self.bn_1 = nn.BatchNorm2d(64)
119 | self.stage_1 = self.block("stage_1", self.stages[0], self.stages[1], 1)
120 | self.stage_2 = self.block("stage_2", self.stages[1], self.stages[2], 2)
121 | self.stage_3 = self.block("stage_3", self.stages[2], self.stages[3], 2)
122 | self.classifier = nn.Linear(1024, num_classes)
123 | init.kaiming_normal(self.classifier.weight)
124 |
125 | for key in self.state_dict():
126 | if key.split(".")[-1] == "weight":
127 | if "conv" in key:
128 | init.kaiming_normal(self.state_dict()[key], mode="fan_out")
129 | if "bn" in key:
130 | self.state_dict()[key][...] = 1
131 | elif key.split(".")[-1] == "bias":
132 | self.state_dict()[key][...] = 0
133 |
134 | def block(self, name, in_channels, out_channels, pool_stride=2):
135 | """Stack n bottleneck modules where n is inferred from the depth of the network.
136 | Args:
137 | name: string name of the current block.
138 | in_channels: number of input channels
139 | out_channels: number of output channels
140 | pool_stride: factor to reduce the spatial dimensionality in the first bottleneck of the block.
141 | Returns: a Module consisting of n sequential bottlenecks.
142 | """
143 | block = nn.Sequential()
144 | for bottleneck in range(self.block_depth):
145 | name_ = "%s_bottleneck_%d" % (name, bottleneck)
146 | if bottleneck == 0:
147 | block.add_module(
148 | name_,
149 | ResNeXtBottleneck(
150 | in_channels,
151 | out_channels,
152 | pool_stride,
153 | self.cardinality,
154 | self.widen_factor,
155 | ),
156 | )
157 | else:
158 | block.add_module(
159 | name_,
160 | ResNeXtBottleneck(
161 | out_channels,
162 | out_channels,
163 | 1,
164 | self.cardinality,
165 | self.widen_factor,
166 | ),
167 | )
168 | return block
169 |
170 | def forward(self, batch):
171 | x = self.conv_1_3x3.forward(batch[0])
172 | x = F.relu(self.bn_1.forward(x), inplace=True)
173 | x = self.stage_1.forward(x)
174 | x = self.stage_2.forward(x)
175 | x = self.stage_3.forward(x)
176 | x = F.avg_pool2d(x, kernel_size=x.shape[-1], stride=1)
177 | x = x.view(-1, 1024)
178 | return self.classifier(x)
179 |
--------------------------------------------------------------------------------
/src/utils/trainer.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import time
4 |
5 | import torch
6 | import wandb
7 |
8 | from utils.tester import test
9 |
10 |
11 | def save_chpt(args, iteration, model, optimizer, scheduler, loss, name="last_chpt"):
12 | chpt = {
13 | "iteration": iteration,
14 | "model_state_dict": model.state_dict(),
15 | "optimizer_state_dict": optimizer.state_dict(),
16 | "scheduler_state_dict": None if scheduler is None else scheduler.state_dict(),
17 | "loss": loss,
18 | }
19 | torch.save(chpt, os.path.join(wandb.run.dir, f"{name}.pth"))
20 | wandb.save(os.path.join(wandb.run.dir, f"{name}.pth"), base_path=wandb.run.dir)
21 | print("->model saved<-\n")
22 |
23 |
24 | def train(
25 | args,
26 | train_loader,
27 | val_loader,
28 | test_loader,
29 | model,
30 | optimizer,
31 | scheduler,
32 | ):
33 | with torch.no_grad():
34 | if val_loader is not None:
35 | # compute metrics on initialization
36 | batch = next(val_loader)
37 | history_val = run_iter(
38 | args=args,
39 | iteration=args.start_iter,
40 | batch=batch,
41 | model=model,
42 | optimizer=None,
43 | mode="val",
44 | )
45 | wandb.log({**history_val, "iter": args.start_iter})
46 |
47 | for iteration in range(args.start_iter, args.max_iter):
48 | batch = next(train_loader)
49 |
50 | time_start = time.time()
51 | history_train = run_iter(
52 | args,
53 | iteration=iteration,
54 | batch=batch,
55 | model=model,
56 | optimizer=optimizer,
57 | mode="train",
58 | )
59 |
60 | train_elapsed = time.time() - time_start
61 | time_start = time.time()
62 | history_val = {}
63 |
64 | if val_loader is not None:
65 | batch = next(val_loader)
66 | with torch.no_grad():
67 | history_val = run_iter(
68 | args,
69 | iteration=iteration + 1,
70 | batch=batch,
71 | model=model,
72 | optimizer=None,
73 | mode="val",
74 | )
75 |
76 | if scheduler is not None:
77 | if scheduler.__class__.__name__ == "ReduceLROnPlateau":
78 | scheduler.step(history_val["val/loss"])
79 | else:
80 | scheduler.step()
81 |
82 | val_elapsed = time.time() - time_start
83 | hist = {
84 | **history_train,
85 | **history_val,
86 | "train_time": train_elapsed,
87 | "val_time": val_elapsed,
88 | }
89 |
90 | # save metrics to wandb
91 | wandb.log(hist)
92 | # save checkpoint
93 | if iteration % args.save_freq == 0 or iteration == args.max_iter:
94 | loss = hist["train/loss"]
95 | if "val/loss" in hist.keys():
96 | loss = hist["val/loss"]
97 | save_chpt(
98 | args,
99 | iteration,
100 | model,
101 | optimizer,
102 | scheduler,
103 | loss,
104 | )
105 |
106 | if iteration % 100 == 0:
107 | print(
108 | "Iteration: {}/{}, Time elapsed: {:.2f}s\n"
109 | "* Train loss: {:.2f} \n".format(
110 | iteration + 1,
111 | args.max_iter,
112 | val_elapsed + train_elapsed,
113 | hist["train/loss"],
114 | )
115 | )
116 | if "val/loss" in hist.keys():
117 | if math.isnan(hist["val/loss"]):
118 | print("Nan loss, stopping training")
119 | break
120 |
121 | # run test eval to track the performance
122 | if (iteration + 1) % args.eval_test_freq == 0 and (
123 | iteration + 1
124 | ) < args.max_iter:
125 | print("Run test evaluation...")
126 | with torch.no_grad():
127 | test(
128 | args=args,
129 | loader=test_loader,
130 | model=model,
131 | )
132 |
133 | print("Save last checkpoint")
134 | loss = hist["train/loss"]
135 | if "val/loss" in hist.keys():
136 | loss = hist["val/loss"]
137 | save_chpt(args, args.max_iter, model, optimizer, scheduler, loss)
138 |
139 |
140 | def run_iter(args, iteration, batch, model, optimizer, mode="train"):
141 | if mode == "train":
142 | model.train()
143 | try:
144 | lr = optimizer.param_groups[0]["lr"]
145 | except:
146 | lr = 0.0
147 | history = {"lr": lr, "iter": iteration + 1}
148 | if iteration > 0:
149 | if args.optimizer_log_freq > 0 and iteration % args.optimizer_log_freq == 0:
150 | if hasattr(optimizer, "get_current_beta1_estimate"):
151 | vals = optimizer.get_current_beta1_estimate()
152 | vals = torch.cat([x.reshape(-1) for x in vals]).reshape(1, -1).cpu()
153 | history["beta1"] = wandb.Histogram(vals)
154 | history["beta1_median"] = vals.median()
155 | history["beta1_mean"] = vals.mean()
156 | elif "betas" in optimizer.param_groups[0]:
157 | beta1 = optimizer.param_groups[0]["betas"][0]
158 | bias_correction = 1 - beta1**iteration
159 | history["beta1_median"] = beta1 / bias_correction
160 | history["beta1_mean"] = beta1 / bias_correction
161 | elif "momentum" in optimizer.param_groups[0]:
162 | beta1 = optimizer.param_groups[0]["momentum"]
163 | history["beta1_median"] = beta1
164 | history["beta1_mean"] = beta1
165 |
166 | elif mode == "val":
167 | model.eval()
168 | history = {}
169 |
170 | if "cuda" in args.device:
171 | for i in range(len(batch)):
172 | batch[i] = batch[i].cuda(non_blocking=True)
173 | # Loss
174 | logs = {}
175 | if mode == "train":
176 | loss, logs = model.train_step(batch, device=args.device)
177 | elif mode == "val":
178 | with torch.no_grad():
179 | loss, logs = model.train_step(batch, device=args.device)
180 |
181 | if mode == "train":
182 | optimize(args, loss, model, optimizer)
183 |
184 | # Get the history
185 | for k in logs.keys():
186 | h_key = k
187 | if "/" not in k:
188 | h_key = f"{mode}/{k}"
189 | if "hist" in k:
190 | history[h_key] = wandb.Histogram(logs[k])
191 | else:
192 | history[h_key] = logs[k]
193 |
194 | return history
195 |
196 |
197 | def optim_step(params, optimizer, grad_clip_val, grad_skip_val):
198 | # clip gradient
199 | grad_norm = torch.nn.utils.clip_grad_norm_(params, grad_clip_val).item()
200 |
201 | if grad_skip_val == 0 or grad_norm < grad_skip_val:
202 | optimizer.step()
203 | return grad_norm
204 |
205 |
206 | def optimize(args, loss, model, optimizer):
207 | if args.grad_clip > 0:
208 | clip_to = args.grad_clip
209 | else:
210 | clip_to = 1e6
211 |
212 | logs = {"skipped_steps": 1}
213 | nans = torch.isnan(loss).sum().item()
214 | if nans == 0:
215 | logs["skipped_steps"] = 0
216 | # backprop through the main loss
217 | optimizer.zero_grad()
218 | loss.backward()
219 | params = [p for n, p in model.named_parameters() if p.requires_grad]
220 |
221 | grad_norm = optim_step(
222 | params=params,
223 | optimizer=optimizer,
224 | grad_clip_val=clip_to,
225 | grad_skip_val=args.grad_skip_thr,
226 | )
227 | logs["grad_norm"] = grad_norm
228 |
229 | wandb.log(logs)
230 |
--------------------------------------------------------------------------------
/notebooks/vsgd_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import matplotlib as mpl\n",
10 | "import matplotlib\n",
11 | "import matplotlib.pyplot as plt\n",
12 | "\n",
13 | "import numpy as np\n",
14 | "import os\n",
15 | "import torch\n",
16 | "\n",
17 | "mpl.rcParams['text.usetex'] = True\n",
18 | "mpl.rcParams['text.latex.preamble'] = r'\\usepackage{amsmath}'\n",
19 | "plt.rcParams['figure.figsize'] = [9, 7]\n",
20 | "\n",
21 | "\n",
22 | "# Append ../src to path\n",
23 | "import sys\n",
24 | "source_path = os.path.join(os.getcwd(), '../src')\n",
25 | "if source_path not in sys.path:\n",
26 | " sys.path.append(source_path)\n",
27 | " "
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "metadata": {},
33 | "source": [
34 | "## 1. Get the datasets"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": 2,
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "from torchvision.datasets import MNIST\n",
44 | "from torch.utils.data import DataLoader\n",
45 | "from torchvision import transforms\n",
46 | "from torchvision.transforms import ToTensor\n",
47 | "\n",
48 | "train_dataset = MNIST(root='../data/', download=True, train=True, transform=ToTensor())\n",
49 | "test_dataset = MNIST(root='../data/', download=True, train=False, transform=ToTensor())\n"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "metadata": {},
55 | "source": [
56 | "## 2. Define NN\n",
57 | "\n",
58 | "In this example, we use the convmixer architecture\n",
59 | "\n",
60 | "\n",
61 | "code source: https://github.com/locuslab/convmixer-cifar10/blob/main/train.py"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": 3,
67 | "metadata": {},
68 | "outputs": [],
69 | "source": [
70 | "import torch.nn as nn \n",
71 | "\n",
72 | "class Residual(nn.Module):\n",
73 | " def __init__(self, fn):\n",
74 | " super().__init__()\n",
75 | " self.fn = fn\n",
76 | "\n",
77 | " def forward(self, x):\n",
78 | " return self.fn(x) + x\n",
79 | " \n",
80 | "def build_convmixer(dim, patch_size, kernel_size, depth, num_classes):\n",
81 | " return nn.Sequential(\n",
82 | " nn.Conv2d(1, dim, kernel_size=patch_size, stride=patch_size),\n",
83 | " nn.GELU(),\n",
84 | " nn.BatchNorm2d(dim),\n",
85 | " *[\n",
86 | " nn.Sequential(\n",
87 | " Residual(\n",
88 | " nn.Sequential(\n",
89 | " nn.Conv2d(\n",
90 | " dim, dim, kernel_size, groups=dim, padding=\"same\"\n",
91 | " ),\n",
92 | " nn.GELU(),\n",
93 | " nn.BatchNorm2d(dim),\n",
94 | " )\n",
95 | " ),\n",
96 | " nn.Conv2d(dim, dim, kernel_size=1),\n",
97 | " nn.GELU(),\n",
98 | " nn.BatchNorm2d(dim),\n",
99 | " )\n",
100 | " for _ in range(depth)\n",
101 | " ],\n",
102 | " nn.AdaptiveAvgPool2d((1, 1)),\n",
103 | " nn.Flatten(),\n",
104 | " nn.Linear(dim, num_classes)\n",
105 | " )"
106 | ]
107 | },
108 | {
109 | "cell_type": "markdown",
110 | "metadata": {},
111 | "source": [
112 | "## 3. Train and test functions"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": 4,
118 | "metadata": {},
119 | "outputs": [],
120 | "source": [
121 | "\n",
122 | "def train(net, train_dataset, optimizer, max_epochs, batch_size):\n",
123 | " device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
124 | " net.to(device)\n",
125 | "\n",
126 | " criterion = nn.CrossEntropyLoss()\n",
127 | " trainloader = DataLoader(train_dataset, \n",
128 | " batch_size=batch_size, \n",
129 | " shuffle=True, \n",
130 | " num_workers=2)\n",
131 | "\n",
132 | " logs = {'train_loss': []}\n",
133 | " for epoch in range(max_epochs): # loop over the dataset multiple times\n",
134 | "\n",
135 | " running_loss = 0.0\n",
136 | " for i, data in enumerate(trainloader, 0):\n",
137 | " # get the inputs; data is a list of [inputs, labels]\n",
138 | " inputs, labels = data\n",
139 | " inputs = inputs.to(device)\n",
140 | " labels = labels.to(device)\n",
141 | " \n",
142 | "\n",
143 | " # zero the parameter gradients\n",
144 | " optimizer.zero_grad()\n",
145 | "\n",
146 | " # forward + backward + optimize\n",
147 | " outputs = net(inputs)\n",
148 | " loss = criterion(outputs, labels)\n",
149 | " loss.backward()\n",
150 | " optimizer.step()\n",
151 | "\n",
152 | " running_loss += loss.item()\n",
153 | " logs['train_loss'].append(loss.item())\n",
154 | " \n",
155 | " print(f\"Epoch {epoch+1}, loss: {running_loss / len(trainloader): .2f}\")\n",
156 | " print('Finished Training')\n",
157 | " return net, logs\n",
158 | "\n",
159 | "def test(test_dataset, net, batch_size):\n",
160 | " \n",
161 | " device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
162 | " net.to(device)\n",
163 | "\n",
164 | " criterion = nn.CrossEntropyLoss()\n",
165 | " testloader = DataLoader(test_dataset, \n",
166 | " batch_size=batch_size, \n",
167 | " shuffle=False, \n",
168 | " num_workers=2,\n",
169 | " drop_last=False,\n",
170 | " )\n",
171 | "\n",
172 | " # logs = {'test_accuracy': []}\n",
173 | " correct_clfs = 0.\n",
174 | " running_loss = 0.\n",
175 | " for i, data in enumerate(testloader, 0):\n",
176 | " # get the inputs; data is a list of [inputs, labels]\n",
177 | " inputs, labels = data\n",
178 | " inputs = inputs.to(device)\n",
179 | " labels = labels.to(device)\n",
180 | " \n",
181 | " # forward \n",
182 | " logits = net(inputs)\n",
183 | " loss = criterion(logits, labels)\n",
184 | " \n",
185 | " running_loss += loss.item()\n",
186 | "\n",
187 | " correct_clfs += (logits.argmax(dim=1) == labels).float().sum().item()\n",
188 | " N_points = len(test_dataset)\n",
189 | " return running_loss / N_points, correct_clfs / N_points\n"
190 | ]
191 | },
192 | {
193 | "cell_type": "markdown",
194 | "metadata": {},
195 | "source": [
196 | "## 4. Compare to Adam and SGD"
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": 14,
202 | "metadata": {},
203 | "outputs": [
204 | {
205 | "name": "stdout",
206 | "output_type": "stream",
207 | "text": [
208 | "Epoch 1, loss: 0.27\n",
209 | "Epoch 2, loss: 0.10\n",
210 | "Epoch 3, loss: 0.09\n",
211 | "Finished Training\n",
212 | "Test accuracy: 98.08%\n"
213 | ]
214 | }
215 | ],
216 | "source": [
217 | "from vsgd import VSGD\n",
218 | "\n",
219 | "net = build_convmixer(dim=16, patch_size=2, kernel_size=3, depth=4, num_classes=10)\n",
220 | "vsgd = VSGD(net.parameters(), lr=0.01, ps=1e-7)\n",
221 | "net, logs_vsgd = train(net, train_dataset, vsgd, max_epochs=3, batch_size=32)\n",
222 | "logs_vsgd['test_loss'], logs_vsgd['test_acc'] = test(test_dataset, net, 32)\n",
223 | "\n",
224 | "acc = logs_vsgd['test_acc']*100\n",
225 | "print(f'Test accuracy: {acc:.2f}%')"
226 | ]
227 | },
228 | {
229 | "cell_type": "code",
230 | "execution_count": 17,
231 | "metadata": {},
232 | "outputs": [
233 | {
234 | "name": "stdout",
235 | "output_type": "stream",
236 | "text": [
237 | "Epoch 1, loss: 0.57\n",
238 | "Epoch 2, loss: 0.13\n",
239 | "Epoch 3, loss: 0.10\n",
240 | "Finished Training\n",
241 | "Test accuracy: 97.81%\n"
242 | ]
243 | }
244 | ],
245 | "source": [
246 | "from torch.optim import SGD\n",
247 | "\n",
248 | "sgd_net = build_convmixer(dim=16, patch_size=2, kernel_size=3, depth=4, num_classes=10)\n",
249 | "sgd = SGD(sgd_net.parameters(), lr=0.01, momentum=0.9)\n",
250 | "sgd_net, logs_sgd = train(sgd_net, train_dataset, sgd, max_epochs=3, batch_size=32)\n",
251 | "logs_sgd['test_loss'], logs_sgd['test_acc'] = test(test_dataset, sgd_net, 32)\n",
252 | "\n",
253 | "acc = logs_sgd['test_acc']*100\n",
254 | "print(f'Test accuracy: {acc:.2f}%')"
255 | ]
256 | },
257 | {
258 | "cell_type": "code",
259 | "execution_count": 15,
260 | "metadata": {},
261 | "outputs": [
262 | {
263 | "name": "stdout",
264 | "output_type": "stream",
265 | "text": [
266 | "Epoch 1, loss: 0.24\n",
267 | "Epoch 2, loss: 0.10\n",
268 | "Epoch 3, loss: 0.09\n",
269 | "Finished Training\n",
270 | "Test accuracy: 97.68%\n"
271 | ]
272 | }
273 | ],
274 | "source": [
275 | "from torch.optim import Adam\n",
276 | "\n",
277 | "adam_net = build_convmixer(dim=16, patch_size=2, kernel_size=3, depth=4, num_classes=10)\n",
278 | "adam = Adam(adam_net.parameters(), lr=0.01)\n",
279 | "adam_net, logs_adam = train(adam_net, train_dataset, adam, max_epochs=3, batch_size=32)\n",
280 | "logs_adam['test_loss'], logs_adam['test_acc'] = test(test_dataset, adam_net, batch_size=32)\n",
281 | "\n",
282 | "acc = logs_adam['test_acc']*100\n",
283 | "print(f'Test accuracy: {acc:.2f}%')"
284 | ]
285 | },
286 | {
287 | "cell_type": "code",
288 | "execution_count": 22,
289 | "metadata": {},
290 | "outputs": [
291 | {
292 | "data": {
293 | "image/png": "",
294 | "text/plain": [
295 | ""
296 | ]
297 | },
298 | "metadata": {
299 | "needs_background": "light"
300 | },
301 | "output_type": "display_data"
302 | }
303 | ],
304 | "source": [
305 | "fig, ax = plt.subplots(1, 2, figsize=(12, 5))\n",
306 | "ax[0].plot(logs_vsgd['train_loss'], label='VSGD', alpha=0.7)\n",
307 | "ax[0].plot(logs_sgd['train_loss'], label='SGD', alpha=0.7)\n",
308 | "ax[0].plot(logs_adam['train_loss'], label='Adam', alpha=0.7)\n",
309 | "ax[0].legend(fontsize=20)\n",
310 | "ax[0].grid()\n",
311 | "ax[0].set_title('Training loss', fontsize=24)\n",
312 | "\n",
313 | "ax[1].bar(['VSGD', 'SGD', 'Adam'], [logs_vsgd['test_acc'], logs_sgd['test_acc'], logs_adam['test_acc']]);\n",
314 | "ax[1].grid()\n",
315 | "ax[1].set_ylim(0.97, 1.0);\n",
316 | "ax[1].set_title('Test accuracy', fontsize=24);"
317 | ]
318 | },
319 | {
320 | "cell_type": "code",
321 | "execution_count": null,
322 | "metadata": {},
323 | "outputs": [],
324 | "source": []
325 | }
326 | ],
327 | "metadata": {
328 | "kernelspec": {
329 | "display_name": "BASE_ENV",
330 | "language": "python",
331 | "name": "base_env"
332 | },
333 | "language_info": {
334 | "codemirror_mode": {
335 | "name": "ipython",
336 | "version": 3
337 | },
338 | "file_extension": ".py",
339 | "mimetype": "text/x-python",
340 | "name": "python",
341 | "nbconvert_exporter": "python",
342 | "pygments_lexer": "ipython3",
343 | "version": "3.6.12"
344 | }
345 | },
346 | "nbformat": 4,
347 | "nbformat_minor": 4
348 | }
349 |
--------------------------------------------------------------------------------