├── .gitignore ├── README.md ├── classification ├── README.md ├── configs │ └── cifar10 │ │ ├── lenet_adam.json │ │ ├── lenet_kfac.json │ │ ├── lenet_noisykfac.json │ │ └── lenet_vogn.json ├── main.py └── models │ ├── __init__.py │ ├── alexnet.py │ ├── lenet.py │ ├── mlp.py │ ├── resnet.py │ └── vgg.py ├── distributed ├── README.md └── classification │ ├── README.md │ ├── configs │ ├── cifar10 │ │ ├── alexnet_adam_bs256_8gpu.json │ │ ├── alexnet_kfac_bs256_8gpu.json │ │ ├── alexnet_noisykfac_bs256_8gpu.json │ │ ├── alexnet_sgd_bs256_8gpu.json │ │ ├── alexnet_vogn_bs256_8gpu.json │ │ ├── lenet_adam_bs128_4gpu.json │ │ ├── lenet_kfac_bs128_4gpu.json │ │ ├── lenet_sgd_bs128_4gpu.json │ │ ├── lenet_vogn_bs128_4gpu.json │ │ ├── resnet18_adam_bs256_8gpu.json │ │ ├── resnet18_sgd_bs256_8gpu.json │ │ ├── resnet18_vogn_bs256_8gpu.json │ │ ├── vgg19_adam_bs256_8gpu.json │ │ └── vgg19_vogn_bs256_8gpu.json │ ├── cifar100 │ │ ├── alexnet_adam_bs256_8gpu.json │ │ ├── alexnet_sgd_bs256_8gpu.json │ │ ├── alexnet_vogn_bs256_8gpu.json │ │ ├── resnet18_adam_bs256_8gpu.json │ │ └── resnet18_vogn_bs256_8gpu.json │ └── imagenet │ │ ├── resnet18_adam_bs4k_128gpu.json │ │ ├── resnet18_kfac_bs4k_128gpu.json │ │ ├── resnet18_kfac_bs4k_4gpu.json │ │ ├── resnet18_noisykfac_bs4k_128gpu.json │ │ ├── resnet18_noisykfac_bs4k_4gpu.json │ │ ├── resnet18_sgd_bs4k_128gpu.json │ │ ├── resnet18_vogn_bs4k_128gpu.json │ │ └── resnet18_vogn_bs4k_4gpu.json │ ├── main.py │ └── models │ ├── __init__.py │ ├── alexnet.py │ ├── lenet.py │ ├── resnet.py │ ├── resnet_b.py │ ├── resnext.py │ └── vgg.py ├── docs ├── boundary.gif ├── curves.png └── distributed_vi.png ├── neurips2019_poster.pdf └── toy_example ├── README.md └── main.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # PyCharm 107 | .idea 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Practical Deep Learning with Bayesian Principles 2 | This repository contains code that demonstrate 3 | practical applications of Bayesian principles to Deep Learning. 4 | Our implementation contains an Adam-like optimizer, called 5 | [VOGN](http://proceedings.mlr.press/v80/khan18a.html), 6 | to obtain uncertainty in Deep Learning. 7 | 8 | - 2D-binary classification (see [toy example](./toy_example)) 9 | - Image classification ([MNIST](./classification), 10 | [CIFAR-10/100](./classification), 11 | and [ImageNet](./distributed/classification)) 12 | - Continual learning for image classification (permuted MNIST) 13 | - Per-pixel semantic labeling & segmentation (Cityscapes) 14 | 15 | ## Setup 16 | This repository uses [PyTorch-SSO](https://github.com/cybertronai/pytorch-sso), a PyTorch extension for second-order optimization, variational inference, and distributed training. 17 | 18 | ```bash 19 | $ git clone git@github.com:cybertronai/pytorch-sso.git 20 | $ cd pytorch-sso 21 | $ python setup.py install 22 | ``` 23 | Please follow the 24 | [Installation](https://github.com/cybertronai/pytorch-sso#installation) 25 | of PyTorch-SSO for CUDA/MPI support. 26 | 27 | 28 | ## Bayesian Uncertainty Estimation 29 | Decision boundary and entropy plots on 2D-binary classification by MLPs trained 30 | with Adam and VOGN. 31 | ![](./docs/boundary.gif) 32 | VOGN optimizes the posterior distribution of each weight (i.e., mean and variance of the Gaussian). 33 | A model with the mean weights draws the red boundary, and models with the MC samples from the posterior distribution draw light red boundaries. 34 | VOGN converges to a similar solution as Adam while keeping uncertainty in its predictions. 35 | 36 | With PyTorch-SSO (`torchsso`), you can run VOGN training by changing a line in your train script: 37 | ```diff 38 | import torch 39 | +import torchsso 40 | 41 | train_loader = torch.utils.data.DataLoader(train_dataset) 42 | model = MLP() 43 | 44 | -optimizer = torch.optim.Adam(model.parameters()) 45 | +optimizer = torchsso.optim.VOGN(model, dataset_size=len(train_loader.dataset)) 46 | 47 | for data, target in train_loader: 48 | 49 | def closure(): 50 | optimizer.zero_grad() 51 | output = model(data) 52 | loss = F.binary_cross_entropy_with_logits(output, target) 53 | loss.backward() 54 | return loss, output 55 | 56 | loss, output = optimizer.step(closure) 57 | 58 | ``` 59 | 60 | To train MLPs by VOGN and Adam and create GIF 61 | ```bash 62 | $ cd toy_example 63 | $ python main.py 64 | ``` 65 | For detail, please see [VOGN implementation in PyTorch-SSO](https://github.com/cybertronai/pytorch-sso/blob/master/torchsso/optim/vi.py). 66 | 67 | ## Bayes for Image Classification 68 | This repository contains code for the NeurIPS 2019 paper "[Practical Deep Learning with Bayesian Principles](https://arxiv.org/abs/1906.02506)," 69 | [[poster](./neurips2019_poster.pdf)] 70 | which includes the results of **Large-scale Variational Inference on ImageNet classification**. 71 | 72 | ![](./docs/curves.png) 73 | VOGN achieves similar performance in about the same number of epochs as Adam and SGD. 74 | Importantly, the benefits of Bayesian principles are preserved: predictive probabilities are well-calibrated (rightmost figure), 75 | uncertainties on out-of-distribution data are improved (please refer the paper), 76 | and continual-learning performance is boosted (please refer the paper, an example is to be prepared). 77 | 78 | See [classification](./classification) (single CPU/GPU) or [distributed/classification](./distributed/classification) (multiple GPUs) for example scripts. 79 | 80 | 81 | ## Citation 82 | NeurIPS 2019 paper 83 | ``` 84 | @article{osawa2019practical, 85 | title = {Practical Deep Learning with Bayesian Principles}, 86 | author = {Osawa, Kazuki and Swaroop, Siddharth and Jain, Anirudh and Eschenhagen, Runa and Turner, Richard E. and Yokota, Rio and Khan, Mohammad Emtiyaz}, 87 | journal = {arXiv preprint arXiv:1906.02506}, 88 | year = {2019} 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /classification/README.md: -------------------------------------------------------------------------------- 1 | To run training LeNet-5 for CIFAR-10 classification 2 | ```bash 3 | python main.py --config --download 4 | ``` 5 | | optimizer | dataset | architecture | config file path | 6 | | --- | --- | --- | --- | 7 | | [Adam](https://arxiv.org/abs/1412.6980) | CIFAR-10 | LeNet-5 | [configs/cifar10/lenet_adam.json](./configs/cifar10/lenet_adam.json) | 8 | | [K-FAC](https://arxiv.org/abs/1503.05671)| CIFAR-10 | LeNet-5 | [configs/cifar10/lenet_kfac.json](./configs/cifar10/lenet_kfac.json) | 9 | | [Noisy K-FAC](https://arxiv.org/abs/1712.02390)| CIFAR-10 | LeNet-5 | [configs/cifar10/lenet_noisykfac.json](./configs/cifar10/lenet_noisykfac.json) | 10 | | [VOGN](https://arxiv.org/abs/1806.04854)| CIFAR-10 | LeNet-5 + BatchNorm | [configs/cifar10/lenet_vogn.json](./configs/cifar10/lenet_vogn.json) | 11 | -------------------------------------------------------------------------------- /classification/configs/cifar10/lenet_adam.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 100, 4 | "batch_size": 128, 5 | "val_batch_size": 128, 6 | "random_crop": false, 7 | "random_horizontal_flip": false, 8 | "normalizing_data": true, 9 | "arch_file": "models/lenet.py", 10 | "arch_name": "LeNet5", 11 | "optim_name": "Adam", 12 | "optim_args": { 13 | "lr": 1e-3, 14 | "betas": [0.9, 0.999], 15 | "weight_decay": 0.01 16 | } 17 | } -------------------------------------------------------------------------------- /classification/configs/cifar10/lenet_kfac.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 50, 4 | "batch_size": 128, 5 | "val_batch_size": 5000, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/lenet.py", 10 | "arch_name": "LeNet5", 11 | "optim_name": "SecondOrderOptimizer", 12 | "optim_args": { 13 | "curv_type":"Fisher", 14 | "curv_shapes": { 15 | "Conv2d": "Kron", 16 | "Linear": "Kron", 17 | "BatchNorm1d": "Diag", 18 | "BatchNorm2d": "Diag" 19 | }, 20 | "lr": 1e-3, 21 | "momentum": 0.9, 22 | "momentum_type": "raw", 23 | "l2_reg": 1e-3, 24 | "acc_steps": 1 25 | }, 26 | "curv_args": { 27 | "damping": 1e-3, 28 | "ema_decay": 0.999, 29 | "pi_type": "tracenorm" 30 | }, 31 | "fisher_args": { 32 | "approx_type": "mc", 33 | "num_mc": 1 34 | }, 35 | "scheduler_name": "ExponentialLR", 36 | "scheduler_args": { 37 | "gamma": 0.9 38 | }, 39 | "log_interval": 64, 40 | "no_cuda": false 41 | } -------------------------------------------------------------------------------- /classification/configs/cifar10/lenet_noisykfac.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 15, 4 | "batch_size": 64, 5 | "val_batch_size": 128, 6 | "random_crop": false, 7 | "random_horizontal_flip": false, 8 | "normalizing_data": false, 9 | "arch_file": "models/lenet.py", 10 | "arch_name": "LeNet5", 11 | "optim_name": "VIOptimizer", 12 | "optim_args": { 13 | "curv_type": "Fisher", 14 | "curv_shapes": { 15 | "Conv2d": "Kron", 16 | "Linear": "Kron" 17 | }, 18 | "lr": 4e-3, 19 | "momentum": 0.9, 20 | "momentum_type": "preconditioned", 21 | "weight_decay": 0.1, 22 | "num_mc_samples": 4, 23 | "val_num_mc_samples": 0, 24 | "kl_weighting": 0.2, 25 | "prior_variance": 1 26 | }, 27 | "curv_args": { 28 | "damping": 1e-4, 29 | "ema_decay": 0.333, 30 | "pi_type": "tracenorm" 31 | }, 32 | "fisher_args": { 33 | "approx_type": "mc", 34 | "num_mc": 1 35 | }, 36 | "scheduler_name": "ExponentialLR", 37 | "scheduler_args": { 38 | "gamma": 0.9 39 | }, 40 | "no_cuda": false 41 | } 42 | -------------------------------------------------------------------------------- /classification/configs/cifar10/lenet_vogn.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 30, 4 | "batch_size": 128, 5 | "val_batch_size": 128, 6 | "random_crop": false, 7 | "random_horizontal_flip": false, 8 | "normalizing_data": true, 9 | "arch_file": "models/lenet.py", 10 | "arch_name": "LeNet5BatchNorm", 11 | "arch_args": { 12 | "affine": true 13 | }, 14 | "optim_name": "VIOptimizer", 15 | "optim_args": { 16 | "curv_type": "Cov", 17 | "curv_shapes": { 18 | "Conv2d": "Diag", 19 | "Linear": "Diag", 20 | "BatchNorm1d": "Diag", 21 | "BatchNorm2d": "Diag" 22 | }, 23 | "lr": 0.01, 24 | "grad_ema_decay": 0.1, 25 | "grad_ema_type": "raw", 26 | "num_mc_samples": 10, 27 | "val_num_mc_samples": 0, 28 | "kl_weighting": 1, 29 | "init_precision": 8e-3, 30 | "prior_variance": 1, 31 | "acc_steps": 1 32 | }, 33 | "curv_args": { 34 | "damping": 0, 35 | "ema_decay": 0.001 36 | }, 37 | "scheduler_name": "ExponentialLR", 38 | "scheduler_args": { 39 | "gamma": 0.9 40 | }, 41 | "no_cuda": false 42 | } 43 | -------------------------------------------------------------------------------- /classification/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from importlib import import_module 4 | import shutil 5 | import json 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torchvision import datasets, transforms, models 10 | import torchsso 11 | from torchsso.optim import SecondOrderOptimizer, VIOptimizer 12 | from torchsso.utils import Logger 13 | 14 | DATASET_CIFAR10 = 'CIFAR-10' 15 | DATASET_CIFAR100 = 'CIFAR-100' 16 | DATASET_MNIST = 'MNIST' 17 | 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser() 21 | # Data 22 | parser.add_argument('--dataset', type=str, 23 | choices=[DATASET_CIFAR10, DATASET_CIFAR100, DATASET_MNIST], default=DATASET_CIFAR10, 24 | help='name of dataset') 25 | parser.add_argument('--root', type=str, default='./data', 26 | help='root of dataset') 27 | parser.add_argument('--epochs', type=int, default=10, 28 | help='number of epochs to train') 29 | parser.add_argument('--batch_size', type=int, default=128, 30 | help='input batch size for training') 31 | parser.add_argument('--val_batch_size', type=int, default=128, 32 | help='input batch size for valing') 33 | parser.add_argument('--normalizing_data', action='store_true', 34 | help='[data pre processing] normalizing data') 35 | parser.add_argument('--random_crop', action='store_true', 36 | help='[data augmentation] random crop') 37 | parser.add_argument('--random_horizontal_flip', action='store_true', 38 | help='[data augmentation] random horizontal flip') 39 | # Training Settings 40 | parser.add_argument('--arch_file', type=str, default=None, 41 | help='name of file which defines the architecture') 42 | parser.add_argument('--arch_name', type=str, default='LeNet5', 43 | help='name of the architecture') 44 | parser.add_argument('--arch_args', type=json.loads, default=None, 45 | help='[JSON] arguments for the architecture') 46 | parser.add_argument('--optim_name', type=str, default=SecondOrderOptimizer.__name__, 47 | help='name of the optimizer') 48 | parser.add_argument('--optim_args', type=json.loads, default=None, 49 | help='[JSON] arguments for the optimizer') 50 | parser.add_argument('--curv_args', type=json.loads, default=dict(), 51 | help='[JSON] arguments for the curvature') 52 | parser.add_argument('--fisher_args', type=json.loads, default=dict(), 53 | help='[JSON] arguments for the fisher') 54 | parser.add_argument('--scheduler_name', type=str, default=None, 55 | help='name of the learning rate scheduler') 56 | parser.add_argument('--scheduler_args', type=json.loads, default=None, 57 | help='[JSON] arguments for the scheduler') 58 | # Options 59 | parser.add_argument('--download', action='store_true', default=False, 60 | help='if True, downloads the dataset (CIFAR-10 or 100) from the internet') 61 | parser.add_argument('--create_graph', action='store_true', default=False, 62 | help='create graph of the derivative') 63 | parser.add_argument('--no_cuda', action='store_true', default=False, 64 | help='disables CUDA training') 65 | parser.add_argument('--seed', type=int, default=1, 66 | help='random seed') 67 | parser.add_argument('--num_workers', type=int, default=0, 68 | help='number of sub processes for data loading') 69 | parser.add_argument('--log_interval', type=int, default=50, 70 | help='how many batches to wait before logging training status') 71 | parser.add_argument('--log_file_name', type=str, default='log', 72 | help='log file name') 73 | parser.add_argument('--checkpoint_interval', type=int, default=50, 74 | help='how many epochs to wait before logging training status') 75 | parser.add_argument('--resume', type=str, default=None, 76 | help='checkpoint path for resume training') 77 | parser.add_argument('--out', type=str, default='result', 78 | help='dir to save output files') 79 | parser.add_argument('--config', default='configs/cifar10/lenet_kfac.json', 80 | help='config file path') 81 | 82 | args = parser.parse_args() 83 | dict_args = vars(args) 84 | 85 | # Load config file 86 | if args.config is not None: 87 | with open(args.config) as f: 88 | config = json.load(f) 89 | dict_args.update(config) 90 | 91 | # Set device 92 | use_cuda = not args.no_cuda and torch.cuda.is_available() 93 | device = torch.device('cuda' if use_cuda else 'cpu') 94 | 95 | # Set random seed 96 | torch.manual_seed(args.seed) 97 | 98 | # Setup data augmentation & data pre processing 99 | train_transforms, val_transforms = [], [] 100 | if args.random_crop: 101 | train_transforms.append(transforms.RandomCrop(32, padding=4)) 102 | 103 | if args.random_horizontal_flip: 104 | train_transforms.append(transforms.RandomHorizontalFlip()) 105 | 106 | train_transforms.append(transforms.ToTensor()) 107 | val_transforms.append(transforms.ToTensor()) 108 | 109 | if args.normalizing_data: 110 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 111 | train_transforms.append(normalize) 112 | val_transforms.append(normalize) 113 | 114 | train_transform = transforms.Compose(train_transforms) 115 | val_transform = transforms.Compose(val_transforms) 116 | 117 | # Setup data loader 118 | if args.dataset == DATASET_CIFAR10: 119 | # CIFAR-10 120 | num_classes = 10 121 | dataset_class = datasets.CIFAR10 122 | elif args.dataset == DATASET_CIFAR100: 123 | # CIFAR-100 124 | num_classes = 100 125 | dataset_class = datasets.CIFAR100 126 | elif args.dataset == DATASET_MNIST: 127 | num_classes = 10 128 | dataset_class = datasets.MNIST 129 | else: 130 | assert False, f'unknown dataset {args.dataset}' 131 | 132 | train_dataset = dataset_class( 133 | root=args.root, train=True, download=args.download, transform=train_transform) 134 | val_dataset = dataset_class( 135 | root=args.root, train=False, download=args.download, transform=val_transform) 136 | 137 | train_loader = torch.utils.data.DataLoader( 138 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 139 | val_loader = torch.utils.data.DataLoader( 140 | val_dataset, batch_size=args.val_batch_size, shuffle=False, num_workers=args.num_workers) 141 | 142 | # Setup model 143 | if args.arch_file is None: 144 | arch_class = getattr(models, args.arch_name) 145 | else: 146 | _, ext = os.path.splitext(args.arch_file) 147 | dirname = os.path.dirname(args.arch_file) 148 | 149 | if dirname == '': 150 | module_path = args.arch_file.replace(ext, '') 151 | elif dirname == '.': 152 | module_path = os.path.basename(args.arch_file).replace(ext, '') 153 | else: 154 | module_path = '.'.join(os.path.split(args.arch_file)).replace(ext, '') 155 | 156 | module = import_module(module_path) 157 | arch_class = getattr(module, args.arch_name) 158 | 159 | arch_kwargs = {} if args.arch_args is None else args.arch_args 160 | arch_kwargs['num_classes'] = num_classes 161 | 162 | model = arch_class(**arch_kwargs) 163 | setattr(model, 'num_classes', num_classes) 164 | model = model.to(device) 165 | 166 | optim_kwargs = {} if args.optim_args is None else args.optim_args 167 | 168 | # Setup optimizer 169 | if args.optim_name == SecondOrderOptimizer.__name__: 170 | optimizer = SecondOrderOptimizer(model, **optim_kwargs, curv_kwargs=args.curv_args) 171 | elif args.optim_name == VIOptimizer.__name__: 172 | optimizer = VIOptimizer(model, dataset_size=len(train_loader.dataset), seed=args.seed, 173 | **optim_kwargs, curv_kwargs=args.curv_args) 174 | else: 175 | optim_class = getattr(torch.optim, args.optim_name) 176 | optimizer = optim_class(model.parameters(), **optim_kwargs) 177 | 178 | # Setup lr scheduler 179 | if args.scheduler_name is None: 180 | scheduler = None 181 | else: 182 | scheduler_class = getattr(torchsso.optim.lr_scheduler, args.scheduler_name, None) 183 | if scheduler_class is None: 184 | scheduler_class = getattr(torch.optim.lr_scheduler, args.scheduler_name) 185 | scheduler_kwargs = {} if args.scheduler_args is None else args.scheduler_args 186 | scheduler = scheduler_class(optimizer, **scheduler_kwargs) 187 | 188 | start_epoch = 1 189 | 190 | # Load checkpoint 191 | if args.resume is not None: 192 | print('==> Resuming from checkpoint..') 193 | assert os.path.exists(args.resume), 'Error: no checkpoint file found' 194 | checkpoint = torch.load(args.resume) 195 | model.load_state_dict(checkpoint['model']) 196 | start_epoch = checkpoint['epoch'] 197 | 198 | # All config 199 | print('===========================') 200 | for key, val in vars(args).items(): 201 | if key == 'dataset': 202 | print('{}: {}'.format(key, val)) 203 | print('train data size: {}'.format(len(train_loader.dataset))) 204 | print('val data size: {}'.format(len(val_loader.dataset))) 205 | else: 206 | print('{}: {}'.format(key, val)) 207 | print('===========================') 208 | 209 | # Copy this file & config to args.out 210 | if not os.path.isdir(args.out): 211 | os.makedirs(args.out) 212 | shutil.copy(os.path.realpath(__file__), args.out) 213 | 214 | if args.config is not None: 215 | shutil.copy(args.config, args.out) 216 | if args.arch_file is not None: 217 | shutil.copy(args.arch_file, args.out) 218 | 219 | # Setup logger 220 | logger = Logger(args.out, args.log_file_name) 221 | logger.start() 222 | 223 | # Run training 224 | for epoch in range(start_epoch, args.epochs + 1): 225 | 226 | # train 227 | accuracy, loss, confidence = train(model, device, train_loader, optimizer, scheduler, epoch, args, logger) 228 | 229 | # val 230 | val_accuracy, val_loss = validate(model, device, val_loader, optimizer) 231 | 232 | # save log 233 | iteration = epoch * len(train_loader) 234 | log = {'epoch': epoch, 'iteration': iteration, 235 | 'accuracy': accuracy, 'loss': loss, 'confidence': confidence, 236 | 'val_accuracy': val_accuracy, 'val_loss': val_loss, 237 | 'lr': optimizer.param_groups[0]['lr'], 238 | 'momentum': optimizer.param_groups[0].get('momentum', 0)} 239 | logger.write(log) 240 | 241 | # save checkpoint 242 | if epoch % args.checkpoint_interval == 0 or epoch == args.epochs: 243 | path = os.path.join(args.out, 'epoch{}.ckpt'.format(epoch)) 244 | data = { 245 | 'model': model.state_dict(), 246 | 'optimizer': optimizer.state_dict(), 247 | 'epoch': epoch 248 | } 249 | torch.save(data, path) 250 | 251 | 252 | def train(model, device, train_loader, optimizer, scheduler, epoch, args, logger): 253 | 254 | def scheduler_type(_scheduler): 255 | if _scheduler is None: 256 | return 'none' 257 | return getattr(_scheduler, 'scheduler_type', 'epoch') 258 | 259 | if scheduler_type(scheduler) == 'epoch': 260 | scheduler.step(epoch - 1) 261 | 262 | model.train() 263 | 264 | total_correct = 0 265 | loss = None 266 | confidence = {'top1': 0, 'top1_true': 0, 'top1_false': 0, 'true': 0, 'false': 0} 267 | total_data_size = 0 268 | epoch_size = len(train_loader.dataset) 269 | num_iters_in_epoch = len(train_loader) 270 | base_num_iter = (epoch - 1) * num_iters_in_epoch 271 | 272 | for batch_idx, (data, target) in enumerate(train_loader): 273 | data, target = data.to(device), target.to(device) 274 | 275 | if scheduler_type(scheduler) == 'iter': 276 | scheduler.step() 277 | 278 | for name, param in model.named_parameters(): 279 | attr = 'p_pre_{}'.format(name) 280 | setattr(model, attr, param.detach().clone()) 281 | 282 | # update params 283 | def closure(): 284 | optimizer.zero_grad() 285 | output = model(data) 286 | loss = F.cross_entropy(output, target) 287 | loss.backward(create_graph=args.create_graph) 288 | 289 | return loss, output 290 | 291 | if isinstance(optimizer, SecondOrderOptimizer) and optimizer.curv_type == 'Fisher': 292 | closure = torchsso.get_closure_for_fisher(optimizer, model, data, target, **args.fisher_args) 293 | 294 | loss, output = optimizer.step(closure=closure) 295 | 296 | pred = output.argmax(dim=1, keepdim=True) 297 | correct = pred.eq(target.view_as(pred)).sum().item() 298 | 299 | loss = loss.item() 300 | total_correct += correct 301 | 302 | prob = F.softmax(output, dim=1) 303 | for p, idx in zip(prob, target): 304 | confidence['top1'] += torch.max(p).item() 305 | top1 = torch.argmax(p).item() 306 | if top1 == idx: 307 | confidence['top1_true'] += p[top1].item() 308 | else: 309 | confidence['top1_false'] += p[top1].item() 310 | confidence['true'] += p[idx].item() 311 | confidence['false'] += (1 - p[idx].item()) 312 | 313 | iteration = base_num_iter + batch_idx + 1 314 | total_data_size += len(data) 315 | 316 | if batch_idx % args.log_interval == 0: 317 | accuracy = 100. * total_correct / total_data_size 318 | elapsed_time = logger.elapsed_time 319 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, ' 320 | 'Accuracy: {:.0f}/{} ({:.2f}%), ' 321 | 'Elapsed Time: {:.1f}s'.format( 322 | epoch, total_data_size, epoch_size, 100. * (batch_idx + 1) / num_iters_in_epoch, 323 | loss, total_correct, total_data_size, accuracy, elapsed_time)) 324 | 325 | # save log 326 | lr = optimizer.param_groups[0]['lr'] 327 | log = {'epoch': epoch, 'iteration': iteration, 'elapsed_time': elapsed_time, 328 | 'accuracy': accuracy, 'loss': loss, 'lr': lr} 329 | 330 | for name, param in model.named_parameters(): 331 | attr = 'p_pre_{}'.format(name) 332 | p_pre = getattr(model, attr) 333 | p_norm = param.norm().item() 334 | p_shape = list(param.size()) 335 | p_pre_norm = p_pre.norm().item() 336 | g_norm = param.grad.norm().item() 337 | upd_norm = param.sub(p_pre).norm().item() 338 | noise_scale = getattr(param, 'noise_scale', 0) 339 | 340 | p_log = {'p_shape': p_shape, 'p_norm': p_norm, 'p_pre_norm': p_pre_norm, 341 | 'g_norm': g_norm, 'upd_norm': upd_norm, 'noise_scale': noise_scale} 342 | log[name] = p_log 343 | 344 | logger.write(log) 345 | 346 | accuracy = 100. * total_correct / epoch_size 347 | confidence['top1'] /= epoch_size 348 | confidence['top1_true'] /= total_correct 349 | confidence['top1_false'] /= (epoch_size - total_correct) 350 | confidence['true'] /= epoch_size 351 | confidence['false'] /= (epoch_size * (model.num_classes - 1)) 352 | 353 | return accuracy, loss, confidence 354 | 355 | 356 | def validate(model, device, val_loader, optimizer): 357 | model.eval() 358 | val_loss = 0 359 | correct = 0 360 | 361 | with torch.no_grad(): 362 | for data, target in val_loader: 363 | 364 | data, target = data.to(device), target.to(device) 365 | 366 | if isinstance(optimizer, VIOptimizer): 367 | output = optimizer.prediction(data) 368 | else: 369 | output = model(data) 370 | 371 | val_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss 372 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 373 | correct += pred.eq(target.view_as(pred)).sum().item() 374 | 375 | val_loss /= len(val_loader.dataset) 376 | val_accuracy = 100. * correct / len(val_loader.dataset) 377 | 378 | print('\nEval: Average loss: {:.4f}, Accuracy: {:.0f}/{} ({:.2f}%)\n'.format( 379 | val_loss, correct, len(val_loader.dataset), val_accuracy)) 380 | 381 | return val_accuracy, val_loss 382 | 383 | 384 | if __name__ == '__main__': 385 | main() 386 | -------------------------------------------------------------------------------- /classification/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import * 2 | from .lenet import * 3 | from .resnet import * 4 | from .alexnet import * 5 | from .mlp import * 6 | -------------------------------------------------------------------------------- /classification/models/alexnet.py: -------------------------------------------------------------------------------- 1 | '''AlexNet for CIFAR10. FC layers are removed. Paddings are adjusted. 2 | Without BN, the start learning rate should be 0.01 3 | (c) YANG, Wei 4 | ''' 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torchsso.utils.accumulator import TensorAccumulator 8 | 9 | 10 | __all__ = ['alexnet', 'alexnet_mcdropout'] 11 | 12 | 13 | class AlexNet(nn.Module): 14 | 15 | def __init__(self, num_classes=10): 16 | super().__init__() 17 | self.conv1 = nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5) 18 | self.conv2 = nn.Conv2d(64, 192, kernel_size=5, padding=2) 19 | self.conv3 = nn.Conv2d(192, 384, kernel_size=3, padding=1) 20 | self.conv4 = nn.Conv2d(384, 256, kernel_size=3, padding=1) 21 | self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 22 | self.fc = nn.Linear(256, num_classes) 23 | 24 | def forward(self, x): 25 | x = F.relu(self.conv1(x)) 26 | x = F.max_pool2d(x, kernel_size=2, stride=2) 27 | x = F.relu(self.conv2(x)) 28 | x = F.max_pool2d(x, kernel_size=2, stride=2) 29 | x = F.relu(self.conv3(x)) 30 | x = F.relu(self.conv4(x)) 31 | x = F.relu(self.conv5(x)) 32 | x = F.max_pool2d(x, kernel_size=2, stride=2) 33 | x = x.view(x.size(0), -1) 34 | x = self.fc(x) 35 | return x 36 | 37 | 38 | class AlexNetMCDropout(AlexNet): 39 | 40 | mc_dropout = True 41 | 42 | def __init__(self, num_classes=10, dropout_ratio=0.5, val_mc=10): 43 | super(AlexNetMCDropout, self).__init__(num_classes) 44 | self.dropout_ratio = dropout_ratio 45 | self.val_mc = val_mc 46 | 47 | def forward(self, x): 48 | dropout_ratio = self.dropout_ratio 49 | x = F.relu(F.dropout(self.conv1(x), p=dropout_ratio)) 50 | x = F.max_pool2d(x, kernel_size=2, stride=2) 51 | x = F.relu(F.dropout(self.conv2(x), p=dropout_ratio)) 52 | x = F.max_pool2d(x, kernel_size=2, stride=2) 53 | x = F.relu(F.dropout(self.conv3(x), p=dropout_ratio)) 54 | x = F.relu(F.dropout(self.conv4(x), p=dropout_ratio)) 55 | x = F.relu(F.dropout(self.conv5(x), p=dropout_ratio)) 56 | x = F.max_pool2d(x, kernel_size=2, stride=2) 57 | x = x.view(x.size(0), -1) 58 | x = self.fc(x) 59 | return x 60 | 61 | def prediction(self, x): 62 | 63 | acc_prob = TensorAccumulator() 64 | m = self.val_mc 65 | 66 | for _ in range(m): 67 | output = self.forward(x) 68 | prob = F.softmax(output, dim=1) 69 | acc_prob.update(prob, scale=1/m) 70 | 71 | prob = acc_prob.get() 72 | 73 | return prob 74 | 75 | 76 | def alexnet(**kwargs): 77 | r"""AlexNet model architecture from the 78 | `"One weird trick..." `_ paper. 79 | """ 80 | model = AlexNet(**kwargs) 81 | return model 82 | 83 | 84 | def alexnet_mcdropout(**kwargs): 85 | model = AlexNetMCDropout(**kwargs) 86 | return model 87 | 88 | -------------------------------------------------------------------------------- /classification/models/lenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torchsso.utils.accumulator import TensorAccumulator 4 | 5 | 6 | class LeNet5(nn.Module): 7 | 8 | def __init__(self, num_classes=10): 9 | super().__init__() 10 | self.conv1 = nn.Conv2d(3, 6, 5) 11 | self.conv2 = nn.Conv2d(6, 16, 5) 12 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 13 | self.fc2 = nn.Linear(120, 84) 14 | self.fc3 = nn.Linear(84, num_classes) 15 | 16 | def forward(self, x): 17 | out = F.relu(self.conv1(x)) 18 | out = F.max_pool2d(out, 2) 19 | out = F.relu(self.conv2(out)) 20 | out = F.max_pool2d(out, 2) 21 | out = out.view(out.size(0), -1) 22 | out = F.relu(self.fc1(out)) 23 | out = F.relu(self.fc2(out)) 24 | out = self.fc3(out) 25 | return out 26 | 27 | 28 | class LeNet5MCDropout(LeNet5): 29 | 30 | def __init__(self, num_classes=10, dropout_ratio=0.1, val_mc=10): 31 | super(LeNet5MCDropout, self).__init__(num_classes=num_classes) 32 | self.dropout_ratio = dropout_ratio 33 | self.val_mc = val_mc 34 | 35 | def forward(self, x): 36 | p = self.dropout_ratio 37 | out = F.relu(F.dropout(self.conv1(x), p)) 38 | out = F.max_pool2d(out, 2) 39 | out = F.relu(F.dropout(self.conv2(out), p)) 40 | out = F.max_pool2d(out, 2) 41 | out = out.view(out.size(0), -1) 42 | out = F.relu(F.dropout(self.fc1(out), p)) 43 | out = F.relu(F.dropout(self.fc2(out), p)) 44 | out = F.dropout(self.fc2(out), p) 45 | return out 46 | 47 | def mc_prediction(self, x): 48 | 49 | acc_prob = TensorAccumulator() 50 | m = self.val_mc 51 | 52 | for _ in range(m): 53 | output = self.forward(x) 54 | prob = F.softmax(output, dim=1) 55 | acc_prob.update(prob, scale=1/m) 56 | 57 | prob = acc_prob.get() 58 | 59 | return prob 60 | 61 | 62 | class LeNet5BatchNorm(nn.Module): 63 | def __init__(self, num_classes=10, affine=False): 64 | super().__init__() 65 | self.conv1 = nn.Conv2d(3, 6, 5) 66 | self.bn1 = nn.BatchNorm2d(6, affine=affine) 67 | self.conv2 = nn.Conv2d(6, 16, 5) 68 | self.bn2 = nn.BatchNorm2d(16, affine=affine) 69 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 70 | self.bn3 = nn.BatchNorm1d(120, affine=affine) 71 | self.fc2 = nn.Linear(120, 84) 72 | self.bn4 = nn.BatchNorm1d(84, affine=affine) 73 | self.fc3 = nn.Linear(84, num_classes) 74 | 75 | def forward(self, x): 76 | out = F.relu(self.bn1(self.conv1(x))) 77 | out = F.max_pool2d(out, 2) 78 | out = F.relu(self.bn2(self.conv2(out))) 79 | out = F.max_pool2d(out, 2) 80 | out = out.view(out.size(0), -1) 81 | out = F.relu(self.bn3(self.fc1(out))) 82 | out = F.relu(self.bn4(self.fc2(out))) 83 | out = self.fc3(out) 84 | return out 85 | -------------------------------------------------------------------------------- /classification/models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | __all__ = ['mlp'] 7 | 8 | 9 | class MLP(nn.Module): 10 | def __init__(self, num_classes=10): 11 | super().__init__() 12 | n_hid = 1000 13 | n_out = 10 14 | self.l1 = nn.Linear(28*28, n_hid) 15 | self.l2 = nn.Linear(n_hid, n_hid) 16 | self.l3 = nn.Linear(n_hid, n_out) 17 | 18 | def forward(self, x: torch.Tensor): 19 | x = x.view([-1, 28*28]) 20 | x = F.relu(self.l1(x)) 21 | x = F.relu(self.l2(x)) 22 | x = self.l3(x) 23 | return x 24 | 25 | 26 | def mlp(**kwargs): 27 | model = MLP(**kwargs) 28 | return model 29 | 30 | -------------------------------------------------------------------------------- /classification/models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(planes,track_running_stats=False) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes,track_running_stats=False) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*planes: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*planes) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = self.bn2(self.conv2(out)) 34 | out += self.shortcut(x) 35 | out = F.relu(out) 36 | return out 37 | 38 | 39 | class Bottleneck(nn.Module): 40 | expansion = 4 41 | 42 | def __init__(self, in_planes, planes, stride=1): 43 | super(Bottleneck, self).__init__() 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(planes,track_running_stats=False) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn2 = nn.BatchNorm2d(planes,track_running_stats=False) 48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 49 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 50 | 51 | self.shortcut = nn.Sequential() 52 | if stride != 1 or in_planes != self.expansion*planes: 53 | self.shortcut = nn.Sequential( 54 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 55 | nn.BatchNorm2d(self.expansion*planes) 56 | ) 57 | 58 | def forward(self, x): 59 | out = F.relu(self.bn1(self.conv1(x))) 60 | out = F.relu(self.bn2(self.conv2(out))) 61 | out = self.bn3(self.conv3(out)) 62 | out += self.shortcut(x) 63 | out = F.relu(out) 64 | return out 65 | 66 | 67 | class ResNet(nn.Module): 68 | def __init__(self, block, num_blocks, num_classes=10): 69 | super(ResNet, self).__init__() 70 | self.in_planes = 64 71 | 72 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 73 | self.bn1 = nn.BatchNorm2d(64) 74 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 75 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 76 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 77 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 78 | self.linear = nn.Linear(512*block.expansion, num_classes) 79 | 80 | def _make_layer(self, block, planes, num_blocks, stride): 81 | strides = [stride] + [1]*(num_blocks-1) 82 | layers = [] 83 | for stride in strides: 84 | layers.append(block(self.in_planes, planes, stride)) 85 | self.in_planes = planes * block.expansion 86 | return nn.Sequential(*layers) 87 | 88 | def forward(self, x): 89 | out = F.relu(self.bn1(self.conv1(x))) 90 | out = self.layer1(out) 91 | out = self.layer2(out) 92 | out = self.layer3(out) 93 | out = self.layer4(out) 94 | out = F.avg_pool2d(out, 4) 95 | out = out.view(out.size(0), -1) 96 | out = self.linear(out) 97 | return out 98 | 99 | 100 | def ResNet18(num_classes=10): 101 | return ResNet(BasicBlock, [2,2,2,2], num_classes) 102 | 103 | def ResNet34(num_classes=10): 104 | return ResNet(BasicBlock, [3,4,6,3], num_classes) 105 | 106 | def ResNet50(num_classes=10): 107 | return ResNet(Bottleneck, [3,4,6,3], num_classes) 108 | 109 | def ResNet101(num_classes=10): 110 | return ResNet(Bottleneck, [3,4,23,3], num_classes) 111 | 112 | def ResNet152(num_classes=10): 113 | return ResNet(Bottleneck, [3,8,36,3], num_classes) 114 | 115 | 116 | def test(): 117 | net = ResNet18() 118 | y = net(torch.randn(1,3,32,32)) 119 | print(y.size()) 120 | 121 | # test() 122 | -------------------------------------------------------------------------------- /classification/models/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG11/13/16/19 in Pytorch.''' 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | cfg = { 7 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 8 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 10 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 11 | } 12 | 13 | 14 | class VGG(nn.Module): 15 | def __init__(self, vgg_name='VGG19'): 16 | super(VGG, self).__init__() 17 | self.features = self._make_layers(cfg[vgg_name]) 18 | self.classifier = nn.Linear(512, 10) 19 | 20 | def forward(self, x): 21 | out = self.features(x) 22 | out = out.view(out.size(0), -1) 23 | out = self.classifier(out) 24 | return out 25 | 26 | def _make_layers(self, cfg): 27 | layers = [] 28 | in_channels = 3 29 | for x in cfg: 30 | if x == 'M': 31 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 32 | else: 33 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 34 | nn.BatchNorm2d(x), 35 | nn.ReLU(inplace=True)] 36 | in_channels = x 37 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 38 | return nn.Sequential(*layers) 39 | 40 | 41 | def test(): 42 | net = VGG('VGG11') 43 | x = torch.randn(2,3,32,32) 44 | y = net(x) 45 | print(y.size()) 46 | 47 | # test() 48 | -------------------------------------------------------------------------------- /distributed/README.md: -------------------------------------------------------------------------------- 1 | # Distributed training 2 | PyTorch-SSO supports data parallelism and MC samples parallelism (for VI) for distributed training among multiple processes (GPUs). 3 | 4 | ![](../docs/distributed_vi.png) 5 | 6 | -------------------------------------------------------------------------------- /distributed/classification/README.md: -------------------------------------------------------------------------------- 1 | To run training on CIFAR-10/100 with multiple GPUs 2 | 3 | ```bash 4 | mpirun -np python main.py \ 5 | --dist_init_method \ 6 | --download \ 7 | --config 8 | ``` 9 | 10 | To run training on ImageNet with multiple GPUs 11 | 12 | ```bash 13 | mpirun -np python main.py \ 14 | --train_root \ 15 | --val_root \ 16 | --dist_init_method \ 17 | --config 18 | ``` 19 | For `init_method`, refer the [PyTorch tutorial for distirubted applications](https://pytorch.org/tutorials/intermediate/dist_tuto.html). 20 | 21 | | optimizer | dataset | architecture | GPUs | config file path | 22 | | --- | --- | --- | --- | --- | 23 | | [Adam](https://arxiv.org/abs/1412.6980) | ImageNet | ResNet-18 | 128 | [configs/imagenet/resnet18_adam_bs4k_128gpu.json](./configs/imagenet/resnet18_adam_bs4k_128gpu.json) | 24 | | [K-FAC](https://arxiv.org/abs/1503.05671) | ImageNet | ResNet-18 | 4 | [configs/imagenet/resnet18_kfac_bs4k_4gpu.json](./configs/imagenet/resnet18_kfac_bs4k_4gpu.json) | 25 | | [K-FAC](https://arxiv.org/abs/1503.05671)| ImageNet | ResNet-18 | 128 | [configs/imagenet/resnet18_kfac_bs4k_128gpu.json](./configs/imagenet/resnet18_kfac_bs4k_128gpu.json) | 26 | | [Noisy K-FAC](https://arxiv.org/abs/1712.02390)| ImageNet | ResNet-18 | 128 | [configs/imagenet/resnet18_noisykfac_bs4k_128gpu.json](./configs/imagenet/resnet18_noisykfac_bs4k_128gpu.json) | 27 | | [VOGN](https://arxiv.org/abs/1806.04854)| ImageNet | ResNet-18 | 128 | [configs/imagenet/resnet18_vogn_bs4k_128gpu.json](./configs/imagenet/resnet18_vogn_bs4k_128gpu.json) | 28 | 29 | - NOTE: 30 | - You need to run with `N` GPUs when you use `*{N}gpu.json` config file. 31 | - You need to set `--acc_steps` (or `"acc_steps"` in json config) to run with limited number of GPUs as below: 32 | - Mini-batch size (bs) = {examples per GPU} x {# GPUs} x {acc_steps} 33 | - Ex) 4096 (4k) = 32 x 8 x 16 34 | - The gradients of loss and the curvature are accumulated for `acc_steps` to build pseudo mini-batch size. 35 | 36 | Visit [configs](./configs) for other architecture, dataset, optimizer, number of GPUs. 37 | -------------------------------------------------------------------------------- /distributed/classification/configs/cifar10/alexnet_adam_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/alexnet.py", 10 | "arch_name": "AlexNet", 11 | "optim_name": "Adam", 12 | "optim_args": { 13 | "lr": 0.001, 14 | "betas": [0.9, 0.999], 15 | "weight_decay": 1e-4 16 | }, 17 | "scheduler_name": "MultiStepLR", 18 | "scheduler_args": { 19 | "milestones": [80, 120], 20 | "gamma": 0.1 21 | } 22 | } -------------------------------------------------------------------------------- /distributed/classification/configs/cifar10/alexnet_kfac_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 61, 4 | "batch_size": 32, 5 | "val_batch_size": 1250, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/alexnet.py", 10 | "arch_name": "AlexNet", 11 | "optim_name": "DistributedSecondOrderOptimizer", 12 | "optim_args": { 13 | "curv_type": "Fisher", 14 | "curv_shapes": { 15 | "Conv2d": "Kron", 16 | "Linear": "Kron", 17 | "BatchNorm1d": "Diag", 18 | "BatchNorm2d": "Diag" 19 | }, 20 | "lr": 1e-2, 21 | "l2_reg": 1e-3, 22 | "momentum": 0.9, 23 | "momentum_type": "raw" 24 | }, 25 | "curv_args": { 26 | "damping": 1e-3, 27 | "ema_decay": 0.999, 28 | "pi_type": "tracenorm" 29 | }, 30 | "fisher_args": { 31 | "approx_type": "mc", 32 | "num_mc": 1 33 | }, 34 | "scheduler_name": "MultiStepLR", 35 | "scheduler_args": { 36 | "milestones": [30, 50], 37 | "gamma": 0.1 38 | } 39 | } -------------------------------------------------------------------------------- /distributed/classification/configs/cifar10/alexnet_noisykfac_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 81, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/alexnet.py", 10 | "arch_name": "AlexNet", 11 | "optim_name": "DistributedVIOptimizer", 12 | "optim_args": { 13 | "curv_type": "Fisher", 14 | "curv_shapes": { 15 | "Conv2d": "Kron", 16 | "Linear": "Kron", 17 | "BatchNorm1d": "Diag", 18 | "BatchNorm2d": "Diag" 19 | }, 20 | "lr": 1e-3, 21 | "momentum": 0.9, 22 | "momentum_type": "raw", 23 | "num_mc_samples": 3, 24 | "val_num_mc_samples": 10, 25 | "kl_weighting": 1, 26 | "warmup_kl_weighting_init": 0.01, 27 | "warmup_kl_weighting_steps": 15821, 28 | "prior_variance": 2 29 | }, 30 | "curv_args": { 31 | "damping": 1e-3, 32 | "ema_decay": 0.999, 33 | "pi_type": "tracenorm" 34 | }, 35 | "fisher_args": { 36 | "approx_type": "mc", 37 | "num_mc": 1 38 | }, 39 | "scheduler_name": "MultiStepLR", 40 | "scheduler_args": { 41 | "milestones": [30, 50], 42 | "gamma": 0.1 43 | }, 44 | "num_mc_groups": 8 45 | } -------------------------------------------------------------------------------- /distributed/classification/configs/cifar10/alexnet_sgd_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/alexnet.py", 10 | "arch_name": "AlexNet", 11 | "optim_name": "SGD", 12 | "optim_args": { 13 | "lr": 0.1, 14 | "momentum": 0.9, 15 | "weight_decay": 1e-4 16 | }, 17 | "scheduler_name": "MultiStepLR", 18 | "scheduler_args": { 19 | "milestones": [80, 120], 20 | "gamma": 0.1 21 | } 22 | } -------------------------------------------------------------------------------- /distributed/classification/configs/cifar10/alexnet_vogn_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": false, 7 | "random_horizontal_flip": false, 8 | "normalizing_data": true, 9 | "dataset_size_scale": 10, 10 | "arch_file": "models/alexnet.py", 11 | "arch_name": "AlexNet", 12 | "optim_name": "DistributedVIOptimizer", 13 | "optim_args": { 14 | "curv_type": "Cov", 15 | "curv_shapes": { 16 | "Conv2d": "Diag", 17 | "Linear": "Kron", 18 | "BatchNorm1d": "Diag", 19 | "BatchNorm2d": "Diag" 20 | }, 21 | "lr": 1e-4, 22 | "momentum": 0.9, 23 | "momentum_type": "raw", 24 | "num_mc_samples": 3, 25 | "val_num_mc_samples": 10, 26 | "kl_weighting": 1, 27 | "warmup_kl_weighting_init": 0.5, 28 | "warmup_kl_weighting_steps": 1954, 29 | "prior_variance": 2 30 | }, 31 | "curv_args": { 32 | "damping": 1e-3, 33 | "ema_decay": 0.999 34 | }, 35 | "scheduler_name": "MultiStepLR", 36 | "scheduler_args": { 37 | "milestones": [80, 120], 38 | "gamma": 0.1 39 | }, 40 | "num_mc_groups": 8 41 | } -------------------------------------------------------------------------------- /distributed/classification/configs/cifar10/lenet_adam_bs128_4gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 150, 4 | "batch_size": 32, 5 | "val_batch_size": 1250, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/lenet.py", 10 | "arch_name": "LeNet5", 11 | "optim_name": "Adam", 12 | "optim_args": { 13 | "lr": 1e-3, 14 | "betas": [0.9, 0.999], 15 | "weight_decay": 1e-2 16 | } 17 | } -------------------------------------------------------------------------------- /distributed/classification/configs/cifar10/lenet_kfac_bs128_4gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 50, 4 | "batch_size": 32, 5 | "val_batch_size": 1250, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/lenet.py", 10 | "arch_name": "LeNet5", 11 | "optim_name": "DistributedSecondOrderOptimizer", 12 | "optim_args": { 13 | "lr": 1e-3, 14 | "curv_type": "Fisher", 15 | "curv_shapes": { 16 | "Conv2d": "Kron", 17 | "Linear": "Kron", 18 | "BatchNorm1d": "Diag", 19 | "BatchNorm2d": "Diag" 20 | }, 21 | "momentum": 0.9, 22 | "momentum_type": "raw", 23 | "l2_reg": 1e-3, 24 | "acc_steps": 1 25 | }, 26 | "curv_args": { 27 | "damping": 1e-3, 28 | "ema_decay": 0.999, 29 | "pi_type": "tracenorm" 30 | }, 31 | "fisher_args": { 32 | "approx_type": "mc", 33 | "num_mc": 1 34 | }, 35 | "scheduler_name": "ExponentialLR", 36 | "scheduler_args": { 37 | "gamma": 0.9 38 | }, 39 | "log_interval": 64, 40 | "no_cuda": false 41 | } -------------------------------------------------------------------------------- /distributed/classification/configs/cifar10/lenet_sgd_bs128_4gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 150, 4 | "batch_size": 32, 5 | "val_batch_size": 1250, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/lenet.py", 10 | "arch_name": "LeNet5", 11 | "optim_name": "SGD", 12 | "optim_args": { 13 | "lr": 1e-3, 14 | "momentum": 0.9, 15 | "weight_decay": 1e-2 16 | } 17 | } -------------------------------------------------------------------------------- /distributed/classification/configs/cifar10/lenet_vogn_bs128_4gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 211, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": false, 7 | "random_horizontal_flip": false, 8 | "normalizing_data": true, 9 | "dataset_size_scale": 1, 10 | "arch_file": "models/lenet.py", 11 | "arch_name": "LeNet5", 12 | "optim_name": "DistributedVIOptimizer", 13 | "optim_args": { 14 | "curv_type": "Cov", 15 | "curv_shapes": { 16 | "Conv2d": "Diag", 17 | "Linear": "Diag", 18 | "BatchNorm1d": "Diag", 19 | "BatchNorm2d": "Diag" 20 | }, 21 | "lr": 1e-4, 22 | "momentum": 0.9, 23 | "momentum_type": "raw", 24 | "num_mc_samples": 6, 25 | "val_num_mc_samples": 10, 26 | "kl_weighting": 1, 27 | "warmup_kl_weighting_init": 0.1, 28 | "warmup_kl_weighting_steps": 11719, 29 | "prior_variance": 1e-2 30 | }, 31 | "curv_args": { 32 | "damping": 1e-3, 33 | "ema_decay": 0.999 34 | }, 35 | "num_mc_groups": 4 36 | } 37 | -------------------------------------------------------------------------------- /distributed/classification/configs/cifar10/resnet18_adam_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/resnet_b.py", 10 | "arch_name": "resnet18", 11 | "arch_args": { 12 | "zero_init_residual": true, 13 | "norm_stat_momentum": 0.1 14 | }, 15 | "optim_name": "Adam", 16 | "optim_args": { 17 | "lr": 1e-3, 18 | "betas": [0.9, 0.999], 19 | "weight_decay": 5e-4 20 | }, 21 | "non_wd_for_bn": true, 22 | "scheduler_name": "MultiStepLR", 23 | "scheduler_args": { 24 | "milestones": [80, 120], 25 | "gamma": 0.1 26 | } 27 | } -------------------------------------------------------------------------------- /distributed/classification/configs/cifar10/resnet18_sgd_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/resnet_b.py", 10 | "arch_name": "resnet18", 11 | "arch_args": { 12 | "zero_init_residual": true, 13 | "norm_stat_momentum": 0.1 14 | }, 15 | "optim_name": "SGD", 16 | "optim_args": { 17 | "lr": 0.1, 18 | "momentum": 0.9, 19 | "weight_decay": 1e-4 20 | }, 21 | "momentum_correction": true, 22 | "non_wd_for_bn": true, 23 | "scheduler_name": "MultiStepLR", 24 | "scheduler_args": { 25 | "milestones": [80, 120], 26 | "gamma": 0.1 27 | } 28 | } -------------------------------------------------------------------------------- /distributed/classification/configs/cifar10/resnet18_vogn_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "dataset_size_scale": 10, 10 | "arch_file": "models/resnet_b.py", 11 | "arch_name": "resnet18", 12 | "arch_args": { 13 | "zero_init_residual": false, 14 | "norm_stat_momentum": 0.1 15 | }, 16 | "optim_name": "DistributedVIOptimizer", 17 | "optim_args": { 18 | "curv_type": "Cov", 19 | "curv_shapes": { 20 | "Conv2d": "Diag", 21 | "Linear": "Diag", 22 | "BatchNorm1d": "Diag", 23 | "BatchNorm2d": "Diag" 24 | }, 25 | "lr": 1e-4, 26 | "momentum": 0.9, 27 | "momentum_type": "raw", 28 | "num_mc_samples": 5, 29 | "val_num_mc_samples": 20, 30 | "kl_weighting": 1, 31 | "prior_variance": 0.02, 32 | "non_reg_for_bn": true 33 | }, 34 | "curv_args": { 35 | "damping": 1e-3, 36 | "ema_decay": 0.999 37 | }, 38 | "momentum_correction": true, 39 | "scheduler_name": "MultiStepLR", 40 | "scheduler_args": { 41 | "milestones": [80, 120], 42 | "gamma": 0.1 43 | }, 44 | "num_mc_groups": 8 45 | } -------------------------------------------------------------------------------- /distributed/classification/configs/cifar10/vgg19_adam_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/vgg.py", 10 | "arch_name": "VGG19", 11 | "optim_name": "Adam", 12 | "optim_args": { 13 | "lr": 1e-3, 14 | "betas": [0.9, 0.999], 15 | "weight_decay": 1e-4 16 | }, 17 | "scheduler_name": "MultiStepLR", 18 | "scheduler_args": { 19 | "milestones": [80, 120], 20 | "gamma": 0.1 21 | } 22 | } -------------------------------------------------------------------------------- /distributed/classification/configs/cifar10/vgg19_vogn_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "dataset_size_scale": 10, 10 | "arch_file": "models/vgg.py", 11 | "arch_name": "VGG19", 12 | "optim_name": "DistributedVIOptimizer", 13 | "optim_args": { 14 | "curv_type": "Cov", 15 | "curv_shapes": { 16 | "Conv2d": "Diag", 17 | "Linear": "Diag", 18 | "BatchNorm1d": "Diag", 19 | "BatchNorm2d": "Diag" 20 | }, 21 | "lr": 1e-4, 22 | "momentum": 0.9, 23 | "momentum_type": "raw", 24 | "num_mc_samples": 5, 25 | "val_num_mc_samples": 10, 26 | "kl_weighting": 1, 27 | "prior_variance": 2 28 | }, 29 | "curv_args": { 30 | "damping": 1e-3, 31 | "ema_decay": 0.999 32 | }, 33 | "scheduler_name": "MultiStepLR", 34 | "scheduler_args": { 35 | "milestones": [80, 120], 36 | "gamma": 0.1 37 | }, 38 | "num_mc_groups": 8 39 | } -------------------------------------------------------------------------------- /distributed/classification/configs/cifar100/alexnet_adam_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-100", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/alexnet.py", 10 | "arch_name": "AlexNet", 11 | "optim_name": "Adam", 12 | "optim_args": { 13 | "lr": 0.001, 14 | "betas": [0.9, 0.999], 15 | "weight_decay": 1e-2 16 | }, 17 | "scheduler_name": "MultiStepLR", 18 | "scheduler_args": { 19 | "milestones": [80, 120], 20 | "gamma": 0.1 21 | } 22 | } -------------------------------------------------------------------------------- /distributed/classification/configs/cifar100/alexnet_sgd_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-100", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/alexnet.py", 10 | "arch_name": "AlexNet", 11 | "optim_name": "SGD", 12 | "optim_args": { 13 | "lr": 0.1, 14 | "momentum": 0.9, 15 | "weight_decay": 1e-4 16 | }, 17 | "scheduler_name": "MultiStepLR", 18 | "scheduler_args": { 19 | "milestones": [80, 120], 20 | "gamma": 0.1 21 | } 22 | } -------------------------------------------------------------------------------- /distributed/classification/configs/cifar100/alexnet_vogn_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-100", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "dataset_size_scale": 10, 10 | "arch_file": "models/alexnet.py", 11 | "arch_name": "AlexNet", 12 | "optim_name": "DistributedVIOptimizer", 13 | "optim_args": { 14 | "curv_type": "Cov", 15 | "curv_shapes": { 16 | "Conv2d": "Diag", 17 | "Linear": "Diag", 18 | "BatchNorm1d": "Diag", 19 | "BatchNorm2d": "Diag" 20 | }, 21 | "lr": 1e-4, 22 | "momentum": 0.9, 23 | "momentum_type": "raw", 24 | "num_mc_samples": 10, 25 | "val_num_mc_samples": 100, 26 | "kl_weighting": 1, 27 | "warmup_kl_weighting_init": 0.5, 28 | "warmup_kl_weighting_steps": 1954, 29 | "prior_variance": 0.02 30 | }, 31 | "curv_args": { 32 | "damping": 1e-3, 33 | "ema_decay": 0.999 34 | }, 35 | "scheduler_name": "MultiStepLR", 36 | "scheduler_args": { 37 | "milestones": [80, 120], 38 | "gamma": 0.1 39 | }, 40 | "num_mc_groups": 8 41 | } -------------------------------------------------------------------------------- /distributed/classification/configs/cifar100/resnet18_adam_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-100", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/resnet_b.py", 10 | "arch_name": "resnet18", 11 | "arch_args": { 12 | "zero_init_residual": true, 13 | "norm_stat_momentum": 0.1 14 | }, 15 | "optim_name": "Adam", 16 | "optim_args": { 17 | "lr": 1e-3, 18 | "betas": [0.9, 0.999], 19 | "weight_decay": 1e-2 20 | }, 21 | "non_wd_for_bn": true, 22 | "scheduler_name": "MultiStepLR", 23 | "scheduler_args": { 24 | "milestones": [80, 120], 25 | "gamma": 0.1 26 | } 27 | } -------------------------------------------------------------------------------- /distributed/classification/configs/cifar100/resnet18_vogn_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-100", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "dataset_size_scale": 10, 10 | "arch_file": "models/resnet_b.py", 11 | "arch_name": "resnet18", 12 | "arch_args": { 13 | "zero_init_residual": false, 14 | "norm_stat_momentum": 0.1 15 | }, 16 | "optim_name": "DistributedVIOptimizer", 17 | "optim_args": { 18 | "curv_type": "Cov", 19 | "curv_shapes": { 20 | "Conv2d": "Diag", 21 | "Linear": "Diag", 22 | "BatchNorm1d": "Diag", 23 | "BatchNorm2d": "Diag" 24 | }, 25 | "lr": 1e-4, 26 | "momentum": 0.9, 27 | "momentum_type": "raw", 28 | "num_mc_samples": 5, 29 | "val_num_mc_samples": 20, 30 | "kl_weighting": 1, 31 | "warmup_kl_weighting_init": 0.5, 32 | "warmup_kl_weighting_steps": 1954, 33 | "prior_variance": 0.02, 34 | "non_reg_for_bn": true 35 | }, 36 | "curv_args": { 37 | "damping": 1e-3, 38 | "ema_decay": 0.999 39 | }, 40 | "momentum_correction": true, 41 | "scheduler_name": "MultiStepLR", 42 | "scheduler_args": { 43 | "milestones": [80, 120], 44 | "gamma": 0.1 45 | }, 46 | "num_mc_groups": 8 47 | } -------------------------------------------------------------------------------- /distributed/classification/configs/imagenet/resnet18_adam_bs4k_128gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "ImageNet", 3 | "epochs": 91, 4 | "batch_size": 32, 5 | "val_batch_size": 391, 6 | "random_resized_crop": false, 7 | "random_crop": true, 8 | "random_horizontal_flip": true, 9 | "normalizing_data": true, 10 | "arch_file": "models/resnet_b.py", 11 | "arch_name": "resnet18", 12 | "arch_args": { 13 | "zero_init_residual": true, 14 | "norm_stat_momentum": 0.1 15 | }, 16 | "optim_name": "Adam", 17 | "optim_args": { 18 | "lr": 1.6e-3, 19 | "weight_decay": 1e-4, 20 | "betas": [0.9, 0.999] 21 | }, 22 | "non_wd_for_bn": true, 23 | "scheduler_name": "MultiStepLR", 24 | "scheduler_args": { 25 | "milestones": [30, 60, 80], 26 | "gamma": 0.1 27 | }, 28 | "warmup_epochs": 5, 29 | "warmup_scheduler_name": "GradualWarmupIterLR", 30 | "warmup_scheduler_args": { 31 | "initial_lr": 1.25e-5, 32 | "max_count": 1565 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /distributed/classification/configs/imagenet/resnet18_kfac_bs4k_128gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "ImageNet", 3 | "epochs": 61, 4 | "batch_size": 32, 5 | "val_batch_size": 32, 6 | "random_resized_crop": false, 7 | "random_crop": true, 8 | "random_horizontal_flip": true, 9 | "normalizing_data": true, 10 | "arch_file": "models/resnet_b.py", 11 | "arch_name": "resnet18", 12 | "arch_args": { 13 | "zero_init_residual": false, 14 | "norm_stat_momentum": 0.1 15 | }, 16 | "optim_name": "DistributedSecondOrderOptimizer", 17 | "optim_args": { 18 | "curv_type": "Fisher", 19 | "curv_shapes": { 20 | "Conv2d": "Kron", 21 | "Linear": "Kron", 22 | "BatchNorm1d": "Diag", 23 | "BatchNorm2d": "Diag" 24 | }, 25 | "lr": 1.6e-3, 26 | "l2_reg": 1e-4, 27 | "momentum": 0.9, 28 | "momentum_type": "raw", 29 | "non_reg_for_bn": true, 30 | "acc_steps": 1 31 | }, 32 | "curv_args": { 33 | "damping": 1e-4, 34 | "ema_decay": 1, 35 | "pi_type": "tracenorm" 36 | }, 37 | "fisher_args": { 38 | "approx_type": "mc", 39 | "num_mc": 1 40 | }, 41 | "momentum_correction": true, 42 | "scheduler_name": "MultiStepLR", 43 | "scheduler_args": { 44 | "milestones": [15, 30, 45], 45 | "gamma": 0.1 46 | }, 47 | "warmup_epochs": 5, 48 | "warmup_scheduler_name": "GradualWarmupIterLR", 49 | "warmup_scheduler_args": { 50 | "initial_lr": 1.25e-5, 51 | "max_count": 1565 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /distributed/classification/configs/imagenet/resnet18_kfac_bs4k_4gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "ImageNet10", 3 | "epochs": 30, 4 | "batch_size": 32, 5 | "val_batch_size": 32, 6 | "random_resized_crop": false, 7 | "random_crop": true, 8 | "random_horizontal_flip": true, 9 | "normalizing_data": true, 10 | "arch_file": "models/resnet_b.py", 11 | "arch_name": "resnet18", 12 | "arch_args": { 13 | "zero_init_residual": false, 14 | "norm_stat_momentum": 0.1 15 | }, 16 | "optim_name": "DistributedSecondOrderOptimizer", 17 | "optim_args": { 18 | "curv_type": "Fisher", 19 | "curv_shapes": { 20 | "Conv2d": "Kron", 21 | "Linear": "Kron", 22 | "BatchNorm1d": "Diag", 23 | "BatchNorm2d": "Diag" 24 | }, 25 | "lr": 1.6e-3, 26 | "momentum": 0.9, 27 | "momentum_type": "raw", 28 | "l2_reg": 1e-4, 29 | "non_reg_for_bn": true, 30 | "acc_steps": 32 31 | }, 32 | "curv_args": { 33 | "damping": 1e-4, 34 | "ema_decay": 1, 35 | "pi_type": "tracenorm" 36 | }, 37 | "fisher_args": { 38 | "approx_type": "mc", 39 | "num_mc": 1 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /distributed/classification/configs/imagenet/resnet18_noisykfac_bs4k_128gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "ImageNet", 3 | "epochs": 61, 4 | "batch_size": 32, 5 | "val_batch_size": 32, 6 | "random_resized_crop": false, 7 | "random_crop": true, 8 | "random_horizontal_flip": true, 9 | "dataset_size_scale": 5, 10 | "normalizing_data": true, 11 | "arch_file": "models/resnet_b.py", 12 | "arch_name": "resnet18", 13 | "arch_args": { 14 | "zero_init_residual": false, 15 | "norm_stat_momentum": 0.1 16 | }, 17 | "optim_name": "DistributedVIOptimizer", 18 | "optim_args": { 19 | "curv_type": "Fisher", 20 | "curv_shapes": { 21 | "Conv2d": "Kron", 22 | "Linear": "Kron", 23 | "BatchNorm1d": "Diag", 24 | "BatchNorm2d": "Diag" 25 | }, 26 | "lr": 1.6e-3, 27 | "momentum": 0.9, 28 | "momentum_type": "raw", 29 | "num_mc_samples": 1, 30 | "val_num_mc_samples": 10, 31 | "kl_weighting": 1, 32 | "prior_variance": 7.5e-3, 33 | "non_reg_for_bn": true, 34 | "acc_steps": 1 35 | }, 36 | "curv_args": { 37 | "damping": 1e-4, 38 | "ema_decay": 0.9, 39 | "pi_type": "tracenorm" 40 | }, 41 | "fisher_args": { 42 | "approx_type": "mc", 43 | "num_mc": 1 44 | }, 45 | "momentum_correction": true, 46 | "scheduler_name": "MultiStepLR", 47 | "scheduler_args": { 48 | "milestones": [15, 30, 45], 49 | "gamma": 0.1 50 | }, 51 | "warmup_epochs": 5, 52 | "warmup_scheduler_name": "GradualWarmupIterLR", 53 | "warmup_scheduler_args": { 54 | "initial_lr": 1.25e-5, 55 | "max_count": 1565 56 | }, 57 | "num_mc_groups": 128 58 | } 59 | -------------------------------------------------------------------------------- /distributed/classification/configs/imagenet/resnet18_noisykfac_bs4k_4gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "ImageNet10", 3 | "epochs": 30, 4 | "batch_size": 32, 5 | "val_batch_size": 32, 6 | "random_resized_crop": false, 7 | "random_crop": true, 8 | "random_horizontal_flip": true, 9 | "normalizing_data": true, 10 | "arch_file": "models/resnet_b.py", 11 | "arch_name": "resnet18", 12 | "arch_args": { 13 | "zero_init_residual": false, 14 | "norm_stat_momentum": 0.1 15 | }, 16 | "optim_name": "DistributedVIOptimizer", 17 | "optim_args": { 18 | "curv_type": "Fisher", 19 | "curv_shapes": { 20 | "Conv2d": "Kron", 21 | "Linear": "Kron", 22 | "BatchNorm1d": "Diag", 23 | "BatchNorm2d": "Diag" 24 | }, 25 | "lr": 1.6e-3, 26 | "momentum": 0.9, 27 | "momentum_type": "raw", 28 | "num_mc_samples": 1, 29 | "val_num_mc_samples": 0, 30 | "kl_weighting": 1, 31 | "prior_variance": 0.75, 32 | "non_reg_for_bn": true, 33 | "acc_steps": 32 34 | }, 35 | "curv_args": { 36 | "damping": 1e-4, 37 | "ema_decay": 1, 38 | "pi_type": "tracenorm" 39 | }, 40 | "fisher_args": { 41 | "approx_type": "mc", 42 | "num_mc": 1 43 | }, 44 | "num_mc_groups": 1 45 | } 46 | -------------------------------------------------------------------------------- /distributed/classification/configs/imagenet/resnet18_sgd_bs4k_128gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "ImageNet", 3 | "epochs": 91, 4 | "batch_size": 32, 5 | "val_batch_size": 391, 6 | "random_resized_crop": false, 7 | "random_crop": true, 8 | "random_horizontal_flip": true, 9 | "normalizing_data": true, 10 | "arch_file": "models/resnet_b.py", 11 | "arch_name": "resnet18", 12 | "arch_args": { 13 | "zero_init_residual": true, 14 | "norm_stat_momentum": 0.1 15 | }, 16 | "optim_name": "SGD", 17 | "optim_args": { 18 | "lr": 1.6e-1, 19 | "momentum": 0.9, 20 | "weight_decay": 1e-4 21 | }, 22 | "momentum_correction": true, 23 | "non_wd_for_bn": true, 24 | "scheduler_name": "MultiStepLR", 25 | "scheduler_args": { 26 | "milestones": [30, 60, 80], 27 | "gamma": 0.1 28 | }, 29 | "warmup_epochs": 5, 30 | "warmup_scheduler_name": "GradualWarmupIterLR", 31 | "warmup_scheduler_args": { 32 | "initial_lr": 1.25e-3, 33 | "max_count": 1565 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /distributed/classification/configs/imagenet/resnet18_vogn_bs4k_128gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "ImageNet", 3 | "epochs": 91, 4 | "batch_size": 32, 5 | "val_batch_size": 32, 6 | "random_resized_crop": false, 7 | "random_crop": false, 8 | "random_horizontal_flip": false, 9 | "dataset_size_scale": 5, 10 | "normalizing_data": true, 11 | "arch_file": "models/resnet_b.py", 12 | "arch_name": "resnet18", 13 | "arch_args": { 14 | "zero_init_residual": false, 15 | "norm_stat_momentum": 0.1 16 | }, 17 | "optim_name": "DistributedVIOptimizer", 18 | "optim_args": { 19 | "curv_type": "Cov", 20 | "curv_shapes": { 21 | "Conv2d": "Diag", 22 | "Linear": "Diag", 23 | "BatchNorm1d": "Diag", 24 | "BatchNorm2d": "Diag" 25 | }, 26 | "lr": 1.6e-3, 27 | "momentum": 0.9, 28 | "momentum_type": "raw", 29 | "num_mc_samples": 1, 30 | "val_num_mc_samples": 10, 31 | "kl_weighting": 1, 32 | "prior_variance": 7.5e-3, 33 | "non_reg_for_bn": true 34 | }, 35 | "curv_args": { 36 | "damping": 1e-4, 37 | "ema_decay": 0.9 38 | }, 39 | "momentum_correction": true, 40 | "scheduler_name": "MultiStepLR", 41 | "scheduler_args": { 42 | "milestones": [30, 60, 80], 43 | "gamma": 0.1 44 | }, 45 | "warmup_epochs": 5, 46 | "warmup_scheduler_name": "GradualWarmupIterLR", 47 | "warmup_scheduler_args": { 48 | "initial_lr": 1.25e-5, 49 | "max_count": 1565 50 | }, 51 | "num_mc_groups": 128 52 | } 53 | -------------------------------------------------------------------------------- /distributed/classification/configs/imagenet/resnet18_vogn_bs4k_4gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "ImageNet", 3 | "epochs": 90, 4 | "batch_size": 32, 5 | "val_batch_size": 32, 6 | "random_resized_crop": false, 7 | "random_crop": true, 8 | "random_horizontal_flip": true, 9 | "normalizing_data": true, 10 | "arch_file": "models/resnet_b.py", 11 | "arch_name": "resnet18", 12 | "arch_args": { 13 | "zero_init_residual": true, 14 | "norm_stat_momentum": 0.1 15 | }, 16 | "optim_name": "DistributedVIOptimizer", 17 | "optim_args": { 18 | "curv_type": "Cov", 19 | "curv_shapes": { 20 | "Conv2d": "Diag", 21 | "Linear": "Diag", 22 | "BatchNorm1d": "Diag", 23 | "BatchNorm2d": "Diag" 24 | }, 25 | "lr": 1.6e-2, 26 | "grad_ema_decay": 0.1, 27 | "grad_ema_type": "raw", 28 | "bias_correction": true, 29 | "non_reg_for_bn": true, 30 | "num_mc_samples": 1, 31 | "val_num_mc_samples": 10, 32 | "kl_weighting": 1, 33 | "prior_variance": 7.5e-4, 34 | "weight_decay": 0, 35 | "lars": false 36 | }, 37 | "curv_args": { 38 | "damping": 1e-4, 39 | "ema_decay": 1e-3 40 | }, 41 | "momentum_correction": true, 42 | "scheduler_name": "MultiStepLR", 43 | "scheduler_args": { 44 | "milestones": [30, 60, 80], 45 | "gamma": 0.1 46 | }, 47 | "warmup_epochs": 5, 48 | "warmup_scheduler_name": "GradualWarmupIterLR", 49 | "warmup_scheduler_args": { 50 | "initial_lr": 1.25e-4, 51 | "max_count": 1565 52 | }, 53 | "num_mc_groups": 4 54 | } 55 | -------------------------------------------------------------------------------- /distributed/classification/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from importlib import import_module 4 | import shutil 5 | import json 6 | import math 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torchvision import datasets, transforms, models 12 | import torchsso 13 | from torchsso.optim import DistributedFirstOrderOptimizer, DistributedSecondOrderOptimizer, DistributedVIOptimizer 14 | from torchsso.optim.lr_scheduler import MomentumCorrectionLR 15 | from torchsso.utils import Logger 16 | 17 | from mpi4py import MPI 18 | import torch.distributed as dist 19 | 20 | DATASET_CIFAR10 = 'CIFAR-10' 21 | DATASET_CIFAR100 = 'CIFAR-100' 22 | DATASET_IMAGENET = 'ImageNet' 23 | 24 | 25 | def main(): 26 | parser = argparse.ArgumentParser() 27 | # Data 28 | parser.add_argument('--dataset', type=str, 29 | choices=[DATASET_CIFAR10, DATASET_CIFAR100, DATASET_IMAGENET], 30 | default=DATASET_CIFAR10, 31 | help='name of dataset') 32 | parser.add_argument('--root', type=str, default='./data', 33 | help='root of dataset') 34 | parser.add_argument('--train_root', type=str, default=None, 35 | help='root of train dataset') 36 | parser.add_argument('--val_root', type=str, default=None, 37 | help='root of validate dataset') 38 | parser.add_argument('--epochs', type=int, default=10, 39 | help='number of epochs to train') 40 | parser.add_argument('--batch_size', type=int, default=128, 41 | help='input batch size for training') 42 | parser.add_argument('--val_batch_size', type=int, default=128, 43 | help='input batch size for valing') 44 | parser.add_argument('--normalizing_data', action='store_true', 45 | help='[data pre processing] normalizing data') 46 | parser.add_argument('--random_crop', action='store_true', 47 | help='[data augmentation] random crop') 48 | parser.add_argument('--random_resized_crop', action='store_true', 49 | help='[data augmentation] random resised crop') 50 | parser.add_argument('--random_horizontal_flip', action='store_true', 51 | help='[data augmentation] random horizontal flip') 52 | parser.add_argument('--dataset_size_scale', type=float, default=1., 53 | help='ratio multiplied to the actual dataset size') 54 | # Training Settings 55 | parser.add_argument('--arch_file', type=str, default=None, 56 | help='name of file which defines the architecture') 57 | parser.add_argument('--arch_name', type=str, default='LeNet5', 58 | help='name of the architecture') 59 | parser.add_argument('--arch_args', type=json.loads, default=None, 60 | help='[JSON] arguments for the architecture') 61 | parser.add_argument('--optim_name', type=str, default=DistributedSecondOrderOptimizer.__name__, 62 | help='name of the optimizer') 63 | parser.add_argument('--optim_args', type=json.loads, default=None, 64 | help='[JSON] arguments for the optimizer') 65 | parser.add_argument('--curv_args', type=json.loads, default=dict(), 66 | help='[JSON] arguments for the curvature') 67 | parser.add_argument('--fisher_args', type=json.loads, default=dict(), 68 | help='[JSON] arguments for the fisher') 69 | parser.add_argument('--scheduler_name', type=str, default=None, 70 | help='name of the learning rate scheduler') 71 | parser.add_argument('--scheduler_args', type=json.loads, default=None, 72 | help='[JSON] arguments for the scheduler') 73 | parser.add_argument('--warmup_epochs', type=int, default=0, 74 | help='number of epochs for warmup lr') 75 | parser.add_argument('--warmup_scheduler_name', type=str, default=None, 76 | help='name of the learning rate scheduler') 77 | parser.add_argument('--warmup_scheduler_args', type=json.loads, default=None, 78 | help='[JSON] arguments for the wamup scheduler') 79 | parser.add_argument('--momentum_correction', action='store_true', 80 | help='if True, momentum/LR ratio is kept to be constant') 81 | parser.add_argument('--non_wd_for_bn', action='store_true', 82 | help='(FirstOrderOptimizer only) if True, weight decay is not applied for BatchNorm') 83 | parser.add_argument('--lars', action='store_true', 84 | help='if True, LARS is applied for first-order optimizer') 85 | # Options 86 | parser.add_argument('--download', action='store_true', default=False, 87 | help='if True, downloads the dataset (CIFAR-10 or 100) from the internet') 88 | parser.add_argument('--seed', type=int, default=1, 89 | help='random seed') 90 | parser.add_argument('--num_workers', type=int, default=0, 91 | help='number of sub processes for data loading') 92 | parser.add_argument('--log_interval', type=int, default=50, 93 | help='how many batches to wait before logging training status') 94 | parser.add_argument('--log_file_name', type=str, default='log', 95 | help='log file name') 96 | parser.add_argument('--checkpoint_interval', type=int, default=50, 97 | help='how many epochs to wait before logging training status') 98 | parser.add_argument('--resume', type=str, default=None, 99 | help='checkpoint path for resume training') 100 | parser.add_argument('--out', type=str, default='result', 101 | help='dir to save output files') 102 | parser.add_argument('--config', default=None, 103 | help='config file path') 104 | # [COMM] 105 | parser.add_argument('--dist_init_method', type=str, 106 | help='torch.distributed init_method') 107 | parser.add_argument('--size_data_group', type=int, default=1, 108 | help='size of the process groups in which input data are shared') 109 | parser.add_argument('--num_mc_groups', type=int, default=1, 110 | help='number of the process groups in which mc sampled params are shared') 111 | 112 | args = parser.parse_args() 113 | dict_args = vars(args) 114 | 115 | # Load config file 116 | if args.config is not None: 117 | with open(args.config) as f: 118 | config = json.load(f) 119 | dict_args.update(config) 120 | 121 | # Set random seed 122 | torch.manual_seed(args.seed) 123 | 124 | # [COMM] Initialize process group 125 | comm = MPI.COMM_WORLD 126 | size = comm.Get_size() 127 | ranks = list(range(size)) 128 | rank = comm.Get_rank() 129 | n_per_node = torch.cuda.device_count() 130 | device = rank % n_per_node 131 | torch.cuda.set_device(device) 132 | init_method = 'tcp://{}:23456'.format(args.dist_init_method) 133 | dist.init_process_group('nccl', init_method=init_method, world_size=size, rank=rank) 134 | 135 | # [COMM] Setup process group for MC sample parallel 136 | size_data_group = args.size_data_group 137 | assert size % size_data_group == 0 138 | num_mc_groups = args.num_mc_groups 139 | assert size % num_mc_groups == 0 140 | 141 | if size_data_group > 1: 142 | num_data_group = size / size_data_group 143 | data_group_id = rank % num_data_group 144 | data_group_ranks = ranks[data_group_id:size:num_data_group] 145 | data_group = dist.new_group(data_group_ranks) 146 | 147 | master_ranks = ranks[0:num_data_group] 148 | master_group = dist.new_group(master_ranks) 149 | else: 150 | num_data_group = size 151 | data_group_id = rank 152 | data_group = None 153 | master_group = dist.new_group(ranks) 154 | 155 | if num_mc_groups > 1: 156 | size_mc_group = int(size / num_mc_groups) 157 | mc_group_id = int(rank/size_mc_group) 158 | else: 159 | size_mc_group = size 160 | mc_group_id = 0 161 | 162 | # Setup data augmentation & data pre processing 163 | train_transforms, val_transforms = [], [] 164 | 165 | if args.dataset in [DATASET_CIFAR10, DATASET_CIFAR100]: 166 | # CIFAR-10/100 167 | if args.random_crop: 168 | train_transforms.append(transforms.RandomCrop(32, padding=4)) 169 | 170 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 171 | else: 172 | # ImageNet 173 | if args.random_resized_crop: 174 | train_transforms.append(transforms.RandomResizedCrop(224)) 175 | else: 176 | train_transforms.append(transforms.Resize(256)) 177 | if args.random_crop: 178 | train_transforms.append(transforms.RandomCrop(224)) 179 | else: 180 | train_transforms.append(transforms.CenterCrop(224)) 181 | 182 | normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 183 | 184 | val_transforms.append(transforms.Resize(256)) 185 | val_transforms.append(transforms.CenterCrop(224)) 186 | 187 | if args.random_horizontal_flip: 188 | train_transforms.append(transforms.RandomHorizontalFlip()) 189 | 190 | train_transforms.append(transforms.ToTensor()) 191 | val_transforms.append(transforms.ToTensor()) 192 | 193 | if args.normalizing_data: 194 | train_transforms.append(normalize) 195 | val_transforms.append(normalize) 196 | 197 | train_transform = transforms.Compose(train_transforms) 198 | val_transform = transforms.Compose(val_transforms) 199 | 200 | # Setup data loader 201 | if args.dataset == DATASET_IMAGENET: 202 | # ImageNet 203 | num_classes = 1000 204 | 205 | train_root = args.root if args.train_root is None else args.train_root 206 | val_root = args.root if args.val_root is None else args.val_root 207 | train_dataset = datasets.ImageFolder(root=train_root, transform=train_transform) 208 | val_dataset = datasets.ImageFolder(root=val_root, transform=val_transform) 209 | else: 210 | if args.dataset == DATASET_CIFAR10: 211 | # CIFAR-10 212 | num_classes = 10 213 | dataset_class = datasets.CIFAR10 214 | else: 215 | # CIFAR-100 216 | num_classes = 100 217 | dataset_class = datasets.CIFAR100 218 | 219 | train_dataset = dataset_class( 220 | root=args.root, train=True, download=args.download, transform=train_transform) 221 | val_dataset = dataset_class( 222 | root=args.root, train=False, download=args.download, transform=val_transform) 223 | 224 | # [COMM] Setup distributed sampler for data parallel & MC sample parallel 225 | train_sampler = torch.utils.data.distributed.DistributedSampler( 226 | train_dataset, num_replicas=num_data_group, rank=data_group_id) 227 | train_loader = torch.utils.data.DataLoader( 228 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 229 | pin_memory=True, sampler=train_sampler, num_workers=args.num_workers) 230 | 231 | # [COMM] Setup distributed sampler for data parallel 232 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) 233 | val_loader = torch.utils.data.DataLoader( 234 | val_dataset, batch_size=args.val_batch_size, shuffle=False, 235 | sampler=val_sampler, num_workers=args.num_workers) 236 | 237 | # Setup model 238 | if args.arch_file is None: 239 | arch_class = getattr(models, args.arch_name) 240 | else: 241 | _, ext = os.path.splitext(args.arch_file) 242 | dirname = os.path.dirname(args.arch_file) 243 | 244 | if dirname == '': 245 | module_path = args.arch_file.replace(ext, '') 246 | elif dirname == '.': 247 | module_path = os.path.basename(args.arch_file).replace(ext, '') 248 | else: 249 | module_path = '.'.join(os.path.split(args.arch_file)).replace(ext, '') 250 | 251 | module = import_module(module_path) 252 | arch_class = getattr(module, args.arch_name) 253 | 254 | arch_kwargs = {} if args.arch_args is None else args.arch_args 255 | arch_kwargs['num_classes'] = num_classes 256 | 257 | model = arch_class(**arch_kwargs) 258 | setattr(model, 'num_classes', num_classes) 259 | model = model.to(device) 260 | 261 | # [COMM] Broadcast model parameters 262 | for param in list(model.parameters()): 263 | dist.broadcast(param.data, src=0) 264 | 265 | # Setup optimizer 266 | optim_kwargs = {} if args.optim_args is None else args.optim_args 267 | acc_steps = optim_kwargs.get('acc_steps', 1) 268 | global_batch_size = num_data_group * args.batch_size * acc_steps 269 | total_steps = math.ceil(args.epochs * len(train_loader.dataset) / global_batch_size) 270 | 271 | # Setup optimizer 272 | if args.optim_name == DistributedVIOptimizer.__name__: 273 | optimizer = DistributedVIOptimizer(model, 274 | mc_group_id=mc_group_id, 275 | dataset_size=len(train_loader.dataset) * args.dataset_size_scale, 276 | total_steps=total_steps, 277 | seed=args.seed, 278 | **optim_kwargs, curv_kwargs=args.curv_args) 279 | else: 280 | assert args.num_mc_groups == 1, 'You cannot use MC sample groups with non-VI optimizers.' 281 | if args.optim_name == DistributedSecondOrderOptimizer.__name__: 282 | optimizer = DistributedSecondOrderOptimizer(model, **optim_kwargs, curv_kwargs=args.curv_args) 283 | else: 284 | if args.non_wd_for_bn: 285 | group, group_non_wd = {'params': []}, {'params': [], 'non_wd': True} 286 | for m in model.children(): 287 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 288 | group_non_wd['params'].extend(m.parameters()) 289 | else: 290 | group['params'].extend(m.parameters()) 291 | 292 | params = [group, group_non_wd] 293 | else: 294 | params = model.parameters() 295 | 296 | optim_class = getattr(torch.optim, args.optim_name) 297 | optimizer = optim_class(params, **optim_kwargs) 298 | 299 | for group in optimizer.param_groups: 300 | if group.get('non_wd', False): 301 | group['weight_decay'] = 0 302 | 303 | optimizer = DistributedFirstOrderOptimizer(optimizer, model, dist, lars=args.lars) 304 | 305 | # Setup lr scheduler 306 | def get_scheduler(name, kwargs): 307 | scheduler_class = getattr(torchsso.optim.lr_scheduler, name, None) 308 | if scheduler_class is None: 309 | scheduler_class = getattr(torch.optim.lr_scheduler, name) 310 | scheduler_kwargs = {} if kwargs is None else kwargs 311 | _scheduler = scheduler_class(optimizer, **scheduler_kwargs) 312 | if args.momentum_correction: 313 | _scheduler = MomentumCorrectionLR(_scheduler) 314 | return _scheduler 315 | 316 | if args.scheduler_name is None: 317 | main_scheduler = None 318 | else: 319 | main_scheduler = get_scheduler(args.scheduler_name, args.scheduler_args) 320 | 321 | if args.warmup_scheduler_name is None: 322 | warmup_scheduler = main_scheduler 323 | else: 324 | warmup_scheduler = get_scheduler(args.warmup_scheduler_name, args.warmup_scheduler_args) 325 | 326 | logger = None 327 | start_epoch = 1 328 | 329 | # Load checkpoint 330 | if args.resume is not None: 331 | print('==> Resuming from checkpoint..') 332 | assert os.path.exists(args.resume), 'Error: no checkpoint file found' 333 | checkpoint = torch.load(args.resume) 334 | model.load_state_dict(checkpoint['model']) 335 | optimizer.load_state_dict(checkpoint['optimizer']) 336 | start_epoch = checkpoint['epoch'] 337 | 338 | if rank == 0: 339 | 340 | # All config 341 | print('===========================') 342 | print('dataset: {}'.format(vars(args)['dataset'])) 343 | print('train data size: {}'.format(len(train_loader.dataset))) 344 | print('val data size: {}'.format(len(val_loader.dataset))) 345 | 346 | print('MPI.COMM_WORLD size: {}'.format(size)) 347 | print('global mini-batch size: {}'.format(global_batch_size)) 348 | print('steps/epoch: {}'.format(math.ceil(len(train_loader.dataset) / global_batch_size))) 349 | 350 | num_mc_samples = optim_kwargs.get('num_mc_samples', None) 351 | if num_mc_samples is not None: 352 | print('global num MC samples: {}'.format(num_mc_groups * num_mc_samples)) 353 | print('MC sample group: {} processes/group x {} group'.format(size_mc_group, num_mc_groups)) 354 | print('data group: {} processes/group x {} group'.format(size_data_group, num_data_group)) 355 | 356 | if hasattr(optimizer, 'indices'): 357 | print('layer assignment: {}'.format(optimizer.indices)) 358 | 359 | print('---------------------------') 360 | 361 | for key, val in vars(args).items(): 362 | if key == 'dataset': 363 | continue 364 | else: 365 | print('{}: {}'.format(key, val)) 366 | print('===========================') 367 | 368 | # Copy this file & config to args.out 369 | if not os.path.isdir(args.out): 370 | os.makedirs(args.out) 371 | try: 372 | shutil.copy(os.path.realpath(__file__), args.out) 373 | except shutil.SameFileError: 374 | pass 375 | if args.config is not None: 376 | try: 377 | shutil.copy(args.config, args.out) 378 | except shutil.SameFileError: 379 | pass 380 | if args.arch_file is not None: 381 | try: 382 | shutil.copy(args.arch_file, args.out) 383 | except shutil.SameFileError: 384 | pass 385 | 386 | # Setup logger 387 | logger = Logger(args.out, args.log_file_name) 388 | logger.start() 389 | 390 | # Run training 391 | for epoch in range(start_epoch, args.epochs + 1): 392 | 393 | scheduler = main_scheduler if epoch > args.warmup_epochs else warmup_scheduler 394 | 395 | # train 396 | accuracy, loss = train(rank, epoch, model, device, train_loader, optimizer, scheduler, 397 | args, master_group, data_group_id, data_group, logger) 398 | # val 399 | val_accuracy, val_loss = validate(rank, model, val_loader, device, optimizer) 400 | 401 | if rank == 0: 402 | # write to log 403 | iteration = epoch * len(train_loader) 404 | elapsed_time = logger.elapsed_time 405 | log = {'epoch': epoch, 'iteration': iteration, 'elapsed_time': elapsed_time, 406 | 'accuracy': accuracy, 'loss': loss, 407 | 'val_accuracy': val_accuracy, 'val_loss': val_loss, 408 | 'lr': optimizer.param_groups[0]['lr'], 409 | 'momentum': optimizer.param_groups[0].get('momentum', 0), 410 | } 411 | logger.write(log) 412 | 413 | # save checkpoint 414 | if epoch % args.checkpoint_interval == 0 or epoch > args.epochs - 3: 415 | path = os.path.join(args.out, 'epoch{}.ckpt'.format(epoch)) 416 | data = { 417 | 'model': model.state_dict(), 418 | 'optimizer': optimizer.state_dict(), 419 | 'epoch': epoch 420 | } 421 | torch.save(data, path) 422 | 423 | 424 | def train(rank, epoch, model, device, train_loader, optimizer, scheduler, 425 | args, master_group, data_group_id=0, data_group=None, logger=None): 426 | 427 | def scheduler_type(_scheduler): 428 | if _scheduler is None: 429 | return 'none' 430 | return getattr(_scheduler, 'scheduler_type', 'epoch') 431 | 432 | if scheduler_type(scheduler) == 'epoch': 433 | scheduler.step(epoch - 1) 434 | 435 | model.train() 436 | 437 | total_correct = 0 438 | loss = None 439 | total_data_size = 0 440 | epoch_size = len(train_loader.dataset) 441 | num_iters_in_epoch = len(train_loader) 442 | base_num_iter = (epoch - 1) * num_iters_in_epoch 443 | 444 | for batch_idx, (data, target) in enumerate(train_loader): 445 | data, target = data.to(device), target.to(device) 446 | 447 | if scheduler_type(scheduler) == 'iter': 448 | scheduler.step() 449 | 450 | for name, param in model.named_parameters(): 451 | attr = 'p_pre_{}'.format(name) 452 | setattr(model, attr, param.detach().clone()) 453 | 454 | # update params 455 | def closure(): 456 | optimizer.zero_grad() 457 | output = model(data) 458 | loss = F.cross_entropy(output, target) 459 | loss.backward() 460 | 461 | return loss, output 462 | 463 | if isinstance(optimizer, DistributedSecondOrderOptimizer) \ 464 | and optimizer.curv_type == 'Fisher': 465 | closure = torchsso.get_closure_for_fisher(optimizer, model, data, target, **args.fisher_args) 466 | 467 | loss, output = optimizer.step(closure=closure) 468 | data_size = torch.tensor(len(data)).to(device) 469 | 470 | # [COMM] reduce across the all processes 471 | dist.reduce(loss, dst=0) 472 | 473 | # [COMM] reduce across the processes in a data group 474 | if data_group is not None: 475 | dist.reduce(output, dst=data_group_id, group=data_group) 476 | 477 | pred = output.argmax(dim=1, keepdim=True) 478 | correct = pred.eq(target.view_as(pred)).sum().data 479 | 480 | # [COMM] reduce across the processes in the master MC sample group 481 | if dist.get_world_size(master_group) > 1: 482 | dist.reduce(correct, dst=0, group=master_group) 483 | dist.reduce(data_size, dst=0, group=master_group) 484 | 485 | # refresh results 486 | if rank == 0: 487 | loss = loss.item() / dist.get_world_size() 488 | 489 | correct = correct.item() 490 | data_size = data_size.item() 491 | 492 | total_correct += correct 493 | 494 | iteration = base_num_iter + batch_idx + 1 495 | total_data_size += data_size 496 | 497 | is_log_timing = (epoch == 1 and batch_idx == 0) or \ 498 | (batch_idx + 1) % args.log_interval == 0 499 | 500 | # save log 501 | if logger is not None and is_log_timing: 502 | accuracy = 100. * total_correct / total_data_size 503 | elapsed_time = logger.elapsed_time 504 | print('epoch: {} [{}/{} ({:.0f}%)]\tloss: {:.6f}, ' 505 | 'accuracy: {:.0f}/{} ({:.2f}%), ' 506 | 'elapsed: {:.1f}s, iters/sec: {:.2f}'.format( 507 | epoch, total_data_size, epoch_size, 100. * (batch_idx + 1) / num_iters_in_epoch, 508 | loss, total_correct, total_data_size, accuracy, elapsed_time, iteration/elapsed_time)) 509 | 510 | lr = optimizer.param_groups[0]['lr'] 511 | m = optimizer.param_groups[0].get('momentum', 0) 512 | log = {'epoch': epoch, 'iteration': iteration, 'elapsed_time': elapsed_time, 513 | 'accuracy': accuracy, 'loss': loss, 'lr': lr, 'momentum': m} 514 | 515 | for name, param in model.named_parameters(): 516 | attr = 'p_pre_{}'.format(name) 517 | p_pre = getattr(model, attr) 518 | p_norm = param.norm().item() 519 | p_shape = list(param.size()) 520 | p_pre_norm = p_pre.norm().item() 521 | g_norm = param.grad.norm().item() 522 | upd_norm = param.sub(p_pre).norm().item() 523 | noise_scale = getattr(param, 'noise_scale', 0) 524 | 525 | p_log = {'p_shape': p_shape, 'p_norm': p_norm, 'p_pre_norm': p_pre_norm, 526 | 'g_norm': g_norm, 'upd_norm': upd_norm, 'noise_scale': noise_scale} 527 | log[name] = p_log 528 | 529 | logger.write(log) 530 | 531 | accuracy = 100. * total_correct / epoch_size 532 | 533 | return accuracy, loss 534 | 535 | 536 | def validate(rank, model, val_loader, device, optimizer): 537 | model.eval() 538 | val_loss = 0 539 | correct = 0 540 | 541 | with torch.no_grad(): 542 | for data, target in val_loader: 543 | data, target = data.to(device), target.to(device) 544 | if isinstance(optimizer, DistributedVIOptimizer): 545 | prob = optimizer.prediction(data) 546 | val_loss += F.nll_loss(torch.log(prob), target, reduction='sum') 547 | pred = prob.argmax(dim=1, keepdim=True) # get the index of the max log-probability 548 | elif hasattr(model, 'mc_prediction'): 549 | prob = model.mc_prediction(data) 550 | val_loss += F.nll_loss(torch.log(prob), target, reduction='sum') 551 | pred = prob.argmax(dim=1, keepdim=True) # get the index of the max log-probability 552 | else: 553 | output = model(data) 554 | val_loss += F.cross_entropy(output, target, reduction='sum') # sum up batch loss 555 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 556 | 557 | correct += pred.eq(target.view_as(pred)).sum() 558 | 559 | dist.reduce(val_loss, dst=0) 560 | dist.reduce(correct, dst=0) 561 | 562 | val_loss = val_loss.item() / len(val_loader.dataset) 563 | val_accuracy = 100. * correct.item() / len(val_loader.dataset) 564 | 565 | if rank == 0: 566 | print('\nEval: average loss: {:.4f}, accuracy: {:.0f}/{} ({:.2f}%)\n'.format( 567 | val_loss, correct, len(val_loader.dataset), val_accuracy)) 568 | 569 | return val_accuracy, val_loss 570 | 571 | 572 | if __name__ == '__main__': 573 | main() 574 | -------------------------------------------------------------------------------- /distributed/classification/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import * 2 | from .lenet import * 3 | from .resnet import * 4 | from .resnext import * 5 | from .alexnet import * 6 | -------------------------------------------------------------------------------- /distributed/classification/models/alexnet.py: -------------------------------------------------------------------------------- 1 | '''AlexNet for CIFAR10. FC layers are removed. Paddings are adjusted. 2 | Without BN, the start learning rate should be 0.01 3 | (c) YANG, Wei 4 | ''' 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torchsso.utils.accumulator import TensorAccumulator 8 | 9 | 10 | __all__ = ['alexnet', 'alexnet_mcdropout'] 11 | 12 | 13 | class AlexNet(nn.Module): 14 | 15 | def __init__(self, num_classes=10): 16 | super().__init__() 17 | self.conv1 = nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5) 18 | self.conv2 = nn.Conv2d(64, 192, kernel_size=5, padding=2) 19 | self.conv3 = nn.Conv2d(192, 384, kernel_size=3, padding=1) 20 | self.conv4 = nn.Conv2d(384, 256, kernel_size=3, padding=1) 21 | self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 22 | self.fc = nn.Linear(256, num_classes) 23 | 24 | def forward(self, x): 25 | x = F.relu(self.conv1(x), inplace=True) 26 | x = F.max_pool2d(x, kernel_size=2, stride=2) 27 | x = F.relu(self.conv2(x), inplace=True) 28 | x = F.max_pool2d(x, kernel_size=2, stride=2) 29 | x = F.relu(self.conv3(x), inplace=True) 30 | x = F.relu(self.conv4(x), inplace=True) 31 | x = F.relu(self.conv5(x), inplace=True) 32 | x = F.max_pool2d(x, kernel_size=2, stride=2) 33 | x = x.view(x.size(0), -1) 34 | x = self.fc(x) 35 | return x 36 | 37 | 38 | class AlexNet2(nn.Module): 39 | 40 | def __init__(self, num_classes=10): 41 | super(AlexNet2, self).__init__() 42 | self.features = nn.Sequential( 43 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5), 44 | nn.ReLU(inplace=True), 45 | nn.MaxPool2d(kernel_size=2, stride=2), 46 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 47 | nn.ReLU(inplace=True), 48 | nn.MaxPool2d(kernel_size=2, stride=2), 49 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 50 | nn.ReLU(inplace=True), 51 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 52 | nn.ReLU(inplace=True), 53 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 54 | nn.ReLU(inplace=True), 55 | nn.MaxPool2d(kernel_size=2, stride=2), 56 | ) 57 | self.classifier = nn.Linear(256, num_classes) 58 | 59 | def forward(self, x): 60 | x = self.features(x) 61 | x = x.view(x.size(0), -1) 62 | x = self.classifier(x) 63 | return x 64 | 65 | 66 | class AlexNetMCDropout(AlexNet): 67 | 68 | def __init__(self, num_classes=10, dropout_ratio=0.5, val_mc=10): 69 | super(AlexNetMCDropout, self).__init__(num_classes) 70 | self.dropout_ratio = dropout_ratio 71 | self.val_mc = val_mc 72 | 73 | def forward(self, x): 74 | dropout_ratio = self.dropout_ratio 75 | x = F.relu(F.dropout(self.conv1(x), p=dropout_ratio), inplace=True) 76 | x = F.max_pool2d(x, kernel_size=2, stride=2) 77 | x = F.relu(F.dropout(self.conv2(x), p=dropout_ratio), inplace=True) 78 | x = F.max_pool2d(x, kernel_size=2, stride=2) 79 | x = F.relu(F.dropout(self.conv3(x), p=dropout_ratio), inplace=True) 80 | x = F.relu(F.dropout(self.conv4(x), p=dropout_ratio), inplace=True) 81 | x = F.relu(F.dropout(self.conv5(x), p=dropout_ratio), inplace=True) 82 | x = F.max_pool2d(x, kernel_size=2, stride=2) 83 | x = x.view(x.size(0), -1) 84 | x = self.fc(x) 85 | return x 86 | 87 | def mc_prediction(self, x): 88 | 89 | acc_prob = TensorAccumulator() 90 | m = self.val_mc 91 | 92 | for _ in range(m): 93 | output = self.forward(x) 94 | prob = F.softmax(output, dim=1) 95 | acc_prob.update(prob, scale=1/m) 96 | 97 | prob = acc_prob.get() 98 | 99 | return prob 100 | 101 | 102 | def alexnet(**kwargs): 103 | r"""AlexNet model architecture from the 104 | `"One weird trick..." `_ paper. 105 | """ 106 | model = AlexNet(**kwargs) 107 | return model 108 | 109 | 110 | def alexnet_mcdropout(**kwargs): 111 | model = AlexNetMCDropout(**kwargs) 112 | return model 113 | 114 | -------------------------------------------------------------------------------- /distributed/classification/models/lenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torchsso.utils.accumulator import TensorAccumulator 4 | 5 | 6 | class LeNet5(nn.Module): 7 | 8 | def __init__(self, num_classes=10): 9 | super().__init__() 10 | self.conv1 = nn.Conv2d(3, 6, 5) 11 | self.conv2 = nn.Conv2d(6, 16, 5) 12 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 13 | self.fc2 = nn.Linear(120, 84) 14 | self.fc3 = nn.Linear(84, num_classes) 15 | 16 | def forward(self, x): 17 | out = F.relu(self.conv1(x)) 18 | out = F.max_pool2d(out, 2) 19 | out = F.relu(self.conv2(out)) 20 | out = F.max_pool2d(out, 2) 21 | out = out.view(out.size(0), -1) 22 | out = F.relu(self.fc1(out)) 23 | out = F.relu(self.fc2(out)) 24 | out = self.fc3(out) 25 | return out 26 | 27 | 28 | class LeNet5MCDropout(LeNet5): 29 | 30 | def __init__(self, num_classes=10, dropout_ratio=0.1, val_mc=10): 31 | super(LeNet5MCDropout, self).__init__(num_classes=num_classes) 32 | self.dropout_ratio = dropout_ratio 33 | self.val_mc = val_mc 34 | 35 | def forward(self, x): 36 | p = self.dropout_ratio 37 | out = F.relu(F.dropout(self.conv1(x), p)) 38 | out = F.max_pool2d(out, 2) 39 | out = F.relu(F.dropout(self.conv2(out), p)) 40 | out = F.max_pool2d(out, 2) 41 | out = out.view(out.size(0), -1) 42 | out = F.relu(F.dropout(self.fc1(out), p)) 43 | out = F.relu(F.dropout(self.fc2(out), p)) 44 | out = F.dropout(self.fc3(out), p) 45 | return out 46 | 47 | def mc_prediction(self, x): 48 | 49 | acc_prob = TensorAccumulator() 50 | m = self.val_mc 51 | 52 | for _ in range(m): 53 | output = self.forward(x) 54 | prob = F.softmax(output, dim=1) 55 | acc_prob.update(prob, scale=1/m) 56 | 57 | prob = acc_prob.get() 58 | 59 | return prob 60 | 61 | 62 | class LeNet5BatchNorm(nn.Module): 63 | def __init__(self, num_classes=10, affine=False): 64 | super().__init__() 65 | self.conv1 = nn.Conv2d(3, 6, 5) 66 | self.bn1 = nn.BatchNorm2d(6, affine=affine) 67 | self.conv2 = nn.Conv2d(6, 16, 5) 68 | self.bn2 = nn.BatchNorm2d(16, affine=affine) 69 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 70 | self.bn3 = nn.BatchNorm1d(120, affine=affine) 71 | self.fc2 = nn.Linear(120, 84) 72 | self.bn4 = nn.BatchNorm1d(84, affine=affine) 73 | self.fc3 = nn.Linear(84, num_classes) 74 | 75 | def forward(self, x): 76 | out = F.relu(self.bn1(self.conv1(x))) 77 | out = F.max_pool2d(out, 2) 78 | out = F.relu(self.bn2(self.conv2(out))) 79 | out = F.max_pool2d(out, 2) 80 | out = out.view(out.size(0), -1) 81 | out = F.relu(self.bn3(self.fc1(out))) 82 | out = F.relu(self.bn4(self.fc2(out))) 83 | out = self.fc3(out) 84 | return out 85 | -------------------------------------------------------------------------------- /distributed/classification/models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(planes,track_running_stats=False) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes,track_running_stats=False) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*planes: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*planes) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = self.bn2(self.conv2(out)) 34 | out += self.shortcut(x) 35 | out = F.relu(out) 36 | return out 37 | 38 | 39 | class Bottleneck(nn.Module): 40 | expansion = 4 41 | 42 | def __init__(self, in_planes, planes, stride=1): 43 | super(Bottleneck, self).__init__() 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(planes,track_running_stats=False) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn2 = nn.BatchNorm2d(planes,track_running_stats=False) 48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 49 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 50 | 51 | self.shortcut = nn.Sequential() 52 | if stride != 1 or in_planes != self.expansion*planes: 53 | self.shortcut = nn.Sequential( 54 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 55 | nn.BatchNorm2d(self.expansion*planes) 56 | ) 57 | 58 | def forward(self, x): 59 | out = F.relu(self.bn1(self.conv1(x))) 60 | out = F.relu(self.bn2(self.conv2(out))) 61 | out = self.bn3(self.conv3(out)) 62 | out += self.shortcut(x) 63 | out = F.relu(out) 64 | return out 65 | 66 | 67 | class ResNet(nn.Module): 68 | def __init__(self, block, num_blocks, num_classes=10): 69 | super(ResNet, self).__init__() 70 | self.in_planes = 64 71 | 72 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 73 | self.bn1 = nn.BatchNorm2d(64) 74 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 75 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 76 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 77 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 78 | self.linear = nn.Linear(512*block.expansion, num_classes) 79 | 80 | def _make_layer(self, block, planes, num_blocks, stride): 81 | strides = [stride] + [1]*(num_blocks-1) 82 | layers = [] 83 | for stride in strides: 84 | layers.append(block(self.in_planes, planes, stride)) 85 | self.in_planes = planes * block.expansion 86 | return nn.Sequential(*layers) 87 | 88 | def forward(self, x): 89 | out = F.relu(self.bn1(self.conv1(x))) 90 | out = self.layer1(out) 91 | out = self.layer2(out) 92 | out = self.layer3(out) 93 | out = self.layer4(out) 94 | out = F.avg_pool2d(out, 4) 95 | out = out.view(out.size(0), -1) 96 | out = self.linear(out) 97 | return out 98 | 99 | 100 | def ResNet18(num_classes=10): 101 | return ResNet(BasicBlock, [2,2,2,2], num_classes) 102 | 103 | def ResNet34(num_classes=10): 104 | return ResNet(BasicBlock, [3,4,6,3], num_classes) 105 | 106 | def ResNet50(num_classes=10): 107 | return ResNet(Bottleneck, [3,4,6,3], num_classes) 108 | 109 | def ResNet101(num_classes=10): 110 | return ResNet(Bottleneck, [3,4,23,3], num_classes) 111 | 112 | def ResNet152(num_classe=10): 113 | return ResNet(Bottleneck, [3,8,36,3], num_classes) 114 | 115 | 116 | def test(): 117 | net = ResNet18() 118 | y = net(torch.randn(1,3,32,32)) 119 | print(y.size()) 120 | 121 | # test() 122 | -------------------------------------------------------------------------------- /distributed/classification/models/resnet_b.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 7 | 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, groups=groups, bias=False) 22 | 23 | 24 | def conv1x1(in_planes, out_planes, stride=1): 25 | """1x1 convolution""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 33 | base_width=64, norm_layer=None, norm_stat_momentum=0.1): 34 | super(BasicBlock, self).__init__() 35 | if norm_layer is None: 36 | norm_layer = nn.BatchNorm2d 37 | if groups != 1 or base_width != 64: 38 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 39 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 40 | self.conv1 = conv3x3(inplanes, planes, stride) 41 | self.bn1 = norm_layer(planes, momentum=norm_stat_momentum) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.conv2 = conv3x3(planes, planes) 44 | self.bn2 = norm_layer(planes, momentum=norm_stat_momentum) 45 | self.downsample = downsample 46 | self.stride = stride 47 | 48 | def forward(self, x): 49 | identity = x 50 | 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | 58 | if self.downsample is not None: 59 | identity = self.downsample(x) 60 | 61 | out += identity 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class Bottleneck(nn.Module): 68 | expansion = 4 69 | 70 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 71 | base_width=64, norm_layer=None, norm_stat_momentum=0.1): 72 | super(Bottleneck, self).__init__() 73 | if norm_layer is None: 74 | norm_layer = nn.BatchNorm2d 75 | width = int(planes * (base_width / 64.)) * groups 76 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 77 | self.conv1 = conv1x1(inplanes, width) 78 | self.bn1 = norm_layer(width, momentum=norm_stat_momentum) 79 | self.conv2 = conv3x3(width, width, stride, groups) 80 | self.bn2 = norm_layer(width, momentum=norm_stat_momentum) 81 | self.conv3 = conv1x1(width, planes * self.expansion) 82 | self.bn3 = norm_layer(planes * self.expansion, momentum=norm_stat_momentum) 83 | self.relu = nn.ReLU(inplace=True) 84 | self.downsample = downsample 85 | self.stride = stride 86 | 87 | def forward(self, x): 88 | identity = x 89 | 90 | out = self.conv1(x) 91 | out = self.bn1(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv2(out) 95 | out = self.bn2(out) 96 | out = self.relu(out) 97 | 98 | out = self.conv3(out) 99 | out = self.bn3(out) 100 | 101 | if self.downsample is not None: 102 | identity = self.downsample(x) 103 | 104 | out += identity 105 | out = self.relu(out) 106 | 107 | return out 108 | 109 | 110 | class ResNet(nn.Module): 111 | 112 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 113 | groups=1, width_per_group=64, norm_layer=None, norm_stat_momentum=0.1): 114 | super(ResNet, self).__init__() 115 | if norm_layer is None: 116 | norm_layer = nn.BatchNorm2d 117 | 118 | self.inplanes = 64 119 | self.groups = groups 120 | self.base_width = width_per_group 121 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 122 | bias=False) 123 | self.bn1 = norm_layer(self.inplanes, momentum=norm_stat_momentum) 124 | self.relu = nn.ReLU(inplace=True) 125 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 126 | self.layer1 = self._make_layer(block, 64, layers[0], 127 | norm_layer=norm_layer, norm_stat_momentum=norm_stat_momentum) 128 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 129 | norm_layer=norm_layer, norm_stat_momentum=norm_stat_momentum) 130 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 131 | norm_layer=norm_layer, norm_stat_momentum=norm_stat_momentum) 132 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 133 | norm_layer=norm_layer, norm_stat_momentum=norm_stat_momentum) 134 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 135 | self.fc = nn.Linear(512 * block.expansion, num_classes) 136 | 137 | for m in self.modules(): 138 | if isinstance(m, nn.Conv2d): 139 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 140 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 141 | nn.init.constant_(m.weight, 1) 142 | nn.init.constant_(m.bias, 0) 143 | 144 | # Zero-initialize the last BN in each residual branch, 145 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 146 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 147 | if zero_init_residual: 148 | for m in self.modules(): 149 | if isinstance(m, Bottleneck): 150 | nn.init.constant_(m.bn3.weight, 0) 151 | elif isinstance(m, BasicBlock): 152 | nn.init.constant_(m.bn2.weight, 0) 153 | 154 | def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None, norm_stat_momentum=0.1): 155 | if norm_layer is None: 156 | norm_layer = nn.BatchNorm2d 157 | downsample = None 158 | if stride != 1 or self.inplanes != planes * block.expansion: 159 | downsample = nn.Sequential( 160 | conv1x1(self.inplanes, planes * block.expansion, stride), 161 | norm_layer(planes * block.expansion, momentum=norm_stat_momentum), 162 | ) 163 | 164 | layers = [] 165 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 166 | self.base_width, norm_layer, norm_stat_momentum)) 167 | self.inplanes = planes * block.expansion 168 | for _ in range(1, blocks): 169 | layers.append(block(self.inplanes, planes, groups=self.groups, 170 | base_width=self.base_width, 171 | norm_layer=norm_layer, norm_stat_momentum=norm_stat_momentum)) 172 | 173 | return nn.Sequential(*layers) 174 | 175 | def forward(self, x): 176 | x = self.conv1(x) 177 | x = self.bn1(x) 178 | x = self.relu(x) 179 | x = self.maxpool(x) 180 | 181 | x = self.layer1(x) 182 | x = self.layer2(x) 183 | x = self.layer3(x) 184 | x = self.layer4(x) 185 | 186 | x = self.avgpool(x) 187 | x = x.view(x.size(0), -1) 188 | x = self.fc(x) 189 | 190 | return x 191 | 192 | 193 | def resnet18(pretrained=False, **kwargs): 194 | """Constructs a ResNet-18 model. 195 | 196 | Args: 197 | pretrained (bool): If True, returns a model pre-trained on ImageNet 198 | """ 199 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 200 | if pretrained: 201 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 202 | return model 203 | 204 | 205 | def resnet34(pretrained=False, **kwargs): 206 | """Constructs a ResNet-34 model. 207 | 208 | Args: 209 | pretrained (bool): If True, returns a model pre-trained on ImageNet 210 | """ 211 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 212 | if pretrained: 213 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 214 | return model 215 | 216 | 217 | def resnet50(pretrained=False, **kwargs): 218 | """Constructs a ResNet-50 model. 219 | 220 | Args: 221 | pretrained (bool): If True, returns a model pre-trained on ImageNet 222 | """ 223 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 224 | if pretrained: 225 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 226 | return model 227 | 228 | 229 | def resnet101(pretrained=False, **kwargs): 230 | """Constructs a ResNet-101 model. 231 | 232 | Args: 233 | pretrained (bool): If True, returns a model pre-trained on ImageNet 234 | """ 235 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 236 | if pretrained: 237 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 238 | return model 239 | 240 | 241 | def resnet152(pretrained=False, **kwargs): 242 | """Constructs a ResNet-152 model. 243 | 244 | Args: 245 | pretrained (bool): If True, returns a model pre-trained on ImageNet 246 | """ 247 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 248 | if pretrained: 249 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 250 | return model 251 | 252 | 253 | def resnext50_32x4d(pretrained=False, **kwargs): 254 | model = ResNet(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4, **kwargs) 255 | # if pretrained: 256 | # model.load_state_dict(model_zoo.load_url(model_urls['resnext50_32x4d'])) 257 | return model 258 | 259 | 260 | def resnext101_32x8d(pretrained=False, **kwargs): 261 | model = ResNet(Bottleneck, [3, 4, 23, 3], groups=32, width_per_group=8, **kwargs) 262 | # if pretrained: 263 | # model.load_state_dict(model_zoo.load_url(model_urls['resnext101_32x8d'])) 264 | return model 265 | -------------------------------------------------------------------------------- /distributed/classification/models/resnext.py: -------------------------------------------------------------------------------- 1 | '''ResNeXt in PyTorch. 2 | 3 | See the paper "Aggregated Residual Transformations for Deep Neural Networks" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class Block(nn.Module): 11 | '''Grouped convolution block.''' 12 | expansion = 2 13 | 14 | def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1): 15 | super(Block, self).__init__() 16 | group_width = cardinality * bottleneck_width 17 | self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(group_width) 19 | self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 20 | self.bn2 = nn.BatchNorm2d(group_width) 21 | self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=False) 22 | self.bn3 = nn.BatchNorm2d(self.expansion*group_width) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*group_width: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*group_width) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = F.relu(self.bn2(self.conv2(out))) 34 | out = self.bn3(self.conv3(out)) 35 | out += self.shortcut(x) 36 | out = F.relu(out) 37 | return out 38 | 39 | 40 | class ResNeXt(nn.Module): 41 | def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes=10): 42 | super(ResNeXt, self).__init__() 43 | self.cardinality = cardinality 44 | self.bottleneck_width = bottleneck_width 45 | self.in_planes = 64 46 | 47 | self.conv1 = nn.Conv2d(3, 64, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(64) 49 | self.layer1 = self._make_layer(num_blocks[0], 1) 50 | self.layer2 = self._make_layer(num_blocks[1], 2) 51 | self.layer3 = self._make_layer(num_blocks[2], 2) 52 | # self.layer4 = self._make_layer(num_blocks[3], 2) 53 | self.linear = nn.Linear(cardinality*bottleneck_width*8, num_classes) 54 | 55 | def _make_layer(self, num_blocks, stride): 56 | strides = [stride] + [1]*(num_blocks-1) 57 | layers = [] 58 | for stride in strides: 59 | layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride)) 60 | self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width 61 | # Increase bottleneck_width by 2 after each stage. 62 | self.bottleneck_width *= 2 63 | return nn.Sequential(*layers) 64 | 65 | def forward(self, x): 66 | out = F.relu(self.bn1(self.conv1(x))) 67 | out = self.layer1(out) 68 | out = self.layer2(out) 69 | out = self.layer3(out) 70 | # out = self.layer4(out) 71 | out = F.avg_pool2d(out, 8) 72 | out = out.view(out.size(0), -1) 73 | out = self.linear(out) 74 | return out 75 | 76 | 77 | def ResNeXt29_2x64d(num_classes=10): 78 | return ResNeXt(num_blocks=[3,3,3], cardinality=2, bottleneck_width=64, num_classes=num_classes) 79 | 80 | def ResNeXt29_4x64d(num_classes=10): 81 | return ResNeXt(num_blocks=[3,3,3], cardinality=4, bottleneck_width=64, num_classes=num_classes) 82 | 83 | def ResNeXt29_8x64d(num_classes=10): 84 | return ResNeXt(num_blocks=[3,3,3], cardinality=8, bottleneck_width=64, num_classes=num_classes) 85 | 86 | def ResNeXt29_32x4d(num_classes=10): 87 | return ResNeXt(num_blocks=[3,3,3], cardinality=32, bottleneck_width=4, num_classes=num_classes) 88 | 89 | def test_resnext(): 90 | net = ResNeXt29_2x64d() 91 | x = torch.randn(1,3,32,32) 92 | y = net(x) 93 | print(y.size()) 94 | 95 | # test_resnext() 96 | -------------------------------------------------------------------------------- /distributed/classification/models/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG11/13/16/19 in Pytorch.''' 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchsso.utils.accumulator import TensorAccumulator 5 | 6 | 7 | cfg = { 8 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 10 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 11 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 12 | } 13 | 14 | 15 | class VGG(nn.Module): 16 | def __init__(self, num_classes=10, vgg_name='VGG19'): 17 | super(VGG, self).__init__() 18 | self.features = self._make_layers(cfg[vgg_name]) 19 | self.classifier = nn.Linear(512, num_classes) 20 | 21 | def forward(self, x): 22 | out = self.features(x) 23 | out = out.view(out.size(0), -1) 24 | out = self.classifier(out) 25 | return out 26 | 27 | def _make_layers(self, cfg): 28 | layers = [] 29 | in_channels = 3 30 | for x in cfg: 31 | if x == 'M': 32 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 33 | else: 34 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 35 | nn.BatchNorm2d(x), 36 | nn.ReLU(inplace=True)] 37 | in_channels = x 38 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 39 | return nn.Sequential(*layers) 40 | 41 | 42 | class VGG19(nn.Module): 43 | 44 | def __init__(self, num_classes=10): 45 | super(VGG19, self).__init__() 46 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1) 47 | self.bn1_1 = nn.BatchNorm2d(64) 48 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 49 | self.bn1_2 = nn.BatchNorm2d(64) 50 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 51 | self.bn2_1 = nn.BatchNorm2d(128) 52 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 53 | self.bn2_2 = nn.BatchNorm2d(128) 54 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 55 | self.bn3_1 = nn.BatchNorm2d(256) 56 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 57 | self.bn3_2 = nn.BatchNorm2d(256) 58 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 59 | self.bn3_3 = nn.BatchNorm2d(256) 60 | self.conv3_4 = nn.Conv2d(256, 256, 3, padding=1) 61 | self.bn3_4 = nn.BatchNorm2d(256) 62 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 63 | self.bn4_1 = nn.BatchNorm2d(512) 64 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 65 | self.bn4_2 = nn.BatchNorm2d(512) 66 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 67 | self.bn4_3 = nn.BatchNorm2d(512) 68 | self.conv4_4 = nn.Conv2d(512, 512, 3, padding=1) 69 | self.bn4_4 = nn.BatchNorm2d(512) 70 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) 71 | self.bn5_1 = nn.BatchNorm2d(512) 72 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) 73 | self.bn5_2 = nn.BatchNorm2d(512) 74 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) 75 | self.bn5_3 = nn.BatchNorm2d(512) 76 | self.conv5_4 = nn.Conv2d(512, 512, 3, padding=1) 77 | self.bn5_4 = nn.BatchNorm2d(512) 78 | self.fc = nn.Linear(512, num_classes) 79 | 80 | def forward(self, x): 81 | h = F.relu(self.bn1_1(self.conv1_1(x)), inplace=True) 82 | h = F.relu(self.bn1_2(self.conv1_2(h)), inplace=True) 83 | h = F.max_pool2d(h, kernel_size=2, stride=2) 84 | h = F.relu(self.bn2_1(self.conv2_1(h)), inplace=True) 85 | h = F.relu(self.bn2_2(self.conv2_2(h)), inplace=True) 86 | h = F.max_pool2d(h, kernel_size=2, stride=2) 87 | h = F.relu(self.bn3_1(self.conv3_1(h)), inplace=True) 88 | h = F.relu(self.bn3_2(self.conv3_2(h)), inplace=True) 89 | h = F.relu(self.bn3_3(self.conv3_3(h)), inplace=True) 90 | h = F.relu(self.bn3_4(self.conv3_4(h)), inplace=True) 91 | h = F.max_pool2d(h, kernel_size=2, stride=2) 92 | h = F.relu(self.bn4_1(self.conv4_1(h)), inplace=True) 93 | h = F.relu(self.bn4_2(self.conv4_2(h)), inplace=True) 94 | h = F.relu(self.bn4_3(self.conv4_3(h)), inplace=True) 95 | h = F.relu(self.bn4_4(self.conv4_4(h)), inplace=True) 96 | h = F.max_pool2d(h, kernel_size=2, stride=2) 97 | h = F.relu(self.bn5_1(self.conv5_1(h)), inplace=True) 98 | h = F.relu(self.bn5_2(self.conv5_2(h)), inplace=True) 99 | h = F.relu(self.bn5_3(self.conv5_3(h)), inplace=True) 100 | h = F.relu(self.bn5_4(self.conv5_4(h)), inplace=True) 101 | h = F.max_pool2d(h, kernel_size=2, stride=2) 102 | h = h.view(h.size(0), -1) 103 | out = self.fc(h) 104 | return out 105 | 106 | 107 | class VGG19MCDropout(VGG19): 108 | 109 | def __init__(self, num_classes=10, dropout_ratio=0.1, val_mc=10): 110 | super(VGG19MCDropout, self).__init__(num_classes) 111 | self.dropout_ratio = dropout_ratio 112 | self.val_mc = val_mc 113 | 114 | def forward(self, x): 115 | p = self.dropout_ratio 116 | h = F.relu(self.bn1_1(F.dropout(self.conv1_1(x), p)), inplace=True) 117 | h = F.relu(self.bn1_2(F.dropout(self.conv1_2(h), p)), inplace=True) 118 | h = F.max_pool2d(h, kernel_size=2, stride=2) 119 | h = F.relu(self.bn2_1(F.dropout(self.conv2_1(h), p)), inplace=True) 120 | h = F.relu(self.bn2_2(F.dropout(self.conv2_2(h), p)), inplace=True) 121 | h = F.max_pool2d(h, kernel_size=2, stride=2) 122 | h = F.relu(self.bn3_1(F.dropout(self.conv3_1(h), p)), inplace=True) 123 | h = F.relu(self.bn3_2(F.dropout(self.conv3_2(h), p)), inplace=True) 124 | h = F.relu(self.bn3_3(F.dropout(self.conv3_3(h), p)), inplace=True) 125 | h = F.relu(self.bn3_4(F.dropout(self.conv3_4(h), p)), inplace=True) 126 | h = F.max_pool2d(h, kernel_size=2, stride=2) 127 | h = F.relu(self.bn4_1(F.dropout(self.conv4_1(h), p)), inplace=True) 128 | h = F.relu(self.bn4_2(F.dropout(self.conv4_2(h), p)), inplace=True) 129 | h = F.relu(self.bn4_3(F.dropout(self.conv4_3(h), p)), inplace=True) 130 | h = F.relu(self.bn4_4(F.dropout(self.conv4_4(h), p)), inplace=True) 131 | h = F.max_pool2d(h, kernel_size=2, stride=2) 132 | h = F.relu(self.bn5_1(F.dropout(self.conv5_1(h), p)), inplace=True) 133 | h = F.relu(self.bn5_2(F.dropout(self.conv5_2(h), p)), inplace=True) 134 | h = F.relu(self.bn5_3(F.dropout(self.conv5_3(h), p)), inplace=True) 135 | h = F.relu(self.bn5_4(F.dropout(self.conv5_4(h), p)), inplace=True) 136 | h = F.max_pool2d(h, kernel_size=2, stride=2) 137 | h = h.view(h.size(0), -1) 138 | out = F.dropout(self.fc(h), p) 139 | return out 140 | 141 | def mc_prediction(self, x): 142 | 143 | acc_prob = TensorAccumulator() 144 | m = self.val_mc 145 | 146 | for _ in range(m): 147 | output = self.forward(x) 148 | prob = F.softmax(output, dim=1) 149 | acc_prob.update(prob, scale=1/m) 150 | 151 | prob = acc_prob.get() 152 | 153 | return prob 154 | -------------------------------------------------------------------------------- /docs/boundary.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/team-approx-bayes/dl-with-bayes/fbf9b0ee185346bc269a4c40b3904384bf3c4338/docs/boundary.gif -------------------------------------------------------------------------------- /docs/curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/team-approx-bayes/dl-with-bayes/fbf9b0ee185346bc269a4c40b3904384bf3c4338/docs/curves.png -------------------------------------------------------------------------------- /docs/distributed_vi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/team-approx-bayes/dl-with-bayes/fbf9b0ee185346bc269a4c40b3904384bf3c4338/docs/distributed_vi.png -------------------------------------------------------------------------------- /neurips2019_poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/team-approx-bayes/dl-with-bayes/fbf9b0ee185346bc269a4c40b3904384bf3c4338/neurips2019_poster.pdf -------------------------------------------------------------------------------- /toy_example/README.md: -------------------------------------------------------------------------------- 1 | # Toy Example 2 | Training MLP on 2D-binary classification. 3 | ```bash 4 | $ python main.py 5 | ``` 6 | This script creates following GIF. 7 | 8 | Decision boundary and entropy plots on 2D-binary classification by MLPs trained with Adam and VOGN. 9 | ![](../docs/boundary.gif) 10 | 11 | VOGN optimizes the posterior distribution of each weight (i.e., mean and variance of the Gaussian). 12 | A model with the mean weights draws the red boundary, and models with the MC samples from the posterior distribution draw light red boundaries. 13 | VOGN converges to a similar solution as Adam while keeping uncertainty in its predictions. 14 | -------------------------------------------------------------------------------- /toy_example/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import inspect 4 | 5 | import imageio 6 | import pickle 7 | 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | from torch.utils.data import TensorDataset, DataLoader 12 | import torchsso 13 | 14 | from sklearn.datasets import make_blobs 15 | import numpy as np 16 | import matplotlib.pyplot as plt 17 | plt.rcParams['font.size'] = 18 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser() 22 | # Data 23 | parser.add_argument('--n_samples', type=int, default=100, 24 | help='number of samples') 25 | parser.add_argument('--centers', type=int, default=5, 26 | help='number of clusters') 27 | parser.add_argument('--random_state', type=int, default=5, 28 | help='random seed for data creation') 29 | # Training 30 | parser.add_argument('--epochs', type=int, default=50, 31 | help='number of epochs to train') 32 | parser.add_argument('--batch_size', type=int, default=10, 33 | help='input batch size for training') 34 | parser.add_argument('--plot_interval', type=int, default=50, 35 | help='interval iterations to plot decision boundary') 36 | # Options 37 | parser.add_argument('--n_samples_for_mcplot', type=int, default=20, 38 | help='number of MC samples for plotting boundaries by VOGN') 39 | parser.add_argument('--no_cuda', action='store_true', default=False, 40 | help='disables CUDA training') 41 | parser.add_argument('--log_file_name', type=str, default='log', 42 | help='log file name') 43 | parser.add_argument('--fig_dir', type=str, default='tmp', 44 | help='directory to keep tmp figures') 45 | parser.add_argument('--keep_figures', action='store_true', default=False, 46 | help='whether keep tmp figures after creating gif') 47 | parser.add_argument('--out', type=str, default='boundary.gif', 48 | help='output gif file') 49 | 50 | args = parser.parse_args() 51 | 52 | # Set device 53 | use_cuda = not args.no_cuda and torch.cuda.is_available() 54 | device = torch.device('cuda' if use_cuda else 'cpu') 55 | 56 | # Generate a dataset 57 | n_samples = args.n_samples 58 | centers = args.centers 59 | random_state = args.random_state 60 | 61 | X, y = make_blobs(n_samples=n_samples, n_features=2, centers=centers, random_state=random_state) 62 | y[y < int(centers) / 2] = 0 63 | y[y >= int(centers) / 2] = 1 64 | 65 | x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 66 | y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 67 | h = 0.05 68 | xx, yy = np.meshgrid(np.arange(x_min, x_max, h), 69 | np.arange(y_min, y_max, h)) 70 | data_meshgrid = torch.from_numpy(np.c_[xx.ravel(), yy.ravel()]).type(torch.float).to(device) 71 | 72 | X_tensor = torch.from_numpy(X).type(torch.float) 73 | y_tensor = torch.from_numpy(y).type(torch.float) 74 | train_dataset = TensorDataset(X_tensor, y_tensor) 75 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size) 76 | 77 | # Model arguments 78 | model_kwargs = dict(input_size=2, output_size=None, hidden_sizes=[128]) 79 | 80 | model1 = MLP(**model_kwargs) 81 | model1 = model1.to(device) 82 | optimizer1 = torch.optim.Adam(model1.parameters()) 83 | 84 | model2 = pickle.loads(pickle.dumps(model1)) # create a clone 85 | model2 = model2.to(device) 86 | optimizer2 = torchsso.optim.VOGN(model2, dataset_size=len(train_loader.dataset)) 87 | 88 | # Show all config 89 | print('===========================') 90 | print(f'model class: {model1.__class__}') 91 | for m in model1.children(): 92 | print(m) 93 | print(f'model args: {model_kwargs}') 94 | for key, val in vars(args).items(): 95 | print(f'{key}: {val}') 96 | print('---------------------------') 97 | print(f'optim1 class: {optimizer1.__class__}') 98 | print(f'optim2 class: {optimizer2.__class__}') 99 | print('===========================') 100 | 101 | figpaths = [] 102 | i = 0 # iteration 103 | 104 | # Run training 105 | for epoch in range(args.epochs): 106 | 107 | model1.train() 108 | model2.train() 109 | 110 | for data, target in train_loader: 111 | 112 | data, target = data.to(device), target.to(device) 113 | 114 | def closure1(): 115 | optimizer1.zero_grad() 116 | output = model1(data) 117 | loss = F.binary_cross_entropy_with_logits(output, target) 118 | loss.backward() 119 | return loss 120 | 121 | def closure2(): 122 | optimizer2.zero_grad() 123 | output = model2(data) 124 | loss = F.binary_cross_entropy_with_logits(output, target) 125 | loss.backward() 126 | return loss, output 127 | 128 | loss1 = optimizer1.step(closure1) 129 | loss2, _ = optimizer2.step(closure2) 130 | 131 | if (i + 1) % args.plot_interval == 0: 132 | # Setup figures 133 | fig = plt.figure(figsize=(21, 6)) 134 | gs = fig.add_gridspec(1, 3) 135 | 136 | # Decision boundary 137 | ax1 = fig.add_subplot(gs[0, 0]) 138 | ax1.set_xlabel('Input 1') 139 | ax1.set_ylabel('Input 2') 140 | ax1.set_title(f'Iteration {i+1}') 141 | 142 | # Entropy (Adam) 143 | ax2 = fig.add_subplot(gs[0, 1]) 144 | ax2.set_xlabel('Input 1') 145 | ax2.set_ylabel('Input 2') 146 | ax2.set_title(f'Entropy (Adam)') 147 | 148 | # Entropy (VOGN) 149 | ax3 = fig.add_subplot(gs[0, 2]) 150 | ax3.set_xlabel('Input 1') 151 | ax3.set_ylabel('Input 2') 152 | ax3.set_title(f'Entropy (VOGN)') 153 | 154 | model1.eval() 155 | model2.eval() 156 | 157 | # (Adam) 158 | prob = torch.sigmoid(model1(data_meshgrid)).view(xx.shape) 159 | entropy = get_entropy(prob) 160 | pred = torch.round(prob).detach().cpu().numpy() 161 | 162 | plot = ax1.contour(xx, yy, pred, colors=['blue'], linewidths=[2]) 163 | plot.collections[len(plot.collections)//2].set_label('Adam') 164 | im = ax2.pcolormesh(xx, yy, entropy) 165 | fig.colorbar(im, ax=ax2) 166 | 167 | # (VOGN) get MC samples 168 | prob, probs = optimizer2.prediction(data_meshgrid, keep_probs=True) 169 | prob = prob.view(xx.shape) 170 | entropy = get_entropy(prob) 171 | 172 | probs = probs[:args.n_samples_for_mcplot] 173 | preds = [torch.round(p).detach().cpu().numpy().reshape(xx.shape) for p in probs] 174 | for pred in preds: 175 | ax1.contour(xx, yy, pred, colors=['red'], alpha=0.01) 176 | im = ax3.pcolormesh(xx, yy, entropy) 177 | fig.colorbar(im, ax=ax3) 178 | 179 | # (VOGN) get mean prediction 180 | prob = optimizer2.prediction(data_meshgrid, mc=0).view(xx.shape) 181 | pred = torch.round(prob).detach().cpu().numpy() 182 | 183 | plot = ax1.contour(xx, yy, pred, colors=['red'], linewidths=[2]) 184 | plot.collections[len(plot.collections)//2].set_label('VOGN') 185 | 186 | # plot samples 187 | for label, marker, color in zip([0, 1], ['o', 's'], ['white', 'gray']): 188 | _X = X[y == label] 189 | ax1.scatter(_X[:, 0], _X[:, 1], s=80, c=color, edgecolors='black', marker=marker) 190 | ax2.scatter(_X[:, 0], _X[:, 1], s=80, c=color, edgecolors='black', marker=marker) 191 | ax3.scatter(_X[:, 0], _X[:, 1], s=80, c=color, edgecolors='black', marker=marker) 192 | 193 | # save tmp figure 194 | ax1.grid(linestyle='--') 195 | ax2.grid(linestyle='--') 196 | ax3.grid(linestyle='--') 197 | ax1.set_yticks([-5, 0, 5, 10]) 198 | ax2.set_yticks([-5, 0, 5, 10]) 199 | ax3.set_yticks([-5, 0, 5, 10]) 200 | ax1.legend(loc='lower right') 201 | ax1.set_aspect(0.8) 202 | plt.tight_layout() 203 | figname = f'iteration{i+1}.png' 204 | figpath = os.path.join(args.fig_dir, figname) 205 | if not os.path.isdir(args.fig_dir): 206 | os.makedirs(args.fig_dir) 207 | fig.savefig(figpath) 208 | plt.close(fig) 209 | figpaths.append(figpath) 210 | 211 | i += 1 212 | 213 | print(f'Train Epoch: {epoch+1}\tLoss(Adam): {loss1:.6f} Loss(VOGN): {loss2:.6f}') 214 | 215 | # Create GIF from temp figures 216 | images = [] 217 | for figpath in figpaths: 218 | images.append(imageio.imread(figpath)) 219 | if not args.keep_figures: 220 | os.remove(figpath) 221 | imageio.mimsave(args.out, images, fps=1) 222 | 223 | 224 | def get_entropy(prob: torch.Tensor): 225 | entropy = - prob * torch.log(prob) - (1 - prob) * torch.log(1 - prob) 226 | entropy[entropy != entropy] = 0 # nan to zero 227 | entropy = entropy.detach().cpu().numpy() 228 | 229 | return entropy 230 | 231 | 232 | class MLP(nn.Module): 233 | def __init__(self, input_size, output_size, hidden_sizes=None, act_func="relu"): 234 | super(MLP, self).__init__() 235 | self.input_size = input_size 236 | self.hidden_sizes = hidden_sizes 237 | if output_size is not None: 238 | self.output_size = output_size 239 | self.squeeze_output = False 240 | else: 241 | self.output_size = 1 242 | self.squeeze_output = True 243 | 244 | # Set activation function 245 | if act_func == "relu": 246 | self.act = F.relu 247 | elif act_func == "tanh": 248 | self.act = F.tanh 249 | elif act_func == "sigmoid": 250 | self.act = torch.sigmoid 251 | else: 252 | raise ValueError(f'Invalid activation function: {act_func}') 253 | 254 | # Define layers 255 | if hidden_sizes is None: 256 | # Linear model 257 | self.hidden_layers = [] 258 | self.output_layer = nn.Linear(self.input_size, self.output_size) 259 | else: 260 | # Neural network 261 | features = zip([self.input_size] + hidden_sizes[:-1], hidden_sizes) 262 | self.hidden_layers = nn.ModuleList([nn.Linear(in_features, out_features) for in_features, out_features in features]) 263 | self.output_layer = nn.Linear(hidden_sizes[-1], self.output_size) 264 | 265 | def forward(self, x): 266 | x = x.view(-1, self.input_size) 267 | h = x 268 | for layer in self.hidden_layers: 269 | h = self.act(layer(h)) 270 | 271 | out = self.output_layer(h) 272 | if self.squeeze_output: 273 | out = torch.squeeze(out).view([-1]) 274 | 275 | return out 276 | 277 | 278 | if __name__ == '__main__': 279 | main() 280 | --------------------------------------------------------------------------------