├── .gitignore ├── LICENSE ├── README.md ├── docs ├── distributed_vi.png └── overview.png ├── examples ├── classification │ ├── README.md │ ├── configs │ │ └── cifar10 │ │ │ ├── lenet_adam.json │ │ │ ├── lenet_kfac.json │ │ │ ├── lenet_noisykfac.json │ │ │ └── lenet_vogn.json │ ├── main.py │ └── models │ │ ├── __init__.py │ │ ├── alexnet.py │ │ ├── lenet.py │ │ ├── mlp.py │ │ ├── resnet.py │ │ └── vgg.py └── distributed │ ├── README.md │ └── classification │ ├── README.md │ ├── configs │ ├── cifar10 │ │ ├── alexnet_adam_bs256_8gpu.json │ │ ├── alexnet_kfac_bs256_8gpu.json │ │ ├── alexnet_noisykfac_bs256_8gpu.json │ │ ├── alexnet_sgd_bs256_8gpu.json │ │ ├── alexnet_vogn_bs256_8gpu.json │ │ ├── lenet_adam_bs128_4gpu.json │ │ ├── lenet_kfac_bs128_4gpu.json │ │ ├── lenet_sgd_bs128_4gpu.json │ │ ├── lenet_vogn_bs128_4gpu.json │ │ ├── resnet18_adam_bs256_8gpu.json │ │ ├── resnet18_sgd_bs256_8gpu.json │ │ ├── resnet18_vogn_bs256_8gpu.json │ │ ├── vgg19_adam_bs256_8gpu.json │ │ └── vgg19_vogn_bs256_8gpu.json │ ├── cifar100 │ │ ├── alexnet_adam_bs256_8gpu.json │ │ ├── alexnet_sgd_bs256_8gpu.json │ │ ├── alexnet_vogn_bs256_8gpu.json │ │ ├── resnet18_adam_bs256_8gpu.json │ │ └── resnet18_vogn_bs256_8gpu.json │ └── imagenet │ │ ├── resnet18_adam_bs4k_128gpu.json │ │ ├── resnet18_kfac_bs4k_128gpu.json │ │ ├── resnet18_kfac_bs4k_4gpu.json │ │ ├── resnet18_noisykfac_bs4k_128gpu.json │ │ ├── resnet18_noisykfac_bs4k_4gpu.json │ │ ├── resnet18_sgd_bs4k_128gpu.json │ │ ├── resnet18_vogn_bs4k_128gpu.json │ │ └── resnet18_vogn_bs4k_4gpu.json │ ├── main.py │ └── models │ ├── __init__.py │ ├── alexnet.py │ ├── lenet.py │ ├── resnet.py │ ├── resnet_b.py │ ├── resnext.py │ └── vgg.py ├── pytorch_sso_poster.pptx.pdf ├── setup.cfg ├── setup.py ├── tests └── test_samplegrad.py └── torchsso ├── __init__.py ├── autograd ├── __init__.py └── samplegrad.py ├── curv ├── __init__.py ├── cov │ ├── __init__.py │ ├── batchnorm.py │ ├── conv.py │ └── linear.py ├── curvature.py ├── fisher │ ├── __init__.py │ ├── batchnorm.py │ ├── conv.py │ └── linear.py └── hessian │ ├── __init__.py │ ├── conv.py │ └── linear.py ├── optim ├── __init__.py ├── firstorder.py ├── lr_scheduler.py ├── secondorder.py └── vi.py └── utils ├── __init__.py ├── accumulator.py ├── chainer_communicators ├── __init__.py ├── _utility.py ├── base.py └── pure_nccl_communicator.py ├── cholesky_cupy.py ├── cupy.py ├── inv_cupy.py └── logger.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # PyCharm 107 | .idea 108 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Cybertron AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # PyTorch-SSO (alpha release) 4 | 5 | Scalable Second-Order methods in PyTorch. 6 | 7 | - Open-source library for second-order optimization and Bayesian inference. 8 | 9 | - An earlier iteration of this library ([chainerkfac](https://github.com/tyohei/chainerkfac)) holds the world record for large-batch training of ResNet-50 on ImageNet by [Kronecker-Factored Approximate Curvature (K-FAC)](https://arxiv.org/abs/1503.05671), scaling to batch sizes of 131K. 10 | - Kazuki Osawa et al, “Large-Scale Distributed Second-Order Optimization Using Kronecker-Factored Approximate Curvature for Deep Convolutional Neural Networks”, **IEEE/CVF CVPR 2019**. 11 | - [[paper](http://openaccess.thecvf.com/content_CVPR_2019/html/Osawa_Large-Scale_Distributed_Second-Order_Optimization_Using_Kronecker-Factored_Approximate_Curvature_for_Deep_CVPR_2019_paper.html)] [[poster](https://kazukiosawa.github.io/cvpr19_poster.pdf)] 12 | - This library is basis for the Natural Gradient for Bayesian inference (Variational Inference) on ImageNet. 13 | - Kazuki Osawa et al, “Practical Deep Learning with Bayesian Principles”, **NeurIPS 2019**. 14 | - [[paper (preprint)](https://arxiv.org/abs/1906.02506)] 15 | 16 | ## Scalable Second-Order Optimization 17 | ![](docs/overview.png) 18 | 19 | ### Optimizers 20 | 21 | PyTorch-SSO provides the following optimizers. 22 | 23 | - Second-Order Optimization 24 | - `torchsso.optim.SecondOrderOptimizer` [[source](https://github.com/cybertronai/pytorch-sso/blob/master/torchsso/optim/secondorder.py)] 25 | - updates the parameters with the gradients pre-conditioned by the curvature of the loss function (`torch.nn.functional.cross_entropy`) for each `param_group`. 26 | - Variational Inference (VI) 27 | - `torchsso.optim.VIOptimizer` [[source](https://github.com/cybertronai/pytorch-sso/blob/master/torchsso/optim/vi.py)] 28 | - updates the posterior distribution (mean, covariance) of the parameters by using the curvature for each `param_group`. 29 | 30 | ### Curvatures 31 | 32 | You can specify a type of the information matrix to be used as the curvature from the following. 33 | 34 | - Hessian [WIP] 35 | 36 | - Fisher information matrix 37 | 38 | - Covariance matrix (empirical Fisher) 39 | 40 | 41 | 42 | Refer [Information matrices and generalization](https://arxiv.org/abs/1906.07774) by Valentin Thomas et al. (2019) for the definitions and the properties of these information matrices. 43 | 44 | 45 | 46 | Refer Section 6 of [Optimization Methods for Large-Scale Machine Learning](https://arxiv.org/abs/1606.04838) by L´eon Bottou et al. (2018) for a clear explanation of the second-order optimzation using these matrices as curvature. 47 | 48 | ### Approximation Methods 49 | 50 | PyTorch-SSO calculates the curvature as a layer-wise block-diagonal matrix. 51 | 52 | You can specify the approximation method for the curvatures in each layer from the follwing. 53 | 54 | 1. Full (No approximation) 55 | 2. Diagonal approximation 56 | 3. [Kronecker-Factored Approximate Curvature (K-FAC)](https://arxiv.org/abs/1503.05671) 57 | 58 | PyTorch-SSO currently supports the following layers (Modules) in PyTorch: 59 | 60 | | Layer (Module) | Full | Diagonal | K-FAC | 61 | | ------------------------- | ------------------ | ------------------ | ------------------ | 62 | | `torch.nn.Linear` | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | 63 | | `torch.nn.Conv2d` | - | :heavy_check_mark: | :heavy_check_mark: | 64 | | `torch.nn.BatchNorm1d/2d` | - | :heavy_check_mark: | - | 65 | 66 | To apply PyTorch-SSO, 67 | - Set`requires_grad` to `True` for each Module. 68 | - The network you define cannot contain any other modules. 69 | - E.g., You need to use `torch.nn.functional.relu/max_pool2d` instead of `torch.nn.ReLU/MaxPool2d` to define a ConvNet. 70 | 71 | ### Distributed Training 72 | 73 | PyTorch-SSO supports *data parallelism* and *MC samples parallelism* (for VI) 74 | for distributed training among multiple processes (GPUs). 75 | 76 | ## Installation 77 | To build PyTorch-SSO run (on a Python 3 environment) 78 | ```bash 79 | git clone git@github.com:cybertronai/pytorch-sso.git 80 | cd pytorch-sso 81 | python setup.py install 82 | ``` 83 | 84 | To use the library 85 | ```python 86 | import torchsso 87 | ``` 88 | 89 | ### Additional requirements 90 | 91 | PyTorch-SSO depends on [CuPy](https://cupy.chainer.org/) for fast GPU computation and [ChainerMN](https://github.com/chainer/chainermn) for communication. To use GPUs, you need to install the following requirements **before the installation of PyTorch-SSO**. 92 | 93 | | Running environment | Requirements | 94 | | ------------------- | ---------------------- | 95 | | single GPU | CuPy | 96 | | multiple GPUs | Cupy with NCCL, MPI4py | 97 | 98 | Refer [CuPy installation guide](https://docs-cupy.chainer.org/en/stable/install.html) and [ChainerMN installation guide](https://docs.chainer.org/en/stable/chainermn/installation/guide.html#chainermn-installation) for details. 99 | 100 | ## Examples 101 | 102 | - [Image classification with a single process](https://github.com/cybertronai/pytorch-sso/tree/master/examples/classification) (MNIST, CIFAR-10) 103 | - [Image classification with multiple processes](https://github.com/cybertronai/pytorch-sso/tree/master/examples/distributed/classification) (CIFAR-10/100, ImageNet) 104 | 105 | ## Authors 106 | 107 | Kazuki Osawa ([@kazukiosawa](https://github.com/kazukiosawa)) and Yaroslav Bulatov ([@yaroslavvb](https://github.com/yaroslavvb)) 108 | -------------------------------------------------------------------------------- /docs/distributed_vi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybertronai/pytorch-sso/d16173236728fdd1e943bc3b61243bfb28f95348/docs/distributed_vi.png -------------------------------------------------------------------------------- /docs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybertronai/pytorch-sso/d16173236728fdd1e943bc3b61243bfb28f95348/docs/overview.png -------------------------------------------------------------------------------- /examples/classification/README.md: -------------------------------------------------------------------------------- 1 | To run training LeNet-5 for CIFAR-10 classification 2 | ```bash 3 | python main.py --config --download 4 | ``` 5 | | optimizer | dataset | architecture | config file path | 6 | | --- | --- | --- | --- | 7 | | [Adam](https://arxiv.org/abs/1412.6980) | CIFAR-10 | LeNet-5 | [configs/cifar10/lenet_adam.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/classification/configs/cifar10/lenet_adam.json) | 8 | | [K-FAC](https://arxiv.org/abs/1503.05671)| CIFAR-10 | LeNet-5 | [configs/cifar10/lenet_kfac.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/classification/configs/cifar10/lenet_kfac.json) | 9 | | [Noisy K-FAC](https://arxiv.org/abs/1712.02390)| CIFAR-10 | LeNet-5 | [configs/cifar10/lenet_noisykfac.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/classification/configs/cifar10/lenet_noisykfac.json) | 10 | | [VOGN](https://arxiv.org/abs/1806.04854)| CIFAR-10 | LeNet-5 + BatchNorm | [configs/cifar10/lenet_vogn.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/classification/configs/cifar10/lenet_vogn.json) | 11 | -------------------------------------------------------------------------------- /examples/classification/configs/cifar10/lenet_adam.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 100, 4 | "batch_size": 128, 5 | "val_batch_size": 128, 6 | "random_crop": false, 7 | "random_horizontal_flip": false, 8 | "normalizing_data": true, 9 | "arch_file": "models/lenet.py", 10 | "arch_name": "LeNet5", 11 | "optim_name": "Adam", 12 | "optim_args": { 13 | "lr": 1e-3, 14 | "betas": [0.9, 0.999], 15 | "weight_decay": 0.01 16 | } 17 | } -------------------------------------------------------------------------------- /examples/classification/configs/cifar10/lenet_kfac.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 50, 4 | "batch_size": 128, 5 | "val_batch_size": 5000, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/lenet.py", 10 | "arch_name": "LeNet5", 11 | "optim_name": "SecondOrderOptimizer", 12 | "optim_args": { 13 | "curv_type":"Fisher", 14 | "curv_shapes": { 15 | "Conv2d": "Kron", 16 | "Linear": "Kron", 17 | "BatchNorm1d": "Diag", 18 | "BatchNorm2d": "Diag" 19 | }, 20 | "lr": 1e-3, 21 | "momentum": 0.9, 22 | "momentum_type": "raw", 23 | "l2_reg": 1e-3, 24 | "acc_steps": 1 25 | }, 26 | "curv_args": { 27 | "damping": 1e-3, 28 | "ema_decay": 0.999, 29 | "pi_type": "tracenorm" 30 | }, 31 | "fisher_args": { 32 | "approx_type": "mc", 33 | "num_mc": 1 34 | }, 35 | "scheduler_name": "ExponentialLR", 36 | "scheduler_args": { 37 | "gamma": 0.9 38 | }, 39 | "log_interval": 64, 40 | "no_cuda": false 41 | } -------------------------------------------------------------------------------- /examples/classification/configs/cifar10/lenet_noisykfac.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 15, 4 | "batch_size": 64, 5 | "val_batch_size": 128, 6 | "random_crop": false, 7 | "random_horizontal_flip": false, 8 | "normalizing_data": false, 9 | "arch_file": "models/lenet.py", 10 | "arch_name": "LeNet5", 11 | "optim_name": "VIOptimizer", 12 | "optim_args": { 13 | "curv_type": "Fisher", 14 | "curv_shapes": { 15 | "Conv2d": "Kron", 16 | "Linear": "Kron" 17 | }, 18 | "lr": 4e-3, 19 | "momentum": 0.9, 20 | "momentum_type": "preconditioned", 21 | "weight_decay": 0.1, 22 | "num_mc_samples": 4, 23 | "val_num_mc_samples": 0, 24 | "kl_weighting": 0.2, 25 | "prior_variance": 1 26 | }, 27 | "curv_args": { 28 | "damping": 1e-4, 29 | "ema_decay": 0.333, 30 | "pi_type": "tracenorm" 31 | }, 32 | "fisher_args": { 33 | "approx_type": "mc", 34 | "num_mc": 1 35 | }, 36 | "scheduler_name": "ExponentialLR", 37 | "scheduler_args": { 38 | "gamma": 0.9 39 | }, 40 | "no_cuda": false 41 | } 42 | -------------------------------------------------------------------------------- /examples/classification/configs/cifar10/lenet_vogn.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 30, 4 | "batch_size": 128, 5 | "val_batch_size": 128, 6 | "random_crop": false, 7 | "random_horizontal_flip": false, 8 | "normalizing_data": true, 9 | "arch_file": "models/lenet.py", 10 | "arch_name": "LeNet5BatchNorm", 11 | "arch_args": { 12 | "affine": true 13 | }, 14 | "optim_name": "VIOptimizer", 15 | "optim_args": { 16 | "curv_type": "Cov", 17 | "curv_shapes": { 18 | "Conv2d": "Diag", 19 | "Linear": "Diag", 20 | "BatchNorm1d": "Diag", 21 | "BatchNorm2d": "Diag" 22 | }, 23 | "lr": 0.01, 24 | "grad_ema_decay": 0.1, 25 | "grad_ema_type": "raw", 26 | "num_mc_samples": 10, 27 | "val_num_mc_samples": 0, 28 | "kl_weighting": 1, 29 | "init_precision": 8e-3, 30 | "prior_variance": 1, 31 | "acc_steps": 1 32 | }, 33 | "curv_args": { 34 | "damping": 0, 35 | "ema_decay": 0.001 36 | }, 37 | "scheduler_name": "ExponentialLR", 38 | "scheduler_args": { 39 | "gamma": 0.9 40 | }, 41 | "no_cuda": false 42 | } 43 | -------------------------------------------------------------------------------- /examples/classification/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from importlib import import_module 4 | import shutil 5 | import json 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torchvision import datasets, transforms, models 10 | import torchsso 11 | from torchsso.optim import SecondOrderOptimizer, VIOptimizer 12 | from torchsso.utils import Logger 13 | 14 | DATASET_CIFAR10 = 'CIFAR-10' 15 | DATASET_CIFAR100 = 'CIFAR-100' 16 | DATASET_MNIST = 'MNIST' 17 | 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser() 21 | # Data 22 | parser.add_argument('--dataset', type=str, 23 | choices=[DATASET_CIFAR10, DATASET_CIFAR100, DATASET_MNIST], default=DATASET_CIFAR10, 24 | help='name of dataset') 25 | parser.add_argument('--root', type=str, default='./data', 26 | help='root of dataset') 27 | parser.add_argument('--epochs', type=int, default=10, 28 | help='number of epochs to train') 29 | parser.add_argument('--batch_size', type=int, default=128, 30 | help='input batch size for training') 31 | parser.add_argument('--val_batch_size', type=int, default=128, 32 | help='input batch size for valing') 33 | parser.add_argument('--normalizing_data', action='store_true', 34 | help='[data pre processing] normalizing data') 35 | parser.add_argument('--random_crop', action='store_true', 36 | help='[data augmentation] random crop') 37 | parser.add_argument('--random_horizontal_flip', action='store_true', 38 | help='[data augmentation] random horizontal flip') 39 | # Training Settings 40 | parser.add_argument('--arch_file', type=str, default=None, 41 | help='name of file which defines the architecture') 42 | parser.add_argument('--arch_name', type=str, default='LeNet5', 43 | help='name of the architecture') 44 | parser.add_argument('--arch_args', type=json.loads, default=None, 45 | help='[JSON] arguments for the architecture') 46 | parser.add_argument('--optim_name', type=str, default=SecondOrderOptimizer.__name__, 47 | help='name of the optimizer') 48 | parser.add_argument('--optim_args', type=json.loads, default=None, 49 | help='[JSON] arguments for the optimizer') 50 | parser.add_argument('--curv_args', type=json.loads, default=dict(), 51 | help='[JSON] arguments for the curvature') 52 | parser.add_argument('--fisher_args', type=json.loads, default=dict(), 53 | help='[JSON] arguments for the fisher') 54 | parser.add_argument('--scheduler_name', type=str, default=None, 55 | help='name of the learning rate scheduler') 56 | parser.add_argument('--scheduler_args', type=json.loads, default=None, 57 | help='[JSON] arguments for the scheduler') 58 | # Options 59 | parser.add_argument('--download', action='store_true', default=False, 60 | help='if True, downloads the dataset (CIFAR-10 or 100) from the internet') 61 | parser.add_argument('--create_graph', action='store_true', default=False, 62 | help='create graph of the derivative') 63 | parser.add_argument('--no_cuda', action='store_true', default=False, 64 | help='disables CUDA training') 65 | parser.add_argument('--seed', type=int, default=1, 66 | help='random seed') 67 | parser.add_argument('--num_workers', type=int, default=0, 68 | help='number of sub processes for data loading') 69 | parser.add_argument('--log_interval', type=int, default=50, 70 | help='how many batches to wait before logging training status') 71 | parser.add_argument('--log_file_name', type=str, default='log', 72 | help='log file name') 73 | parser.add_argument('--checkpoint_interval', type=int, default=50, 74 | help='how many epochs to wait before logging training status') 75 | parser.add_argument('--resume', type=str, default=None, 76 | help='checkpoint path for resume training') 77 | parser.add_argument('--out', type=str, default='result', 78 | help='dir to save output files') 79 | parser.add_argument('--config', default='configs/cifar10/lenet_kfac.json', 80 | help='config file path') 81 | 82 | args = parser.parse_args() 83 | dict_args = vars(args) 84 | 85 | # Load config file 86 | if args.config is not None: 87 | with open(args.config) as f: 88 | config = json.load(f) 89 | dict_args.update(config) 90 | 91 | # Set device 92 | use_cuda = not args.no_cuda and torch.cuda.is_available() 93 | device = torch.device('cuda' if use_cuda else 'cpu') 94 | 95 | # Set random seed 96 | torch.manual_seed(args.seed) 97 | 98 | # Setup data augmentation & data pre processing 99 | train_transforms, val_transforms = [], [] 100 | if args.random_crop: 101 | train_transforms.append(transforms.RandomCrop(32, padding=4)) 102 | 103 | if args.random_horizontal_flip: 104 | train_transforms.append(transforms.RandomHorizontalFlip()) 105 | 106 | train_transforms.append(transforms.ToTensor()) 107 | val_transforms.append(transforms.ToTensor()) 108 | 109 | if args.normalizing_data: 110 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 111 | train_transforms.append(normalize) 112 | val_transforms.append(normalize) 113 | 114 | train_transform = transforms.Compose(train_transforms) 115 | val_transform = transforms.Compose(val_transforms) 116 | 117 | # Setup data loader 118 | if args.dataset == DATASET_CIFAR10: 119 | # CIFAR-10 120 | num_classes = 10 121 | dataset_class = datasets.CIFAR10 122 | elif args.dataset == DATASET_CIFAR100: 123 | # CIFAR-100 124 | num_classes = 100 125 | dataset_class = datasets.CIFAR100 126 | elif args.dataset == DATASET_MNIST: 127 | num_classes = 10 128 | dataset_class = datasets.MNIST 129 | else: 130 | assert False, f'unknown dataset {args.dataset}' 131 | 132 | train_dataset = dataset_class( 133 | root=args.root, train=True, download=args.download, transform=train_transform) 134 | val_dataset = dataset_class( 135 | root=args.root, train=False, download=args.download, transform=val_transform) 136 | 137 | train_loader = torch.utils.data.DataLoader( 138 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 139 | val_loader = torch.utils.data.DataLoader( 140 | val_dataset, batch_size=args.val_batch_size, shuffle=False, num_workers=args.num_workers) 141 | 142 | # Setup model 143 | if args.arch_file is None: 144 | arch_class = getattr(models, args.arch_name) 145 | else: 146 | _, ext = os.path.splitext(args.arch_file) 147 | dirname = os.path.dirname(args.arch_file) 148 | 149 | if dirname == '': 150 | module_path = args.arch_file.replace(ext, '') 151 | elif dirname == '.': 152 | module_path = os.path.basename(args.arch_file).replace(ext, '') 153 | else: 154 | module_path = '.'.join(os.path.split(args.arch_file)).replace(ext, '') 155 | 156 | module = import_module(module_path) 157 | arch_class = getattr(module, args.arch_name) 158 | 159 | arch_kwargs = {} if args.arch_args is None else args.arch_args 160 | arch_kwargs['num_classes'] = num_classes 161 | 162 | model = arch_class(**arch_kwargs) 163 | setattr(model, 'num_classes', num_classes) 164 | model = model.to(device) 165 | 166 | optim_kwargs = {} if args.optim_args is None else args.optim_args 167 | 168 | # Setup optimizer 169 | if args.optim_name == SecondOrderOptimizer.__name__: 170 | optimizer = SecondOrderOptimizer(model, **optim_kwargs, curv_kwargs=args.curv_args) 171 | elif args.optim_name == VIOptimizer.__name__: 172 | optimizer = VIOptimizer(model, dataset_size=len(train_loader.dataset), seed=args.seed, 173 | **optim_kwargs, curv_kwargs=args.curv_args) 174 | else: 175 | optim_class = getattr(torch.optim, args.optim_name) 176 | optimizer = optim_class(model.parameters(), **optim_kwargs) 177 | 178 | # Setup lr scheduler 179 | if args.scheduler_name is None: 180 | scheduler = None 181 | else: 182 | scheduler_class = getattr(torchsso.optim.lr_scheduler, args.scheduler_name, None) 183 | if scheduler_class is None: 184 | scheduler_class = getattr(torch.optim.lr_scheduler, args.scheduler_name) 185 | scheduler_kwargs = {} if args.scheduler_args is None else args.scheduler_args 186 | scheduler = scheduler_class(optimizer, **scheduler_kwargs) 187 | 188 | start_epoch = 1 189 | 190 | # Load checkpoint 191 | if args.resume is not None: 192 | print('==> Resuming from checkpoint..') 193 | assert os.path.exists(args.resume), 'Error: no checkpoint file found' 194 | checkpoint = torch.load(args.resume) 195 | model.load_state_dict(checkpoint['model']) 196 | start_epoch = checkpoint['epoch'] 197 | 198 | # All config 199 | print('===========================') 200 | for key, val in vars(args).items(): 201 | if key == 'dataset': 202 | print('{}: {}'.format(key, val)) 203 | print('train data size: {}'.format(len(train_loader.dataset))) 204 | print('val data size: {}'.format(len(val_loader.dataset))) 205 | else: 206 | print('{}: {}'.format(key, val)) 207 | print('===========================') 208 | 209 | # Copy this file & config to args.out 210 | if not os.path.isdir(args.out): 211 | os.makedirs(args.out) 212 | shutil.copy(os.path.realpath(__file__), args.out) 213 | 214 | if args.config is not None: 215 | shutil.copy(args.config, args.out) 216 | if args.arch_file is not None: 217 | shutil.copy(args.arch_file, args.out) 218 | 219 | # Setup logger 220 | logger = Logger(args.out, args.log_file_name) 221 | logger.start() 222 | 223 | # Run training 224 | for epoch in range(start_epoch, args.epochs + 1): 225 | 226 | # train 227 | accuracy, loss, confidence = train(model, device, train_loader, optimizer, scheduler, epoch, args, logger) 228 | 229 | # val 230 | val_accuracy, val_loss = validate(model, device, val_loader, optimizer) 231 | 232 | # save log 233 | iteration = epoch * len(train_loader) 234 | log = {'epoch': epoch, 'iteration': iteration, 235 | 'accuracy': accuracy, 'loss': loss, 'confidence': confidence, 236 | 'val_accuracy': val_accuracy, 'val_loss': val_loss, 237 | 'lr': optimizer.param_groups[0]['lr'], 238 | 'momentum': optimizer.param_groups[0].get('momentum', 0)} 239 | logger.write(log) 240 | 241 | # save checkpoint 242 | if epoch % args.checkpoint_interval == 0 or epoch == args.epochs: 243 | path = os.path.join(args.out, 'epoch{}.ckpt'.format(epoch)) 244 | data = { 245 | 'model': model.state_dict(), 246 | 'optimizer': optimizer.state_dict(), 247 | 'epoch': epoch 248 | } 249 | torch.save(data, path) 250 | 251 | 252 | def train(model, device, train_loader, optimizer, scheduler, epoch, args, logger): 253 | 254 | def scheduler_type(_scheduler): 255 | if _scheduler is None: 256 | return 'none' 257 | return getattr(_scheduler, 'scheduler_type', 'epoch') 258 | 259 | model.train() 260 | 261 | total_correct = 0 262 | loss = None 263 | confidence = {'top1': 0, 'top1_true': 0, 'top1_false': 0, 'true': 0, 'false': 0} 264 | total_data_size = 0 265 | epoch_size = len(train_loader.dataset) 266 | num_iters_in_epoch = len(train_loader) 267 | base_num_iter = (epoch - 1) * num_iters_in_epoch 268 | 269 | for batch_idx, (data, target) in enumerate(train_loader): 270 | data, target = data.to(device), target.to(device) 271 | 272 | for name, param in model.named_parameters(): 273 | attr = 'p_pre_{}'.format(name) 274 | setattr(model, attr, param.detach().clone()) 275 | 276 | # update params 277 | def closure(): 278 | optimizer.zero_grad() 279 | output = model(data) 280 | loss = F.cross_entropy(output, target) 281 | loss.backward(create_graph=args.create_graph) 282 | 283 | return loss, output 284 | 285 | if isinstance(optimizer, SecondOrderOptimizer) and optimizer.curv_type == 'Fisher': 286 | closure = torchsso.get_closure_for_fisher(optimizer, model, data, target, **args.fisher_args) 287 | 288 | loss, output = optimizer.step(closure=closure) 289 | 290 | pred = output.argmax(dim=1, keepdim=True) 291 | correct = pred.eq(target.view_as(pred)).sum().item() 292 | 293 | loss = loss.item() 294 | total_correct += correct 295 | 296 | prob = F.softmax(output, dim=1) 297 | for p, idx in zip(prob, target): 298 | confidence['top1'] += torch.max(p).item() 299 | top1 = torch.argmax(p).item() 300 | if top1 == idx: 301 | confidence['top1_true'] += p[top1].item() 302 | else: 303 | confidence['top1_false'] += p[top1].item() 304 | confidence['true'] += p[idx].item() 305 | confidence['false'] += (1 - p[idx].item()) 306 | 307 | iteration = base_num_iter + batch_idx + 1 308 | total_data_size += len(data) 309 | 310 | if scheduler_type(scheduler) == 'iter': 311 | scheduler.step() 312 | 313 | if batch_idx % args.log_interval == 0: 314 | accuracy = 100. * total_correct / total_data_size 315 | elapsed_time = logger.elapsed_time 316 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, ' 317 | 'Accuracy: {:.0f}/{} ({:.2f}%), ' 318 | 'Elapsed Time: {:.1f}s'.format( 319 | epoch, total_data_size, epoch_size, 100. * (batch_idx + 1) / num_iters_in_epoch, 320 | loss, total_correct, total_data_size, accuracy, elapsed_time)) 321 | 322 | # save log 323 | lr = optimizer.param_groups[0]['lr'] 324 | log = {'epoch': epoch, 'iteration': iteration, 'elapsed_time': elapsed_time, 325 | 'accuracy': accuracy, 'loss': loss, 'lr': lr} 326 | 327 | for name, param in model.named_parameters(): 328 | attr = 'p_pre_{}'.format(name) 329 | p_pre = getattr(model, attr) 330 | p_norm = param.norm().item() 331 | p_shape = list(param.size()) 332 | p_pre_norm = p_pre.norm().item() 333 | g_norm = param.grad.norm().item() 334 | upd_norm = param.sub(p_pre).norm().item() 335 | noise_scale = getattr(param, 'noise_scale', 0) 336 | 337 | p_log = {'p_shape': p_shape, 'p_norm': p_norm, 'p_pre_norm': p_pre_norm, 338 | 'g_norm': g_norm, 'upd_norm': upd_norm, 'noise_scale': noise_scale} 339 | log[name] = p_log 340 | 341 | logger.write(log) 342 | 343 | if scheduler_type(scheduler) == 'epoch': 344 | scheduler.step(epoch - 1) 345 | 346 | accuracy = 100. * total_correct / epoch_size 347 | confidence['top1'] /= epoch_size 348 | confidence['top1_true'] /= total_correct 349 | confidence['top1_false'] /= (epoch_size - total_correct) 350 | confidence['true'] /= epoch_size 351 | confidence['false'] /= (epoch_size * (model.num_classes - 1)) 352 | 353 | return accuracy, loss, confidence 354 | 355 | 356 | def validate(model, device, val_loader, optimizer): 357 | model.eval() 358 | val_loss = 0 359 | correct = 0 360 | 361 | with torch.no_grad(): 362 | for data, target in val_loader: 363 | 364 | data, target = data.to(device), target.to(device) 365 | 366 | if isinstance(optimizer, VIOptimizer): 367 | output = optimizer.prediction(data) 368 | else: 369 | output = model(data) 370 | 371 | val_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss 372 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 373 | correct += pred.eq(target.view_as(pred)).sum().item() 374 | 375 | val_loss /= len(val_loader.dataset) 376 | val_accuracy = 100. * correct / len(val_loader.dataset) 377 | 378 | print('\nEval: Average loss: {:.4f}, Accuracy: {:.0f}/{} ({:.2f}%)\n'.format( 379 | val_loss, correct, len(val_loader.dataset), val_accuracy)) 380 | 381 | return val_accuracy, val_loss 382 | 383 | 384 | if __name__ == '__main__': 385 | main() 386 | -------------------------------------------------------------------------------- /examples/classification/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import * 2 | from .lenet import * 3 | from .resnet import * 4 | from .alexnet import * 5 | from .mlp import * 6 | -------------------------------------------------------------------------------- /examples/classification/models/alexnet.py: -------------------------------------------------------------------------------- 1 | '''AlexNet for CIFAR10. FC layers are removed. Paddings are adjusted. 2 | Without BN, the start learning rate should be 0.01 3 | (c) YANG, Wei 4 | ''' 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torchsso.utils.accumulator import TensorAccumulator 8 | 9 | 10 | __all__ = ['alexnet', 'alexnet_mcdropout'] 11 | 12 | 13 | class AlexNet(nn.Module): 14 | 15 | def __init__(self, num_classes=10): 16 | super().__init__() 17 | self.conv1 = nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5) 18 | self.conv2 = nn.Conv2d(64, 192, kernel_size=5, padding=2) 19 | self.conv3 = nn.Conv2d(192, 384, kernel_size=3, padding=1) 20 | self.conv4 = nn.Conv2d(384, 256, kernel_size=3, padding=1) 21 | self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 22 | self.fc = nn.Linear(256, num_classes) 23 | 24 | def forward(self, x): 25 | x = F.relu(self.conv1(x)) 26 | x = F.max_pool2d(x, kernel_size=2, stride=2) 27 | x = F.relu(self.conv2(x)) 28 | x = F.max_pool2d(x, kernel_size=2, stride=2) 29 | x = F.relu(self.conv3(x)) 30 | x = F.relu(self.conv4(x)) 31 | x = F.relu(self.conv5(x)) 32 | x = F.max_pool2d(x, kernel_size=2, stride=2) 33 | x = x.view(x.size(0), -1) 34 | x = self.fc(x) 35 | return x 36 | 37 | 38 | class AlexNetMCDropout(AlexNet): 39 | 40 | mc_dropout = True 41 | 42 | def __init__(self, num_classes=10, dropout_ratio=0.5, val_mc=10): 43 | super(AlexNetMCDropout, self).__init__(num_classes) 44 | self.dropout_ratio = dropout_ratio 45 | self.val_mc = val_mc 46 | 47 | def forward(self, x): 48 | dropout_ratio = self.dropout_ratio 49 | x = F.relu(F.dropout(self.conv1(x), p=dropout_ratio)) 50 | x = F.max_pool2d(x, kernel_size=2, stride=2) 51 | x = F.relu(F.dropout(self.conv2(x), p=dropout_ratio)) 52 | x = F.max_pool2d(x, kernel_size=2, stride=2) 53 | x = F.relu(F.dropout(self.conv3(x), p=dropout_ratio)) 54 | x = F.relu(F.dropout(self.conv4(x), p=dropout_ratio)) 55 | x = F.relu(F.dropout(self.conv5(x), p=dropout_ratio)) 56 | x = F.max_pool2d(x, kernel_size=2, stride=2) 57 | x = x.view(x.size(0), -1) 58 | x = self.fc(x) 59 | return x 60 | 61 | def prediction(self, x): 62 | 63 | acc_prob = TensorAccumulator() 64 | m = self.val_mc 65 | 66 | for _ in range(m): 67 | output = self.forward(x) 68 | prob = F.softmax(output, dim=1) 69 | acc_prob.update(prob, scale=1/m) 70 | 71 | prob = acc_prob.get() 72 | 73 | return prob 74 | 75 | 76 | def alexnet(**kwargs): 77 | r"""AlexNet model architecture from the 78 | `"One weird trick..." `_ paper. 79 | """ 80 | model = AlexNet(**kwargs) 81 | return model 82 | 83 | 84 | def alexnet_mcdropout(**kwargs): 85 | model = AlexNetMCDropout(**kwargs) 86 | return model 87 | 88 | -------------------------------------------------------------------------------- /examples/classification/models/lenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torchsso.utils.accumulator import TensorAccumulator 4 | 5 | 6 | class LeNet5(nn.Module): 7 | 8 | def __init__(self, num_classes=10): 9 | super().__init__() 10 | self.conv1 = nn.Conv2d(3, 6, 5) 11 | self.conv2 = nn.Conv2d(6, 16, 5) 12 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 13 | self.fc2 = nn.Linear(120, 84) 14 | self.fc3 = nn.Linear(84, num_classes) 15 | 16 | def forward(self, x): 17 | out = F.relu(self.conv1(x)) 18 | out = F.max_pool2d(out, 2) 19 | out = F.relu(self.conv2(out)) 20 | out = F.max_pool2d(out, 2) 21 | out = out.view(out.size(0), -1) 22 | out = F.relu(self.fc1(out)) 23 | out = F.relu(self.fc2(out)) 24 | out = self.fc3(out) 25 | return out 26 | 27 | 28 | class LeNet5MCDropout(LeNet5): 29 | 30 | def __init__(self, num_classes=10, dropout_ratio=0.1, val_mc=10): 31 | super(LeNet5MCDropout, self).__init__(num_classes=num_classes) 32 | self.dropout_ratio = dropout_ratio 33 | self.val_mc = val_mc 34 | 35 | def forward(self, x): 36 | p = self.dropout_ratio 37 | out = F.relu(F.dropout(self.conv1(x), p)) 38 | out = F.max_pool2d(out, 2) 39 | out = F.relu(F.dropout(self.conv2(out), p)) 40 | out = F.max_pool2d(out, 2) 41 | out = out.view(out.size(0), -1) 42 | out = F.relu(F.dropout(self.fc1(out), p)) 43 | out = F.relu(F.dropout(self.fc2(out), p)) 44 | out = F.dropout(self.fc2(out), p) 45 | return out 46 | 47 | def mc_prediction(self, x): 48 | 49 | acc_prob = TensorAccumulator() 50 | m = self.val_mc 51 | 52 | for _ in range(m): 53 | output = self.forward(x) 54 | prob = F.softmax(output, dim=1) 55 | acc_prob.update(prob, scale=1/m) 56 | 57 | prob = acc_prob.get() 58 | 59 | return prob 60 | 61 | 62 | class LeNet5BatchNorm(nn.Module): 63 | def __init__(self, num_classes=10, affine=False): 64 | super().__init__() 65 | self.conv1 = nn.Conv2d(3, 6, 5) 66 | self.bn1 = nn.BatchNorm2d(6, affine=affine) 67 | self.conv2 = nn.Conv2d(6, 16, 5) 68 | self.bn2 = nn.BatchNorm2d(16, affine=affine) 69 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 70 | self.bn3 = nn.BatchNorm1d(120, affine=affine) 71 | self.fc2 = nn.Linear(120, 84) 72 | self.bn4 = nn.BatchNorm1d(84, affine=affine) 73 | self.fc3 = nn.Linear(84, num_classes) 74 | 75 | def forward(self, x): 76 | out = F.relu(self.bn1(self.conv1(x))) 77 | out = F.max_pool2d(out, 2) 78 | out = F.relu(self.bn2(self.conv2(out))) 79 | out = F.max_pool2d(out, 2) 80 | out = out.view(out.size(0), -1) 81 | out = F.relu(self.bn3(self.fc1(out))) 82 | out = F.relu(self.bn4(self.fc2(out))) 83 | out = self.fc3(out) 84 | return out 85 | -------------------------------------------------------------------------------- /examples/classification/models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | __all__ = ['mlp'] 7 | 8 | 9 | class MLP(nn.Module): 10 | def __init__(self, num_classes=10): 11 | super().__init__() 12 | n_hid = 1000 13 | n_out = 10 14 | self.l1 = nn.Linear(28*28, n_hid) 15 | self.l2 = nn.Linear(n_hid, n_hid) 16 | self.l3 = nn.Linear(n_hid, n_out) 17 | 18 | def forward(self, x: torch.Tensor): 19 | x = x.view([-1, 28*28]) 20 | x = F.relu(self.l1(x)) 21 | x = F.relu(self.l2(x)) 22 | x = self.l3(x) 23 | return x 24 | 25 | 26 | def mlp(**kwargs): 27 | model = MLP(**kwargs) 28 | return model 29 | 30 | -------------------------------------------------------------------------------- /examples/classification/models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(planes,track_running_stats=False) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes,track_running_stats=False) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*planes: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*planes) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = self.bn2(self.conv2(out)) 34 | out += self.shortcut(x) 35 | out = F.relu(out) 36 | return out 37 | 38 | 39 | class Bottleneck(nn.Module): 40 | expansion = 4 41 | 42 | def __init__(self, in_planes, planes, stride=1): 43 | super(Bottleneck, self).__init__() 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(planes,track_running_stats=False) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn2 = nn.BatchNorm2d(planes,track_running_stats=False) 48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 49 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 50 | 51 | self.shortcut = nn.Sequential() 52 | if stride != 1 or in_planes != self.expansion*planes: 53 | self.shortcut = nn.Sequential( 54 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 55 | nn.BatchNorm2d(self.expansion*planes) 56 | ) 57 | 58 | def forward(self, x): 59 | out = F.relu(self.bn1(self.conv1(x))) 60 | out = F.relu(self.bn2(self.conv2(out))) 61 | out = self.bn3(self.conv3(out)) 62 | out += self.shortcut(x) 63 | out = F.relu(out) 64 | return out 65 | 66 | 67 | class ResNet(nn.Module): 68 | def __init__(self, block, num_blocks, num_classes=10): 69 | super(ResNet, self).__init__() 70 | self.in_planes = 64 71 | 72 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 73 | self.bn1 = nn.BatchNorm2d(64) 74 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 75 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 76 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 77 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 78 | self.linear = nn.Linear(512*block.expansion, num_classes) 79 | 80 | def _make_layer(self, block, planes, num_blocks, stride): 81 | strides = [stride] + [1]*(num_blocks-1) 82 | layers = [] 83 | for stride in strides: 84 | layers.append(block(self.in_planes, planes, stride)) 85 | self.in_planes = planes * block.expansion 86 | return nn.Sequential(*layers) 87 | 88 | def forward(self, x): 89 | out = F.relu(self.bn1(self.conv1(x))) 90 | out = self.layer1(out) 91 | out = self.layer2(out) 92 | out = self.layer3(out) 93 | out = self.layer4(out) 94 | out = F.avg_pool2d(out, 4) 95 | out = out.view(out.size(0), -1) 96 | out = self.linear(out) 97 | return out 98 | 99 | 100 | def ResNet18(num_classes=10): 101 | return ResNet(BasicBlock, [2,2,2,2], num_classes) 102 | 103 | def ResNet34(num_classes=10): 104 | return ResNet(BasicBlock, [3,4,6,3], num_classes) 105 | 106 | def ResNet50(num_classes=10): 107 | return ResNet(Bottleneck, [3,4,6,3], num_classes) 108 | 109 | def ResNet101(num_classes=10): 110 | return ResNet(Bottleneck, [3,4,23,3], num_classes) 111 | 112 | def ResNet152(num_classes=10): 113 | return ResNet(Bottleneck, [3,8,36,3], num_classes) 114 | 115 | 116 | def test(): 117 | net = ResNet18() 118 | y = net(torch.randn(1,3,32,32)) 119 | print(y.size()) 120 | 121 | # test() 122 | -------------------------------------------------------------------------------- /examples/classification/models/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG11/13/16/19 in Pytorch.''' 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | cfg = { 7 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 8 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 10 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 11 | } 12 | 13 | 14 | class VGG(nn.Module): 15 | def __init__(self, vgg_name='VGG19'): 16 | super(VGG, self).__init__() 17 | self.features = self._make_layers(cfg[vgg_name]) 18 | self.classifier = nn.Linear(512, 10) 19 | 20 | def forward(self, x): 21 | out = self.features(x) 22 | out = out.view(out.size(0), -1) 23 | out = self.classifier(out) 24 | return out 25 | 26 | def _make_layers(self, cfg): 27 | layers = [] 28 | in_channels = 3 29 | for x in cfg: 30 | if x == 'M': 31 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 32 | else: 33 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 34 | nn.BatchNorm2d(x), 35 | nn.ReLU(inplace=True)] 36 | in_channels = x 37 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 38 | return nn.Sequential(*layers) 39 | 40 | 41 | def test(): 42 | net = VGG('VGG11') 43 | x = torch.randn(2,3,32,32) 44 | y = net(x) 45 | print(y.size()) 46 | 47 | # test() 48 | -------------------------------------------------------------------------------- /examples/distributed/README.md: -------------------------------------------------------------------------------- 1 | # Distributed training 2 | PyTorch-SSO supports data parallelism and MC samples parallelism (for VI) for distributed training among multiple processes (GPUs). 3 | 4 | ![](../../docs/distributed_vi.png) 5 | 6 | ## Applications 7 | - [Image classification](https://github.com/cybertronai/pytorch-sso/blob/master/examples/distributed/classification) 8 | -------------------------------------------------------------------------------- /examples/distributed/classification/README.md: -------------------------------------------------------------------------------- 1 | To run training on CIFAR-10/100 with multiple GPUs 2 | 3 | ```bash 4 | mpirun -np python main.py \ 5 | --dist_init_method \ 6 | --download \ 7 | --config 8 | ``` 9 | 10 | To run training on ImageNet with multiple GPUs 11 | 12 | ```bash 13 | mpirun -np python main.py \ 14 | --train_root \ 15 | --val_root \ 16 | --dist_init_method \ 17 | --config 18 | ``` 19 | For `init_method`, refer the [PyTorch tutorial for distirubted applications](https://pytorch.org/tutorials/intermediate/dist_tuto.html). 20 | 21 | | optimizer | dataset | architecture | GPUs | config file path | 22 | | --- | --- | --- | --- | --- | 23 | | [Adam](https://arxiv.org/abs/1412.6980) | ImageNet | ResNet-18 | 128 | [configs/imagenet/resnet18_adam_bs4k_128gpu.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/distributed/classification/configs/imagenet/resnet18_adam_bs4k_128gpu.json) | 24 | | [K-FAC](https://arxiv.org/abs/1503.05671) | ImageNet | ResNet-18 | 4 | [configs/imagenet/resnet18_kfac_bs4k_4gpu.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/distributed/classification/configs/imagenet/resnet18_kfac_bs4k_4gpu.json) | 25 | | [K-FAC](https://arxiv.org/abs/1503.05671)| ImageNet | ResNet-18 | 128 | [configs/imagenet/resnet18_kfac_bs4k_128gpu.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/distributed/classification/configs/imagenet/resnet18_kfac_bs4k_128gpu.json) | 26 | | [Noisy K-FAC](https://arxiv.org/abs/1712.02390)| ImageNet | ResNet-18 | 128 | [configs/imagenet/resnet18_noisykfac_bs4k_128gpu.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/distributed/classification/configs/imagenet/resnet18_noisykfac_bs4k_128gpu.json) | 27 | | [VOGN](https://arxiv.org/abs/1806.04854)| ImageNet | ResNet-18 | 128 | [configs/imagenet/resnet18_vogn_bs4k_128gpu.json](https://github.com/cybertronai/pytorch-sso/blob/master/examples/distributed/classification/configs/imagenet/resnet18_vogn_bs4k_128gpu.json) | 28 | 29 | - NOTE: 30 | - You need to run with `N` GPUs when you use `*{N}gpu.json` config file. 31 | - You need to set `--acc_steps` (or `"acc_steps"` in json config) to run with limited number of GPUs as below: 32 | - Mini-batch size (bs) = {examples per GPU} x {# GPUs} x {acc_steps} 33 | - Ex) 4096 (4k) = 32 x 8 x 16 34 | - The gradients of loss and the curvature are accumulated for `acc_steps` to build pseudo mini-batch size. 35 | 36 | Visit [configs](https://github.com/cybertronai/pytorch-sso/blob/master/examples/distributed/classification/configs) for other architecture, dataset, optimizer, number of GPUs. 37 | -------------------------------------------------------------------------------- /examples/distributed/classification/configs/cifar10/alexnet_adam_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/alexnet.py", 10 | "arch_name": "AlexNet", 11 | "optim_name": "Adam", 12 | "optim_args": { 13 | "lr": 0.001, 14 | "betas": [0.9, 0.999], 15 | "weight_decay": 1e-4 16 | }, 17 | "scheduler_name": "MultiStepLR", 18 | "scheduler_args": { 19 | "milestones": [80, 120], 20 | "gamma": 0.1 21 | } 22 | } -------------------------------------------------------------------------------- /examples/distributed/classification/configs/cifar10/alexnet_kfac_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 61, 4 | "batch_size": 32, 5 | "val_batch_size": 1250, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/alexnet.py", 10 | "arch_name": "AlexNet", 11 | "optim_name": "DistributedSecondOrderOptimizer", 12 | "optim_args": { 13 | "curv_type": "Fisher", 14 | "curv_shapes": { 15 | "Conv2d": "Kron", 16 | "Linear": "Kron", 17 | "BatchNorm1d": "Diag", 18 | "BatchNorm2d": "Diag" 19 | }, 20 | "lr": 1e-2, 21 | "l2_reg": 1e-3, 22 | "momentum": 0.9, 23 | "momentum_type": "raw" 24 | }, 25 | "curv_args": { 26 | "damping": 1e-3, 27 | "ema_decay": 0.999, 28 | "pi_type": "tracenorm" 29 | }, 30 | "fisher_args": { 31 | "approx_type": "mc", 32 | "num_mc": 1 33 | }, 34 | "scheduler_name": "MultiStepLR", 35 | "scheduler_args": { 36 | "milestones": [30, 50], 37 | "gamma": 0.1 38 | } 39 | } -------------------------------------------------------------------------------- /examples/distributed/classification/configs/cifar10/alexnet_noisykfac_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 81, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/alexnet.py", 10 | "arch_name": "AlexNet", 11 | "optim_name": "DistributedVIOptimizer", 12 | "optim_args": { 13 | "curv_type": "Fisher", 14 | "curv_shapes": { 15 | "Conv2d": "Kron", 16 | "Linear": "Kron", 17 | "BatchNorm1d": "Diag", 18 | "BatchNorm2d": "Diag" 19 | }, 20 | "lr": 1e-3, 21 | "momentum": 0.9, 22 | "momentum_type": "raw", 23 | "num_mc_samples": 3, 24 | "val_num_mc_samples": 10, 25 | "kl_weighting": 1, 26 | "warmup_kl_weighting_init": 0.01, 27 | "warmup_kl_weighting_steps": 15821, 28 | "prior_variance": 2 29 | }, 30 | "curv_args": { 31 | "damping": 1e-3, 32 | "ema_decay": 0.999, 33 | "pi_type": "tracenorm" 34 | }, 35 | "fisher_args": { 36 | "approx_type": "mc", 37 | "num_mc": 1 38 | }, 39 | "scheduler_name": "MultiStepLR", 40 | "scheduler_args": { 41 | "milestones": [30, 50], 42 | "gamma": 0.1 43 | }, 44 | "num_mc_groups": 8 45 | } -------------------------------------------------------------------------------- /examples/distributed/classification/configs/cifar10/alexnet_sgd_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/alexnet.py", 10 | "arch_name": "AlexNet", 11 | "optim_name": "SGD", 12 | "optim_args": { 13 | "lr": 0.1, 14 | "momentum": 0.9, 15 | "weight_decay": 1e-4 16 | }, 17 | "scheduler_name": "MultiStepLR", 18 | "scheduler_args": { 19 | "milestones": [80, 120], 20 | "gamma": 0.1 21 | } 22 | } -------------------------------------------------------------------------------- /examples/distributed/classification/configs/cifar10/alexnet_vogn_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": false, 7 | "random_horizontal_flip": false, 8 | "normalizing_data": true, 9 | "dataset_size_scale": 10, 10 | "arch_file": "models/alexnet.py", 11 | "arch_name": "AlexNet", 12 | "optim_name": "DistributedVIOptimizer", 13 | "optim_args": { 14 | "curv_type": "Cov", 15 | "curv_shapes": { 16 | "Conv2d": "Diag", 17 | "Linear": "Kron", 18 | "BatchNorm1d": "Diag", 19 | "BatchNorm2d": "Diag" 20 | }, 21 | "lr": 1e-4, 22 | "momentum": 0.9, 23 | "momentum_type": "raw", 24 | "num_mc_samples": 3, 25 | "val_num_mc_samples": 10, 26 | "kl_weighting": 1, 27 | "warmup_kl_weighting_init": 0.5, 28 | "warmup_kl_weighting_steps": 1954, 29 | "prior_variance": 2 30 | }, 31 | "curv_args": { 32 | "damping": 1e-3, 33 | "ema_decay": 0.999 34 | }, 35 | "scheduler_name": "MultiStepLR", 36 | "scheduler_args": { 37 | "milestones": [80, 120], 38 | "gamma": 0.1 39 | }, 40 | "num_mc_groups": 8 41 | } -------------------------------------------------------------------------------- /examples/distributed/classification/configs/cifar10/lenet_adam_bs128_4gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 150, 4 | "batch_size": 32, 5 | "val_batch_size": 1250, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/lenet.py", 10 | "arch_name": "LeNet5", 11 | "optim_name": "Adam", 12 | "optim_args": { 13 | "lr": 1e-3, 14 | "betas": [0.9, 0.999], 15 | "weight_decay": 1e-2 16 | } 17 | } -------------------------------------------------------------------------------- /examples/distributed/classification/configs/cifar10/lenet_kfac_bs128_4gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 50, 4 | "batch_size": 32, 5 | "val_batch_size": 1250, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/lenet.py", 10 | "arch_name": "LeNet5", 11 | "optim_name": "DistributedSecondOrderOptimizer", 12 | "optim_args": { 13 | "lr": 1e-3, 14 | "curv_type": "Fisher", 15 | "curv_shapes": { 16 | "Conv2d": "Kron", 17 | "Linear": "Kron", 18 | "BatchNorm1d": "Diag", 19 | "BatchNorm2d": "Diag" 20 | }, 21 | "momentum": 0.9, 22 | "momentum_type": "raw", 23 | "l2_reg": 1e-3, 24 | "acc_steps": 1 25 | }, 26 | "curv_args": { 27 | "damping": 1e-3, 28 | "ema_decay": 0.999, 29 | "pi_type": "tracenorm" 30 | }, 31 | "fisher_args": { 32 | "approx_type": "mc", 33 | "num_mc": 1 34 | }, 35 | "scheduler_name": "ExponentialLR", 36 | "scheduler_args": { 37 | "gamma": 0.9 38 | }, 39 | "log_interval": 64, 40 | "no_cuda": false 41 | } -------------------------------------------------------------------------------- /examples/distributed/classification/configs/cifar10/lenet_sgd_bs128_4gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 150, 4 | "batch_size": 32, 5 | "val_batch_size": 1250, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/lenet.py", 10 | "arch_name": "LeNet5", 11 | "optim_name": "SGD", 12 | "optim_args": { 13 | "lr": 1e-3, 14 | "momentum": 0.9, 15 | "weight_decay": 1e-2 16 | } 17 | } -------------------------------------------------------------------------------- /examples/distributed/classification/configs/cifar10/lenet_vogn_bs128_4gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 211, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": false, 7 | "random_horizontal_flip": false, 8 | "normalizing_data": true, 9 | "dataset_size_scale": 1, 10 | "arch_file": "models/lenet.py", 11 | "arch_name": "LeNet5", 12 | "optim_name": "DistributedVIOptimizer", 13 | "optim_args": { 14 | "curv_type": "Cov", 15 | "curv_shapes": { 16 | "Conv2d": "Diag", 17 | "Linear": "Diag", 18 | "BatchNorm1d": "Diag", 19 | "BatchNorm2d": "Diag" 20 | }, 21 | "lr": 1e-4, 22 | "momentum": 0.9, 23 | "momentum_type": "raw", 24 | "num_mc_samples": 6, 25 | "val_num_mc_samples": 10, 26 | "kl_weighting": 1, 27 | "warmup_kl_weighting_init": 0.1, 28 | "warmup_kl_weighting_steps": 11719, 29 | "prior_variance": 1e-2 30 | }, 31 | "curv_args": { 32 | "damping": 1e-3, 33 | "ema_decay": 0.999 34 | }, 35 | "num_mc_groups": 4 36 | } 37 | -------------------------------------------------------------------------------- /examples/distributed/classification/configs/cifar10/resnet18_adam_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/resnet_b.py", 10 | "arch_name": "resnet18", 11 | "arch_args": { 12 | "zero_init_residual": true, 13 | "norm_stat_momentum": 0.1 14 | }, 15 | "optim_name": "Adam", 16 | "optim_args": { 17 | "lr": 1e-3, 18 | "betas": [0.9, 0.999], 19 | "weight_decay": 5e-4 20 | }, 21 | "non_wd_for_bn": true, 22 | "scheduler_name": "MultiStepLR", 23 | "scheduler_args": { 24 | "milestones": [80, 120], 25 | "gamma": 0.1 26 | } 27 | } -------------------------------------------------------------------------------- /examples/distributed/classification/configs/cifar10/resnet18_sgd_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/resnet_b.py", 10 | "arch_name": "resnet18", 11 | "arch_args": { 12 | "zero_init_residual": true, 13 | "norm_stat_momentum": 0.1 14 | }, 15 | "optim_name": "SGD", 16 | "optim_args": { 17 | "lr": 0.1, 18 | "momentum": 0.9, 19 | "weight_decay": 1e-4 20 | }, 21 | "momentum_correction": true, 22 | "non_wd_for_bn": true, 23 | "scheduler_name": "MultiStepLR", 24 | "scheduler_args": { 25 | "milestones": [80, 120], 26 | "gamma": 0.1 27 | } 28 | } -------------------------------------------------------------------------------- /examples/distributed/classification/configs/cifar10/resnet18_vogn_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "dataset_size_scale": 10, 10 | "arch_file": "models/resnet_b.py", 11 | "arch_name": "resnet18", 12 | "arch_args": { 13 | "zero_init_residual": false, 14 | "norm_stat_momentum": 0.1 15 | }, 16 | "optim_name": "DistributedVIOptimizer", 17 | "optim_args": { 18 | "curv_type": "Cov", 19 | "curv_shapes": { 20 | "Conv2d": "Diag", 21 | "Linear": "Diag", 22 | "BatchNorm1d": "Diag", 23 | "BatchNorm2d": "Diag" 24 | }, 25 | "lr": 1e-4, 26 | "momentum": 0.9, 27 | "momentum_type": "raw", 28 | "num_mc_samples": 5, 29 | "val_num_mc_samples": 20, 30 | "kl_weighting": 1, 31 | "prior_variance": 0.02, 32 | "non_reg_for_bn": true 33 | }, 34 | "curv_args": { 35 | "damping": 1e-3, 36 | "ema_decay": 0.999 37 | }, 38 | "momentum_correction": true, 39 | "scheduler_name": "MultiStepLR", 40 | "scheduler_args": { 41 | "milestones": [80, 120], 42 | "gamma": 0.1 43 | }, 44 | "num_mc_groups": 8 45 | } -------------------------------------------------------------------------------- /examples/distributed/classification/configs/cifar10/vgg19_adam_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/vgg.py", 10 | "arch_name": "VGG19", 11 | "optim_name": "Adam", 12 | "optim_args": { 13 | "lr": 1e-3, 14 | "betas": [0.9, 0.999], 15 | "weight_decay": 1e-4 16 | }, 17 | "scheduler_name": "MultiStepLR", 18 | "scheduler_args": { 19 | "milestones": [80, 120], 20 | "gamma": 0.1 21 | } 22 | } -------------------------------------------------------------------------------- /examples/distributed/classification/configs/cifar10/vgg19_vogn_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-10", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "dataset_size_scale": 10, 10 | "arch_file": "models/vgg.py", 11 | "arch_name": "VGG19", 12 | "optim_name": "DistributedVIOptimizer", 13 | "optim_args": { 14 | "curv_type": "Cov", 15 | "curv_shapes": { 16 | "Conv2d": "Diag", 17 | "Linear": "Diag", 18 | "BatchNorm1d": "Diag", 19 | "BatchNorm2d": "Diag" 20 | }, 21 | "lr": 1e-4, 22 | "momentum": 0.9, 23 | "momentum_type": "raw", 24 | "num_mc_samples": 5, 25 | "val_num_mc_samples": 10, 26 | "kl_weighting": 1, 27 | "prior_variance": 2 28 | }, 29 | "curv_args": { 30 | "damping": 1e-3, 31 | "ema_decay": 0.999 32 | }, 33 | "scheduler_name": "MultiStepLR", 34 | "scheduler_args": { 35 | "milestones": [80, 120], 36 | "gamma": 0.1 37 | }, 38 | "num_mc_groups": 8 39 | } -------------------------------------------------------------------------------- /examples/distributed/classification/configs/cifar100/alexnet_adam_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-100", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/alexnet.py", 10 | "arch_name": "AlexNet", 11 | "optim_name": "Adam", 12 | "optim_args": { 13 | "lr": 0.001, 14 | "betas": [0.9, 0.999], 15 | "weight_decay": 1e-2 16 | }, 17 | "scheduler_name": "MultiStepLR", 18 | "scheduler_args": { 19 | "milestones": [80, 120], 20 | "gamma": 0.1 21 | } 22 | } -------------------------------------------------------------------------------- /examples/distributed/classification/configs/cifar100/alexnet_sgd_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-100", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/alexnet.py", 10 | "arch_name": "AlexNet", 11 | "optim_name": "SGD", 12 | "optim_args": { 13 | "lr": 0.1, 14 | "momentum": 0.9, 15 | "weight_decay": 1e-4 16 | }, 17 | "scheduler_name": "MultiStepLR", 18 | "scheduler_args": { 19 | "milestones": [80, 120], 20 | "gamma": 0.1 21 | } 22 | } -------------------------------------------------------------------------------- /examples/distributed/classification/configs/cifar100/alexnet_vogn_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-100", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "dataset_size_scale": 10, 10 | "arch_file": "models/alexnet.py", 11 | "arch_name": "AlexNet", 12 | "optim_name": "DistributedVIOptimizer", 13 | "optim_args": { 14 | "curv_type": "Cov", 15 | "curv_shapes": { 16 | "Conv2d": "Diag", 17 | "Linear": "Diag", 18 | "BatchNorm1d": "Diag", 19 | "BatchNorm2d": "Diag" 20 | }, 21 | "lr": 1e-4, 22 | "momentum": 0.9, 23 | "momentum_type": "raw", 24 | "num_mc_samples": 10, 25 | "val_num_mc_samples": 100, 26 | "kl_weighting": 1, 27 | "warmup_kl_weighting_init": 0.5, 28 | "warmup_kl_weighting_steps": 1954, 29 | "prior_variance": 0.02 30 | }, 31 | "curv_args": { 32 | "damping": 1e-3, 33 | "ema_decay": 0.999 34 | }, 35 | "scheduler_name": "MultiStepLR", 36 | "scheduler_args": { 37 | "milestones": [80, 120], 38 | "gamma": 0.1 39 | }, 40 | "num_mc_groups": 8 41 | } -------------------------------------------------------------------------------- /examples/distributed/classification/configs/cifar100/resnet18_adam_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-100", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "arch_file": "models/resnet_b.py", 10 | "arch_name": "resnet18", 11 | "arch_args": { 12 | "zero_init_residual": true, 13 | "norm_stat_momentum": 0.1 14 | }, 15 | "optim_name": "Adam", 16 | "optim_args": { 17 | "lr": 1e-3, 18 | "betas": [0.9, 0.999], 19 | "weight_decay": 1e-2 20 | }, 21 | "non_wd_for_bn": true, 22 | "scheduler_name": "MultiStepLR", 23 | "scheduler_args": { 24 | "milestones": [80, 120], 25 | "gamma": 0.1 26 | } 27 | } -------------------------------------------------------------------------------- /examples/distributed/classification/configs/cifar100/resnet18_vogn_bs256_8gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CIFAR-100", 3 | "epochs": 161, 4 | "batch_size": 32, 5 | "val_batch_size": 128, 6 | "random_crop": true, 7 | "random_horizontal_flip": true, 8 | "normalizing_data": true, 9 | "dataset_size_scale": 10, 10 | "arch_file": "models/resnet_b.py", 11 | "arch_name": "resnet18", 12 | "arch_args": { 13 | "zero_init_residual": false, 14 | "norm_stat_momentum": 0.1 15 | }, 16 | "optim_name": "DistributedVIOptimizer", 17 | "optim_args": { 18 | "curv_type": "Cov", 19 | "curv_shapes": { 20 | "Conv2d": "Diag", 21 | "Linear": "Diag", 22 | "BatchNorm1d": "Diag", 23 | "BatchNorm2d": "Diag" 24 | }, 25 | "lr": 1e-4, 26 | "momentum": 0.9, 27 | "momentum_type": "raw", 28 | "num_mc_samples": 5, 29 | "val_num_mc_samples": 20, 30 | "kl_weighting": 1, 31 | "warmup_kl_weighting_init": 0.5, 32 | "warmup_kl_weighting_steps": 1954, 33 | "prior_variance": 0.02, 34 | "non_reg_for_bn": true 35 | }, 36 | "curv_args": { 37 | "damping": 1e-3, 38 | "ema_decay": 0.999 39 | }, 40 | "momentum_correction": true, 41 | "scheduler_name": "MultiStepLR", 42 | "scheduler_args": { 43 | "milestones": [80, 120], 44 | "gamma": 0.1 45 | }, 46 | "num_mc_groups": 8 47 | } -------------------------------------------------------------------------------- /examples/distributed/classification/configs/imagenet/resnet18_adam_bs4k_128gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "ImageNet", 3 | "epochs": 91, 4 | "batch_size": 32, 5 | "val_batch_size": 391, 6 | "random_resized_crop": false, 7 | "random_crop": true, 8 | "random_horizontal_flip": true, 9 | "normalizing_data": true, 10 | "arch_file": "models/resnet_b.py", 11 | "arch_name": "resnet18", 12 | "arch_args": { 13 | "zero_init_residual": true, 14 | "norm_stat_momentum": 0.1 15 | }, 16 | "optim_name": "Adam", 17 | "optim_args": { 18 | "lr": 1.6e-3, 19 | "weight_decay": 1e-4, 20 | "betas": [0.9, 0.999] 21 | }, 22 | "non_wd_for_bn": true, 23 | "scheduler_name": "MultiStepLR", 24 | "scheduler_args": { 25 | "milestones": [30, 60, 80], 26 | "gamma": 0.1 27 | }, 28 | "warmup_epochs": 5, 29 | "warmup_scheduler_name": "GradualWarmupIterLR", 30 | "warmup_scheduler_args": { 31 | "initial_lr": 1.25e-5, 32 | "max_count": 1565 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /examples/distributed/classification/configs/imagenet/resnet18_kfac_bs4k_128gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "ImageNet", 3 | "epochs": 61, 4 | "batch_size": 32, 5 | "val_batch_size": 32, 6 | "random_resized_crop": false, 7 | "random_crop": true, 8 | "random_horizontal_flip": true, 9 | "normalizing_data": true, 10 | "arch_file": "models/resnet_b.py", 11 | "arch_name": "resnet18", 12 | "arch_args": { 13 | "zero_init_residual": false, 14 | "norm_stat_momentum": 0.1 15 | }, 16 | "optim_name": "DistributedSecondOrderOptimizer", 17 | "optim_args": { 18 | "curv_type": "Fisher", 19 | "curv_shapes": { 20 | "Conv2d": "Kron", 21 | "Linear": "Kron", 22 | "BatchNorm1d": "Diag", 23 | "BatchNorm2d": "Diag" 24 | }, 25 | "lr": 1.6e-3, 26 | "l2_reg": 1e-4, 27 | "momentum": 0.9, 28 | "momentum_type": "raw", 29 | "non_reg_for_bn": true, 30 | "acc_steps": 1 31 | }, 32 | "curv_args": { 33 | "damping": 1e-4, 34 | "ema_decay": 1, 35 | "pi_type": "tracenorm" 36 | }, 37 | "fisher_args": { 38 | "approx_type": "mc", 39 | "num_mc": 1 40 | }, 41 | "momentum_correction": true, 42 | "scheduler_name": "MultiStepLR", 43 | "scheduler_args": { 44 | "milestones": [15, 30, 45], 45 | "gamma": 0.1 46 | }, 47 | "warmup_epochs": 5, 48 | "warmup_scheduler_name": "GradualWarmupIterLR", 49 | "warmup_scheduler_args": { 50 | "initial_lr": 1.25e-5, 51 | "max_count": 1565 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /examples/distributed/classification/configs/imagenet/resnet18_kfac_bs4k_4gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "ImageNet10", 3 | "epochs": 30, 4 | "batch_size": 32, 5 | "val_batch_size": 32, 6 | "random_resized_crop": false, 7 | "random_crop": true, 8 | "random_horizontal_flip": true, 9 | "normalizing_data": true, 10 | "arch_file": "models/resnet_b.py", 11 | "arch_name": "resnet18", 12 | "arch_args": { 13 | "zero_init_residual": false, 14 | "norm_stat_momentum": 0.1 15 | }, 16 | "optim_name": "DistributedSecondOrderOptimizer", 17 | "optim_args": { 18 | "curv_type": "Fisher", 19 | "curv_shapes": { 20 | "Conv2d": "Kron", 21 | "Linear": "Kron", 22 | "BatchNorm1d": "Diag", 23 | "BatchNorm2d": "Diag" 24 | }, 25 | "lr": 1.6e-3, 26 | "momentum": 0.9, 27 | "momentum_type": "raw", 28 | "l2_reg": 1e-4, 29 | "non_reg_for_bn": true, 30 | "acc_steps": 32 31 | }, 32 | "curv_args": { 33 | "damping": 1e-4, 34 | "ema_decay": 1, 35 | "pi_type": "tracenorm" 36 | }, 37 | "fisher_args": { 38 | "approx_type": "mc", 39 | "num_mc": 1 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /examples/distributed/classification/configs/imagenet/resnet18_noisykfac_bs4k_128gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "ImageNet", 3 | "epochs": 61, 4 | "batch_size": 32, 5 | "val_batch_size": 32, 6 | "random_resized_crop": false, 7 | "random_crop": true, 8 | "random_horizontal_flip": true, 9 | "dataset_size_scale": 5, 10 | "normalizing_data": true, 11 | "arch_file": "models/resnet_b.py", 12 | "arch_name": "resnet18", 13 | "arch_args": { 14 | "zero_init_residual": false, 15 | "norm_stat_momentum": 0.1 16 | }, 17 | "optim_name": "DistributedVIOptimizer", 18 | "optim_args": { 19 | "curv_type": "Fisher", 20 | "curv_shapes": { 21 | "Conv2d": "Kron", 22 | "Linear": "Kron", 23 | "BatchNorm1d": "Diag", 24 | "BatchNorm2d": "Diag" 25 | }, 26 | "lr": 1.6e-3, 27 | "momentum": 0.9, 28 | "momentum_type": "raw", 29 | "num_mc_samples": 1, 30 | "val_num_mc_samples": 10, 31 | "kl_weighting": 1, 32 | "prior_variance": 7.5e-3, 33 | "non_reg_for_bn": true, 34 | "acc_steps": 1 35 | }, 36 | "curv_args": { 37 | "damping": 1e-4, 38 | "ema_decay": 0.9, 39 | "pi_type": "tracenorm" 40 | }, 41 | "fisher_args": { 42 | "approx_type": "mc", 43 | "num_mc": 1 44 | }, 45 | "momentum_correction": true, 46 | "scheduler_name": "MultiStepLR", 47 | "scheduler_args": { 48 | "milestones": [15, 30, 45], 49 | "gamma": 0.1 50 | }, 51 | "warmup_epochs": 5, 52 | "warmup_scheduler_name": "GradualWarmupIterLR", 53 | "warmup_scheduler_args": { 54 | "initial_lr": 1.25e-5, 55 | "max_count": 1565 56 | }, 57 | "num_mc_groups": 128 58 | } 59 | -------------------------------------------------------------------------------- /examples/distributed/classification/configs/imagenet/resnet18_noisykfac_bs4k_4gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "ImageNet10", 3 | "epochs": 30, 4 | "batch_size": 32, 5 | "val_batch_size": 32, 6 | "random_resized_crop": false, 7 | "random_crop": true, 8 | "random_horizontal_flip": true, 9 | "normalizing_data": true, 10 | "arch_file": "models/resnet_b.py", 11 | "arch_name": "resnet18", 12 | "arch_args": { 13 | "zero_init_residual": false, 14 | "norm_stat_momentum": 0.1 15 | }, 16 | "optim_name": "DistributedVIOptimizer", 17 | "optim_args": { 18 | "curv_type": "Fisher", 19 | "curv_shapes": { 20 | "Conv2d": "Kron", 21 | "Linear": "Kron", 22 | "BatchNorm1d": "Diag", 23 | "BatchNorm2d": "Diag" 24 | }, 25 | "lr": 1.6e-3, 26 | "momentum": 0.9, 27 | "momentum_type": "raw", 28 | "num_mc_samples": 1, 29 | "val_num_mc_samples": 0, 30 | "kl_weighting": 1, 31 | "prior_variance": 0.75, 32 | "non_reg_for_bn": true, 33 | "acc_steps": 32 34 | }, 35 | "curv_args": { 36 | "damping": 1e-4, 37 | "ema_decay": 1, 38 | "pi_type": "tracenorm" 39 | }, 40 | "fisher_args": { 41 | "approx_type": "mc", 42 | "num_mc": 1 43 | }, 44 | "num_mc_groups": 1 45 | } 46 | -------------------------------------------------------------------------------- /examples/distributed/classification/configs/imagenet/resnet18_sgd_bs4k_128gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "ImageNet", 3 | "epochs": 91, 4 | "batch_size": 32, 5 | "val_batch_size": 391, 6 | "random_resized_crop": false, 7 | "random_crop": true, 8 | "random_horizontal_flip": true, 9 | "normalizing_data": true, 10 | "arch_file": "models/resnet_b.py", 11 | "arch_name": "resnet18", 12 | "arch_args": { 13 | "zero_init_residual": true, 14 | "norm_stat_momentum": 0.1 15 | }, 16 | "optim_name": "SGD", 17 | "optim_args": { 18 | "lr": 1.6e-1, 19 | "momentum": 0.9, 20 | "weight_decay": 1e-4 21 | }, 22 | "momentum_correction": true, 23 | "non_wd_for_bn": true, 24 | "scheduler_name": "MultiStepLR", 25 | "scheduler_args": { 26 | "milestones": [30, 60, 80], 27 | "gamma": 0.1 28 | }, 29 | "warmup_epochs": 5, 30 | "warmup_scheduler_name": "GradualWarmupIterLR", 31 | "warmup_scheduler_args": { 32 | "initial_lr": 1.25e-3, 33 | "max_count": 1565 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /examples/distributed/classification/configs/imagenet/resnet18_vogn_bs4k_128gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "ImageNet", 3 | "epochs": 91, 4 | "batch_size": 32, 5 | "val_batch_size": 32, 6 | "random_resized_crop": false, 7 | "random_crop": false, 8 | "random_horizontal_flip": false, 9 | "dataset_size_scale": 5, 10 | "normalizing_data": true, 11 | "arch_file": "models/resnet_b.py", 12 | "arch_name": "resnet18", 13 | "arch_args": { 14 | "zero_init_residual": false, 15 | "norm_stat_momentum": 0.1 16 | }, 17 | "optim_name": "DistributedVIOptimizer", 18 | "optim_args": { 19 | "curv_type": "Cov", 20 | "curv_shapes": { 21 | "Conv2d": "Diag", 22 | "Linear": "Diag", 23 | "BatchNorm1d": "Diag", 24 | "BatchNorm2d": "Diag" 25 | }, 26 | "lr": 1.6e-3, 27 | "momentum": 0.9, 28 | "momentum_type": "raw", 29 | "num_mc_samples": 1, 30 | "val_num_mc_samples": 10, 31 | "kl_weighting": 1, 32 | "prior_variance": 7.5e-3, 33 | "non_reg_for_bn": true 34 | }, 35 | "curv_args": { 36 | "damping": 1e-4, 37 | "ema_decay": 0.9 38 | }, 39 | "momentum_correction": true, 40 | "scheduler_name": "MultiStepLR", 41 | "scheduler_args": { 42 | "milestones": [30, 60, 80], 43 | "gamma": 0.1 44 | }, 45 | "warmup_epochs": 5, 46 | "warmup_scheduler_name": "GradualWarmupIterLR", 47 | "warmup_scheduler_args": { 48 | "initial_lr": 1.25e-5, 49 | "max_count": 1565 50 | }, 51 | "num_mc_groups": 128 52 | } 53 | -------------------------------------------------------------------------------- /examples/distributed/classification/configs/imagenet/resnet18_vogn_bs4k_4gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "ImageNet", 3 | "epochs": 90, 4 | "batch_size": 32, 5 | "val_batch_size": 32, 6 | "random_resized_crop": false, 7 | "random_crop": true, 8 | "random_horizontal_flip": true, 9 | "normalizing_data": true, 10 | "arch_file": "models/resnet_b.py", 11 | "arch_name": "resnet18", 12 | "arch_args": { 13 | "zero_init_residual": true, 14 | "norm_stat_momentum": 0.1 15 | }, 16 | "optim_name": "DistributedVIOptimizer", 17 | "optim_args": { 18 | "curv_type": "Cov", 19 | "curv_shapes": { 20 | "Conv2d": "Diag", 21 | "Linear": "Diag", 22 | "BatchNorm1d": "Diag", 23 | "BatchNorm2d": "Diag" 24 | }, 25 | "lr": 1.6e-2, 26 | "grad_ema_decay": 0.1, 27 | "grad_ema_type": "raw", 28 | "bias_correction": true, 29 | "non_reg_for_bn": true, 30 | "num_mc_samples": 1, 31 | "val_num_mc_samples": 10, 32 | "kl_weighting": 1, 33 | "prior_variance": 7.5e-4, 34 | "weight_decay": 0, 35 | "lars": false 36 | }, 37 | "curv_args": { 38 | "damping": 1e-4, 39 | "ema_decay": 1e-3 40 | }, 41 | "momentum_correction": true, 42 | "scheduler_name": "MultiStepLR", 43 | "scheduler_args": { 44 | "milestones": [30, 60, 80], 45 | "gamma": 0.1 46 | }, 47 | "warmup_epochs": 5, 48 | "warmup_scheduler_name": "GradualWarmupIterLR", 49 | "warmup_scheduler_args": { 50 | "initial_lr": 1.25e-4, 51 | "max_count": 1565 52 | }, 53 | "num_mc_groups": 4 54 | } 55 | -------------------------------------------------------------------------------- /examples/distributed/classification/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import * 2 | from .lenet import * 3 | from .resnet import * 4 | from .resnext import * 5 | from .alexnet import * 6 | -------------------------------------------------------------------------------- /examples/distributed/classification/models/alexnet.py: -------------------------------------------------------------------------------- 1 | '''AlexNet for CIFAR10. FC layers are removed. Paddings are adjusted. 2 | Without BN, the start learning rate should be 0.01 3 | (c) YANG, Wei 4 | ''' 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torchsso.utils.accumulator import TensorAccumulator 8 | 9 | 10 | __all__ = ['alexnet', 'alexnet_mcdropout'] 11 | 12 | 13 | class AlexNet(nn.Module): 14 | 15 | def __init__(self, num_classes=10): 16 | super().__init__() 17 | self.conv1 = nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5) 18 | self.conv2 = nn.Conv2d(64, 192, kernel_size=5, padding=2) 19 | self.conv3 = nn.Conv2d(192, 384, kernel_size=3, padding=1) 20 | self.conv4 = nn.Conv2d(384, 256, kernel_size=3, padding=1) 21 | self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 22 | self.fc = nn.Linear(256, num_classes) 23 | 24 | def forward(self, x): 25 | x = F.relu(self.conv1(x), inplace=True) 26 | x = F.max_pool2d(x, kernel_size=2, stride=2) 27 | x = F.relu(self.conv2(x), inplace=True) 28 | x = F.max_pool2d(x, kernel_size=2, stride=2) 29 | x = F.relu(self.conv3(x), inplace=True) 30 | x = F.relu(self.conv4(x), inplace=True) 31 | x = F.relu(self.conv5(x), inplace=True) 32 | x = F.max_pool2d(x, kernel_size=2, stride=2) 33 | x = x.view(x.size(0), -1) 34 | x = self.fc(x) 35 | return x 36 | 37 | 38 | class AlexNet2(nn.Module): 39 | 40 | def __init__(self, num_classes=10): 41 | super(AlexNet2, self).__init__() 42 | self.features = nn.Sequential( 43 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5), 44 | nn.ReLU(inplace=True), 45 | nn.MaxPool2d(kernel_size=2, stride=2), 46 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 47 | nn.ReLU(inplace=True), 48 | nn.MaxPool2d(kernel_size=2, stride=2), 49 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 50 | nn.ReLU(inplace=True), 51 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 52 | nn.ReLU(inplace=True), 53 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 54 | nn.ReLU(inplace=True), 55 | nn.MaxPool2d(kernel_size=2, stride=2), 56 | ) 57 | self.classifier = nn.Linear(256, num_classes) 58 | 59 | def forward(self, x): 60 | x = self.features(x) 61 | x = x.view(x.size(0), -1) 62 | x = self.classifier(x) 63 | return x 64 | 65 | 66 | class AlexNetMCDropout(AlexNet): 67 | 68 | def __init__(self, num_classes=10, dropout_ratio=0.5, val_mc=10): 69 | super(AlexNetMCDropout, self).__init__(num_classes) 70 | self.dropout_ratio = dropout_ratio 71 | self.val_mc = val_mc 72 | 73 | def forward(self, x): 74 | dropout_ratio = self.dropout_ratio 75 | x = F.relu(F.dropout(self.conv1(x), p=dropout_ratio), inplace=True) 76 | x = F.max_pool2d(x, kernel_size=2, stride=2) 77 | x = F.relu(F.dropout(self.conv2(x), p=dropout_ratio), inplace=True) 78 | x = F.max_pool2d(x, kernel_size=2, stride=2) 79 | x = F.relu(F.dropout(self.conv3(x), p=dropout_ratio), inplace=True) 80 | x = F.relu(F.dropout(self.conv4(x), p=dropout_ratio), inplace=True) 81 | x = F.relu(F.dropout(self.conv5(x), p=dropout_ratio), inplace=True) 82 | x = F.max_pool2d(x, kernel_size=2, stride=2) 83 | x = x.view(x.size(0), -1) 84 | x = self.fc(x) 85 | return x 86 | 87 | def mc_prediction(self, x): 88 | 89 | acc_prob = TensorAccumulator() 90 | m = self.val_mc 91 | 92 | for _ in range(m): 93 | output = self.forward(x) 94 | prob = F.softmax(output, dim=1) 95 | acc_prob.update(prob, scale=1/m) 96 | 97 | prob = acc_prob.get() 98 | 99 | return prob 100 | 101 | 102 | def alexnet(**kwargs): 103 | r"""AlexNet model architecture from the 104 | `"One weird trick..." `_ paper. 105 | """ 106 | model = AlexNet(**kwargs) 107 | return model 108 | 109 | 110 | def alexnet_mcdropout(**kwargs): 111 | model = AlexNetMCDropout(**kwargs) 112 | return model 113 | 114 | -------------------------------------------------------------------------------- /examples/distributed/classification/models/lenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torchsso.utils.accumulator import TensorAccumulator 4 | 5 | 6 | class LeNet5(nn.Module): 7 | 8 | def __init__(self, num_classes=10): 9 | super().__init__() 10 | self.conv1 = nn.Conv2d(3, 6, 5) 11 | self.conv2 = nn.Conv2d(6, 16, 5) 12 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 13 | self.fc2 = nn.Linear(120, 84) 14 | self.fc3 = nn.Linear(84, num_classes) 15 | 16 | def forward(self, x): 17 | out = F.relu(self.conv1(x)) 18 | out = F.max_pool2d(out, 2) 19 | out = F.relu(self.conv2(out)) 20 | out = F.max_pool2d(out, 2) 21 | out = out.view(out.size(0), -1) 22 | out = F.relu(self.fc1(out)) 23 | out = F.relu(self.fc2(out)) 24 | out = self.fc3(out) 25 | return out 26 | 27 | 28 | class LeNet5MCDropout(LeNet5): 29 | 30 | def __init__(self, num_classes=10, dropout_ratio=0.1, val_mc=10): 31 | super(LeNet5MCDropout, self).__init__(num_classes=num_classes) 32 | self.dropout_ratio = dropout_ratio 33 | self.val_mc = val_mc 34 | 35 | def forward(self, x): 36 | p = self.dropout_ratio 37 | out = F.relu(F.dropout(self.conv1(x), p)) 38 | out = F.max_pool2d(out, 2) 39 | out = F.relu(F.dropout(self.conv2(out), p)) 40 | out = F.max_pool2d(out, 2) 41 | out = out.view(out.size(0), -1) 42 | out = F.relu(F.dropout(self.fc1(out), p)) 43 | out = F.relu(F.dropout(self.fc2(out), p)) 44 | out = F.dropout(self.fc3(out), p) 45 | return out 46 | 47 | def mc_prediction(self, x): 48 | 49 | acc_prob = TensorAccumulator() 50 | m = self.val_mc 51 | 52 | for _ in range(m): 53 | output = self.forward(x) 54 | prob = F.softmax(output, dim=1) 55 | acc_prob.update(prob, scale=1/m) 56 | 57 | prob = acc_prob.get() 58 | 59 | return prob 60 | 61 | 62 | class LeNet5BatchNorm(nn.Module): 63 | def __init__(self, num_classes=10, affine=False): 64 | super().__init__() 65 | self.conv1 = nn.Conv2d(3, 6, 5) 66 | self.bn1 = nn.BatchNorm2d(6, affine=affine) 67 | self.conv2 = nn.Conv2d(6, 16, 5) 68 | self.bn2 = nn.BatchNorm2d(16, affine=affine) 69 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 70 | self.bn3 = nn.BatchNorm1d(120, affine=affine) 71 | self.fc2 = nn.Linear(120, 84) 72 | self.bn4 = nn.BatchNorm1d(84, affine=affine) 73 | self.fc3 = nn.Linear(84, num_classes) 74 | 75 | def forward(self, x): 76 | out = F.relu(self.bn1(self.conv1(x))) 77 | out = F.max_pool2d(out, 2) 78 | out = F.relu(self.bn2(self.conv2(out))) 79 | out = F.max_pool2d(out, 2) 80 | out = out.view(out.size(0), -1) 81 | out = F.relu(self.bn3(self.fc1(out))) 82 | out = F.relu(self.bn4(self.fc2(out))) 83 | out = self.fc3(out) 84 | return out 85 | -------------------------------------------------------------------------------- /examples/distributed/classification/models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(planes,track_running_stats=False) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes,track_running_stats=False) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*planes: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*planes) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = self.bn2(self.conv2(out)) 34 | out += self.shortcut(x) 35 | out = F.relu(out) 36 | return out 37 | 38 | 39 | class Bottleneck(nn.Module): 40 | expansion = 4 41 | 42 | def __init__(self, in_planes, planes, stride=1): 43 | super(Bottleneck, self).__init__() 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(planes,track_running_stats=False) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn2 = nn.BatchNorm2d(planes,track_running_stats=False) 48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 49 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 50 | 51 | self.shortcut = nn.Sequential() 52 | if stride != 1 or in_planes != self.expansion*planes: 53 | self.shortcut = nn.Sequential( 54 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 55 | nn.BatchNorm2d(self.expansion*planes) 56 | ) 57 | 58 | def forward(self, x): 59 | out = F.relu(self.bn1(self.conv1(x))) 60 | out = F.relu(self.bn2(self.conv2(out))) 61 | out = self.bn3(self.conv3(out)) 62 | out += self.shortcut(x) 63 | out = F.relu(out) 64 | return out 65 | 66 | 67 | class ResNet(nn.Module): 68 | def __init__(self, block, num_blocks, num_classes=10): 69 | super(ResNet, self).__init__() 70 | self.in_planes = 64 71 | 72 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 73 | self.bn1 = nn.BatchNorm2d(64) 74 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 75 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 76 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 77 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 78 | self.linear = nn.Linear(512*block.expansion, num_classes) 79 | 80 | def _make_layer(self, block, planes, num_blocks, stride): 81 | strides = [stride] + [1]*(num_blocks-1) 82 | layers = [] 83 | for stride in strides: 84 | layers.append(block(self.in_planes, planes, stride)) 85 | self.in_planes = planes * block.expansion 86 | return nn.Sequential(*layers) 87 | 88 | def forward(self, x): 89 | out = F.relu(self.bn1(self.conv1(x))) 90 | out = self.layer1(out) 91 | out = self.layer2(out) 92 | out = self.layer3(out) 93 | out = self.layer4(out) 94 | out = F.avg_pool2d(out, 4) 95 | out = out.view(out.size(0), -1) 96 | out = self.linear(out) 97 | return out 98 | 99 | 100 | def ResNet18(num_classes=10): 101 | return ResNet(BasicBlock, [2,2,2,2], num_classes) 102 | 103 | def ResNet34(num_classes=10): 104 | return ResNet(BasicBlock, [3,4,6,3], num_classes) 105 | 106 | def ResNet50(num_classes=10): 107 | return ResNet(Bottleneck, [3,4,6,3], num_classes) 108 | 109 | def ResNet101(num_classes=10): 110 | return ResNet(Bottleneck, [3,4,23,3], num_classes) 111 | 112 | def ResNet152(num_classe=10): 113 | return ResNet(Bottleneck, [3,8,36,3], num_classes) 114 | 115 | 116 | def test(): 117 | net = ResNet18() 118 | y = net(torch.randn(1,3,32,32)) 119 | print(y.size()) 120 | 121 | # test() 122 | -------------------------------------------------------------------------------- /examples/distributed/classification/models/resnet_b.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 7 | 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, groups=groups, bias=False) 22 | 23 | 24 | def conv1x1(in_planes, out_planes, stride=1): 25 | """1x1 convolution""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 33 | base_width=64, norm_layer=None, norm_stat_momentum=0.1): 34 | super(BasicBlock, self).__init__() 35 | if norm_layer is None: 36 | norm_layer = nn.BatchNorm2d 37 | if groups != 1 or base_width != 64: 38 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 39 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 40 | self.conv1 = conv3x3(inplanes, planes, stride) 41 | self.bn1 = norm_layer(planes, momentum=norm_stat_momentum) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.conv2 = conv3x3(planes, planes) 44 | self.bn2 = norm_layer(planes, momentum=norm_stat_momentum) 45 | self.downsample = downsample 46 | self.stride = stride 47 | 48 | def forward(self, x): 49 | identity = x 50 | 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | 58 | if self.downsample is not None: 59 | identity = self.downsample(x) 60 | 61 | out += identity 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class Bottleneck(nn.Module): 68 | expansion = 4 69 | 70 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 71 | base_width=64, norm_layer=None, norm_stat_momentum=0.1): 72 | super(Bottleneck, self).__init__() 73 | if norm_layer is None: 74 | norm_layer = nn.BatchNorm2d 75 | width = int(planes * (base_width / 64.)) * groups 76 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 77 | self.conv1 = conv1x1(inplanes, width) 78 | self.bn1 = norm_layer(width, momentum=norm_stat_momentum) 79 | self.conv2 = conv3x3(width, width, stride, groups) 80 | self.bn2 = norm_layer(width, momentum=norm_stat_momentum) 81 | self.conv3 = conv1x1(width, planes * self.expansion) 82 | self.bn3 = norm_layer(planes * self.expansion, momentum=norm_stat_momentum) 83 | self.relu = nn.ReLU(inplace=True) 84 | self.downsample = downsample 85 | self.stride = stride 86 | 87 | def forward(self, x): 88 | identity = x 89 | 90 | out = self.conv1(x) 91 | out = self.bn1(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv2(out) 95 | out = self.bn2(out) 96 | out = self.relu(out) 97 | 98 | out = self.conv3(out) 99 | out = self.bn3(out) 100 | 101 | if self.downsample is not None: 102 | identity = self.downsample(x) 103 | 104 | out += identity 105 | out = self.relu(out) 106 | 107 | return out 108 | 109 | 110 | class ResNet(nn.Module): 111 | 112 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 113 | groups=1, width_per_group=64, norm_layer=None, norm_stat_momentum=0.1): 114 | super(ResNet, self).__init__() 115 | if norm_layer is None: 116 | norm_layer = nn.BatchNorm2d 117 | 118 | self.inplanes = 64 119 | self.groups = groups 120 | self.base_width = width_per_group 121 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 122 | bias=False) 123 | self.bn1 = norm_layer(self.inplanes, momentum=norm_stat_momentum) 124 | self.relu = nn.ReLU(inplace=True) 125 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 126 | self.layer1 = self._make_layer(block, 64, layers[0], 127 | norm_layer=norm_layer, norm_stat_momentum=norm_stat_momentum) 128 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 129 | norm_layer=norm_layer, norm_stat_momentum=norm_stat_momentum) 130 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 131 | norm_layer=norm_layer, norm_stat_momentum=norm_stat_momentum) 132 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 133 | norm_layer=norm_layer, norm_stat_momentum=norm_stat_momentum) 134 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 135 | self.fc = nn.Linear(512 * block.expansion, num_classes) 136 | 137 | for m in self.modules(): 138 | if isinstance(m, nn.Conv2d): 139 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 140 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 141 | nn.init.constant_(m.weight, 1) 142 | nn.init.constant_(m.bias, 0) 143 | 144 | # Zero-initialize the last BN in each residual branch, 145 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 146 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 147 | if zero_init_residual: 148 | for m in self.modules(): 149 | if isinstance(m, Bottleneck): 150 | nn.init.constant_(m.bn3.weight, 0) 151 | elif isinstance(m, BasicBlock): 152 | nn.init.constant_(m.bn2.weight, 0) 153 | 154 | def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None, norm_stat_momentum=0.1): 155 | if norm_layer is None: 156 | norm_layer = nn.BatchNorm2d 157 | downsample = None 158 | if stride != 1 or self.inplanes != planes * block.expansion: 159 | downsample = nn.Sequential( 160 | conv1x1(self.inplanes, planes * block.expansion, stride), 161 | norm_layer(planes * block.expansion, momentum=norm_stat_momentum), 162 | ) 163 | 164 | layers = [] 165 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 166 | self.base_width, norm_layer, norm_stat_momentum)) 167 | self.inplanes = planes * block.expansion 168 | for _ in range(1, blocks): 169 | layers.append(block(self.inplanes, planes, groups=self.groups, 170 | base_width=self.base_width, 171 | norm_layer=norm_layer, norm_stat_momentum=norm_stat_momentum)) 172 | 173 | return nn.Sequential(*layers) 174 | 175 | def forward(self, x): 176 | x = self.conv1(x) 177 | x = self.bn1(x) 178 | x = self.relu(x) 179 | x = self.maxpool(x) 180 | 181 | x = self.layer1(x) 182 | x = self.layer2(x) 183 | x = self.layer3(x) 184 | x = self.layer4(x) 185 | 186 | x = self.avgpool(x) 187 | x = x.view(x.size(0), -1) 188 | x = self.fc(x) 189 | 190 | return x 191 | 192 | 193 | def resnet18(pretrained=False, **kwargs): 194 | """Constructs a ResNet-18 model. 195 | 196 | Args: 197 | pretrained (bool): If True, returns a model pre-trained on ImageNet 198 | """ 199 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 200 | if pretrained: 201 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 202 | return model 203 | 204 | 205 | def resnet34(pretrained=False, **kwargs): 206 | """Constructs a ResNet-34 model. 207 | 208 | Args: 209 | pretrained (bool): If True, returns a model pre-trained on ImageNet 210 | """ 211 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 212 | if pretrained: 213 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 214 | return model 215 | 216 | 217 | def resnet50(pretrained=False, **kwargs): 218 | """Constructs a ResNet-50 model. 219 | 220 | Args: 221 | pretrained (bool): If True, returns a model pre-trained on ImageNet 222 | """ 223 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 224 | if pretrained: 225 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 226 | return model 227 | 228 | 229 | def resnet101(pretrained=False, **kwargs): 230 | """Constructs a ResNet-101 model. 231 | 232 | Args: 233 | pretrained (bool): If True, returns a model pre-trained on ImageNet 234 | """ 235 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 236 | if pretrained: 237 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 238 | return model 239 | 240 | 241 | def resnet152(pretrained=False, **kwargs): 242 | """Constructs a ResNet-152 model. 243 | 244 | Args: 245 | pretrained (bool): If True, returns a model pre-trained on ImageNet 246 | """ 247 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 248 | if pretrained: 249 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 250 | return model 251 | 252 | 253 | def resnext50_32x4d(pretrained=False, **kwargs): 254 | model = ResNet(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4, **kwargs) 255 | # if pretrained: 256 | # model.load_state_dict(model_zoo.load_url(model_urls['resnext50_32x4d'])) 257 | return model 258 | 259 | 260 | def resnext101_32x8d(pretrained=False, **kwargs): 261 | model = ResNet(Bottleneck, [3, 4, 23, 3], groups=32, width_per_group=8, **kwargs) 262 | # if pretrained: 263 | # model.load_state_dict(model_zoo.load_url(model_urls['resnext101_32x8d'])) 264 | return model 265 | -------------------------------------------------------------------------------- /examples/distributed/classification/models/resnext.py: -------------------------------------------------------------------------------- 1 | '''ResNeXt in PyTorch. 2 | 3 | See the paper "Aggregated Residual Transformations for Deep Neural Networks" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class Block(nn.Module): 11 | '''Grouped convolution block.''' 12 | expansion = 2 13 | 14 | def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1): 15 | super(Block, self).__init__() 16 | group_width = cardinality * bottleneck_width 17 | self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(group_width) 19 | self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 20 | self.bn2 = nn.BatchNorm2d(group_width) 21 | self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=False) 22 | self.bn3 = nn.BatchNorm2d(self.expansion*group_width) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*group_width: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*group_width) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = F.relu(self.bn2(self.conv2(out))) 34 | out = self.bn3(self.conv3(out)) 35 | out += self.shortcut(x) 36 | out = F.relu(out) 37 | return out 38 | 39 | 40 | class ResNeXt(nn.Module): 41 | def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes=10): 42 | super(ResNeXt, self).__init__() 43 | self.cardinality = cardinality 44 | self.bottleneck_width = bottleneck_width 45 | self.in_planes = 64 46 | 47 | self.conv1 = nn.Conv2d(3, 64, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(64) 49 | self.layer1 = self._make_layer(num_blocks[0], 1) 50 | self.layer2 = self._make_layer(num_blocks[1], 2) 51 | self.layer3 = self._make_layer(num_blocks[2], 2) 52 | # self.layer4 = self._make_layer(num_blocks[3], 2) 53 | self.linear = nn.Linear(cardinality*bottleneck_width*8, num_classes) 54 | 55 | def _make_layer(self, num_blocks, stride): 56 | strides = [stride] + [1]*(num_blocks-1) 57 | layers = [] 58 | for stride in strides: 59 | layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride)) 60 | self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width 61 | # Increase bottleneck_width by 2 after each stage. 62 | self.bottleneck_width *= 2 63 | return nn.Sequential(*layers) 64 | 65 | def forward(self, x): 66 | out = F.relu(self.bn1(self.conv1(x))) 67 | out = self.layer1(out) 68 | out = self.layer2(out) 69 | out = self.layer3(out) 70 | # out = self.layer4(out) 71 | out = F.avg_pool2d(out, 8) 72 | out = out.view(out.size(0), -1) 73 | out = self.linear(out) 74 | return out 75 | 76 | 77 | def ResNeXt29_2x64d(num_classes=10): 78 | return ResNeXt(num_blocks=[3,3,3], cardinality=2, bottleneck_width=64, num_classes=num_classes) 79 | 80 | def ResNeXt29_4x64d(num_classes=10): 81 | return ResNeXt(num_blocks=[3,3,3], cardinality=4, bottleneck_width=64, num_classes=num_classes) 82 | 83 | def ResNeXt29_8x64d(num_classes=10): 84 | return ResNeXt(num_blocks=[3,3,3], cardinality=8, bottleneck_width=64, num_classes=num_classes) 85 | 86 | def ResNeXt29_32x4d(num_classes=10): 87 | return ResNeXt(num_blocks=[3,3,3], cardinality=32, bottleneck_width=4, num_classes=num_classes) 88 | 89 | def test_resnext(): 90 | net = ResNeXt29_2x64d() 91 | x = torch.randn(1,3,32,32) 92 | y = net(x) 93 | print(y.size()) 94 | 95 | # test_resnext() 96 | -------------------------------------------------------------------------------- /examples/distributed/classification/models/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG11/13/16/19 in Pytorch.''' 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchsso.utils.accumulator import TensorAccumulator 5 | 6 | 7 | cfg = { 8 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 10 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 11 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 12 | } 13 | 14 | 15 | class VGG(nn.Module): 16 | def __init__(self, num_classes=10, vgg_name='VGG19'): 17 | super(VGG, self).__init__() 18 | self.features = self._make_layers(cfg[vgg_name]) 19 | self.classifier = nn.Linear(512, num_classes) 20 | 21 | def forward(self, x): 22 | out = self.features(x) 23 | out = out.view(out.size(0), -1) 24 | out = self.classifier(out) 25 | return out 26 | 27 | def _make_layers(self, cfg): 28 | layers = [] 29 | in_channels = 3 30 | for x in cfg: 31 | if x == 'M': 32 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 33 | else: 34 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 35 | nn.BatchNorm2d(x), 36 | nn.ReLU(inplace=True)] 37 | in_channels = x 38 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 39 | return nn.Sequential(*layers) 40 | 41 | 42 | class VGG19(nn.Module): 43 | 44 | def __init__(self, num_classes=10): 45 | super(VGG19, self).__init__() 46 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1) 47 | self.bn1_1 = nn.BatchNorm2d(64) 48 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 49 | self.bn1_2 = nn.BatchNorm2d(64) 50 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 51 | self.bn2_1 = nn.BatchNorm2d(128) 52 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 53 | self.bn2_2 = nn.BatchNorm2d(128) 54 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 55 | self.bn3_1 = nn.BatchNorm2d(256) 56 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 57 | self.bn3_2 = nn.BatchNorm2d(256) 58 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 59 | self.bn3_3 = nn.BatchNorm2d(256) 60 | self.conv3_4 = nn.Conv2d(256, 256, 3, padding=1) 61 | self.bn3_4 = nn.BatchNorm2d(256) 62 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 63 | self.bn4_1 = nn.BatchNorm2d(512) 64 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 65 | self.bn4_2 = nn.BatchNorm2d(512) 66 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 67 | self.bn4_3 = nn.BatchNorm2d(512) 68 | self.conv4_4 = nn.Conv2d(512, 512, 3, padding=1) 69 | self.bn4_4 = nn.BatchNorm2d(512) 70 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) 71 | self.bn5_1 = nn.BatchNorm2d(512) 72 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) 73 | self.bn5_2 = nn.BatchNorm2d(512) 74 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) 75 | self.bn5_3 = nn.BatchNorm2d(512) 76 | self.conv5_4 = nn.Conv2d(512, 512, 3, padding=1) 77 | self.bn5_4 = nn.BatchNorm2d(512) 78 | self.fc = nn.Linear(512, num_classes) 79 | 80 | def forward(self, x): 81 | h = F.relu(self.bn1_1(self.conv1_1(x)), inplace=True) 82 | h = F.relu(self.bn1_2(self.conv1_2(h)), inplace=True) 83 | h = F.max_pool2d(h, kernel_size=2, stride=2) 84 | h = F.relu(self.bn2_1(self.conv2_1(h)), inplace=True) 85 | h = F.relu(self.bn2_2(self.conv2_2(h)), inplace=True) 86 | h = F.max_pool2d(h, kernel_size=2, stride=2) 87 | h = F.relu(self.bn3_1(self.conv3_1(h)), inplace=True) 88 | h = F.relu(self.bn3_2(self.conv3_2(h)), inplace=True) 89 | h = F.relu(self.bn3_3(self.conv3_3(h)), inplace=True) 90 | h = F.relu(self.bn3_4(self.conv3_4(h)), inplace=True) 91 | h = F.max_pool2d(h, kernel_size=2, stride=2) 92 | h = F.relu(self.bn4_1(self.conv4_1(h)), inplace=True) 93 | h = F.relu(self.bn4_2(self.conv4_2(h)), inplace=True) 94 | h = F.relu(self.bn4_3(self.conv4_3(h)), inplace=True) 95 | h = F.relu(self.bn4_4(self.conv4_4(h)), inplace=True) 96 | h = F.max_pool2d(h, kernel_size=2, stride=2) 97 | h = F.relu(self.bn5_1(self.conv5_1(h)), inplace=True) 98 | h = F.relu(self.bn5_2(self.conv5_2(h)), inplace=True) 99 | h = F.relu(self.bn5_3(self.conv5_3(h)), inplace=True) 100 | h = F.relu(self.bn5_4(self.conv5_4(h)), inplace=True) 101 | h = F.max_pool2d(h, kernel_size=2, stride=2) 102 | h = h.view(h.size(0), -1) 103 | out = self.fc(h) 104 | return out 105 | 106 | 107 | class VGG19MCDropout(VGG19): 108 | 109 | def __init__(self, num_classes=10, dropout_ratio=0.1, val_mc=10): 110 | super(VGG19MCDropout, self).__init__(num_classes) 111 | self.dropout_ratio = dropout_ratio 112 | self.val_mc = val_mc 113 | 114 | def forward(self, x): 115 | p = self.dropout_ratio 116 | h = F.relu(self.bn1_1(F.dropout(self.conv1_1(x), p)), inplace=True) 117 | h = F.relu(self.bn1_2(F.dropout(self.conv1_2(h), p)), inplace=True) 118 | h = F.max_pool2d(h, kernel_size=2, stride=2) 119 | h = F.relu(self.bn2_1(F.dropout(self.conv2_1(h), p)), inplace=True) 120 | h = F.relu(self.bn2_2(F.dropout(self.conv2_2(h), p)), inplace=True) 121 | h = F.max_pool2d(h, kernel_size=2, stride=2) 122 | h = F.relu(self.bn3_1(F.dropout(self.conv3_1(h), p)), inplace=True) 123 | h = F.relu(self.bn3_2(F.dropout(self.conv3_2(h), p)), inplace=True) 124 | h = F.relu(self.bn3_3(F.dropout(self.conv3_3(h), p)), inplace=True) 125 | h = F.relu(self.bn3_4(F.dropout(self.conv3_4(h), p)), inplace=True) 126 | h = F.max_pool2d(h, kernel_size=2, stride=2) 127 | h = F.relu(self.bn4_1(F.dropout(self.conv4_1(h), p)), inplace=True) 128 | h = F.relu(self.bn4_2(F.dropout(self.conv4_2(h), p)), inplace=True) 129 | h = F.relu(self.bn4_3(F.dropout(self.conv4_3(h), p)), inplace=True) 130 | h = F.relu(self.bn4_4(F.dropout(self.conv4_4(h), p)), inplace=True) 131 | h = F.max_pool2d(h, kernel_size=2, stride=2) 132 | h = F.relu(self.bn5_1(F.dropout(self.conv5_1(h), p)), inplace=True) 133 | h = F.relu(self.bn5_2(F.dropout(self.conv5_2(h), p)), inplace=True) 134 | h = F.relu(self.bn5_3(F.dropout(self.conv5_3(h), p)), inplace=True) 135 | h = F.relu(self.bn5_4(F.dropout(self.conv5_4(h), p)), inplace=True) 136 | h = F.max_pool2d(h, kernel_size=2, stride=2) 137 | h = h.view(h.size(0), -1) 138 | out = F.dropout(self.fc(h), p) 139 | return out 140 | 141 | def mc_prediction(self, x): 142 | 143 | acc_prob = TensorAccumulator() 144 | m = self.val_mc 145 | 146 | for _ in range(m): 147 | output = self.forward(x) 148 | prob = F.softmax(output, dim=1) 149 | acc_prob.update(prob, scale=1/m) 150 | 151 | prob = acc_prob.get() 152 | 153 | return prob 154 | -------------------------------------------------------------------------------- /pytorch_sso_poster.pptx.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybertronai/pytorch-sso/d16173236728fdd1e943bc3b61243bfb28f95348/pytorch_sso_poster.pptx.pdf -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = torchsso 3 | version = 0.1.1 4 | url = https://github.com/cybertronai/pytorch-sso 5 | author = Kazuki Osawa 6 | author_email = osawa1021@gmail.com 7 | license_file = LICENSE 8 | description = PyTorch-SSO: Scalable Second-Order Optimization Methods in PyTorch. 9 | long_description = file: README.md 10 | classifiers = 11 | Programming Language :: Python :: 3 12 | 13 | [options] 14 | zip_safe = False 15 | packages = find: 16 | install_requires = 17 | torch 18 | torchvision 19 | chainer 20 | Pillow 21 | numpy 22 | scipy 23 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | setup() 3 | -------------------------------------------------------------------------------- /tests/test_samplegrad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torchsso.autograd import save_sample_grads 6 | 7 | 8 | class LeNet5BatchNorm(nn.Module): 9 | def __init__(self, num_classes=10, affine=True): 10 | super().__init__() 11 | self.conv1 = nn.Conv2d(3, 6, 5) 12 | self.bn1 = nn.BatchNorm2d(6, affine=affine) 13 | self.conv2 = nn.Conv2d(6, 16, 5) 14 | self.bn2 = nn.BatchNorm2d(16, affine=affine) 15 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 16 | self.bn3 = nn.BatchNorm1d(120, affine=affine) 17 | self.fc2 = nn.Linear(120, 84) 18 | self.bn4 = nn.BatchNorm1d(84, affine=affine) 19 | self.fc3 = nn.Linear(84, num_classes) 20 | 21 | def forward(self, x): 22 | out = F.relu(self.bn1(self.conv1(x))) 23 | out = F.max_pool2d(out, 2) 24 | out = F.relu(self.bn2(self.conv2(out))) 25 | out = F.max_pool2d(out, 2) 26 | out = out.view(out.size(0), -1) 27 | out = F.relu(self.bn3(self.fc1(out))) 28 | out = F.relu(self.bn4(self.fc2(out))) 29 | out = self.fc3(out) 30 | return out 31 | 32 | 33 | def test_samplegrad(): 34 | model = LeNet5BatchNorm() 35 | n = 10 36 | c, h, w = 3, 32, 32 37 | x = torch.randn(n, c, h, w) 38 | 39 | with save_sample_grads(model): 40 | out = model(x) 41 | loss = out.sum() 42 | loss.backward() 43 | 44 | for module in model.children(): 45 | print(module) 46 | for p in module.parameters(): 47 | if p.requires_grad: 48 | error = (p.grads.sum(0) - p.grad).max() 49 | print(f'\t{p.size()} : error = {error}') 50 | 51 | 52 | if __name__ == '__main__': 53 | test_samplegrad() 54 | -------------------------------------------------------------------------------- /torchsso/__init__.py: -------------------------------------------------------------------------------- 1 | from torchsso import optim # NOQA 2 | from torchsso import autograd # NOQA 3 | from torchsso import utils # NOQA 4 | 5 | from torchsso.curv.curvature import Curvature, DiagCurvature, KronCurvature # NOQA 6 | from torchsso.curv.cov.linear import CovLinear, DiagCovLinear, KronCovLinear # NOQA 7 | from torchsso.curv.cov.conv import CovConv2d, DiagCovConv2d, KronCovConv2d # NOQA 8 | from torchsso.curv.cov.batchnorm import CovBatchNorm1d, DiagCovBatchNorm1d, CovBatchNorm2d, DiagCovBatchNorm2d # NOQA 9 | 10 | from torchsso.curv.hessian import KronHessian # NOQA 11 | from torchsso.curv.hessian.linear import KronHessianLinear # NOQA 12 | from torchsso.curv.hessian.conv import KronHessianConv2d # NOQA 13 | 14 | from torchsso.curv.fisher import get_closure_for_fisher # NOQA 15 | from torchsso.curv.fisher import Fisher # NOQA 16 | from torchsso.curv.fisher.linear import DiagFisherLinear, KronFisherLinear # NOQA 17 | from torchsso.curv.fisher.conv import DiagFisherConv2d, KronFisherConv2d # NOQA 18 | from torchsso.curv.fisher.batchnorm import DiagFisherBatchNorm2d # NOQA 19 | -------------------------------------------------------------------------------- /torchsso/autograd/__init__.py: -------------------------------------------------------------------------------- 1 | from torchsso.autograd.samplegrad import * # NOQA 2 | -------------------------------------------------------------------------------- /torchsso/autograd/samplegrad.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | @contextmanager 9 | def save_sample_grads(model: nn.Module): 10 | 11 | handles = [] 12 | for module in model.children(): 13 | params = list(module.parameters()) 14 | params = [p for p in params if p.requires_grad] 15 | if len(params) == 0: 16 | continue 17 | 18 | handles.append(module.register_forward_hook(_forward_postprocess)) 19 | handles.append(module.register_backward_hook(_backward_postprocess)) 20 | 21 | yield 22 | for handle in handles: 23 | handle.remove() 24 | 25 | 26 | def _forward_postprocess(module: nn.Module, input: torch.Tensor, output: torch.Tensor): 27 | data_input = input[0].clone().detach() 28 | 29 | if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 30 | bnorm = module 31 | f = bnorm.num_features 32 | if isinstance(module, nn.BatchNorm1d): 33 | shape = (1, f) 34 | elif isinstance(module, nn.BatchNorm2d): 35 | shape = (1, f, 1, 1) 36 | else: 37 | shape = (1, f, 1, 1, 1) 38 | # restore normalized input 39 | data_input_norm = (output - bnorm.bias.view(shape)).div(bnorm.weight.view(shape)) 40 | data_input = data_input_norm 41 | 42 | setattr(module, 'data_input', data_input) 43 | 44 | 45 | def _backward_postprocess(module: nn.Module, grad_input: torch.Tensor, grad_output: torch.Tensor): 46 | grad_output = grad_output[0].clone().detach() 47 | data_input = getattr(module, 'data_input', None) 48 | assert data_input is not None, 'backward is called before forward.' 49 | assert data_input.size(0) == grad_output.size(0) 50 | 51 | args = [module, data_input, grad_output] 52 | if isinstance(module, nn.Linear): 53 | grad_linear(*args) 54 | elif isinstance(module, nn.Conv2d): 55 | grad_conv2d(*args) 56 | elif isinstance(module, nn.BatchNorm1d): 57 | grad_batchnorm1d(*args) 58 | elif isinstance(module, nn.BatchNorm2d): 59 | grad_batchnorm2d(*args) 60 | else: 61 | raise ValueError(f'Unsupported module class: {module.__class__}.') 62 | 63 | 64 | def grad_linear(module: nn.Module, data_input: torch.Tensor, grad_output: torch.Tensor): 65 | 66 | assert isinstance(module, nn.Linear) 67 | linear = module 68 | assert data_input.ndimension() == 2 # n x f_in 69 | assert grad_output.ndimension() == 2 # n x f_out 70 | 71 | if linear.weight.requires_grad: 72 | grads = torch.einsum('bi,bj->bij', grad_output, data_input) # n x f_out x f_in 73 | setattr(linear.weight, 'grads', grads) # n x f_out x f_in 74 | 75 | if hasattr(linear, 'bias') and linear.bias.requires_grad: 76 | setattr(linear.bias, 'grads', grad_output) # n x f_out 77 | 78 | 79 | def grad_conv2d(module: nn.Module, data_input: torch.Tensor, grad_output: torch.Tensor): 80 | 81 | assert isinstance(module, nn.Conv2d) 82 | conv2d = module 83 | assert data_input.ndimension() == 4 # n x c_in x h_in x w_in 84 | assert grad_output.ndimension() == 4 # n x c_out x h_out x w_out 85 | 86 | if conv2d.weight.requires_grad: 87 | # n x (c_in)(k_h)(k_w) x (h_out)(w_out) 88 | input2d = F.unfold(data_input, 89 | kernel_size=conv2d.kernel_size, stride=conv2d.stride, 90 | padding=conv2d.padding, dilation=conv2d.dilation) 91 | 92 | # n x c_out x h_out x w_out 93 | n, c_out, h, w = grad_output.size() 94 | # n x c_out x (h_out)(w_out) 95 | grad_output2d = grad_output.view(n, c_out, -1) 96 | 97 | c_out, c_in, k_h, k_w = conv2d.weight.size() 98 | 99 | grads_2d = torch.einsum('bik,bjk->bij', grad_output2d, input2d) # n x c_out x (c_in)(k_h)(k_w) 100 | setattr(conv2d.weight, 'grads', grads_2d.view(n, c_out, c_in, k_h, k_w)) # n x c_out x c_in x k_h x k_w 101 | 102 | if hasattr(conv2d, 'bias') and conv2d.bias.requires_grad: 103 | setattr(conv2d.bias, 'grads', grad_output.sum(dim=(2, 3))) # n x c_out 104 | 105 | 106 | def grad_batchnorm1d(module: nn.Module, data_input: torch.Tensor, grad_output: torch.Tensor): 107 | assert isinstance(module, nn.BatchNorm1d) 108 | batchnorm1d = module 109 | assert data_input.ndimension() == 2 # n x f 110 | assert grad_output.ndimension() == 2 # n x f 111 | assert batchnorm1d.affine 112 | 113 | if batchnorm1d.weight.requires_grad: 114 | grads = data_input.mul(grad_output) # n x f 115 | setattr(batchnorm1d.weight, 'grads', grads) 116 | 117 | if batchnorm1d.bias.requires_grad: 118 | setattr(batchnorm1d.bias, 'grads', grad_output) # n x f 119 | 120 | 121 | def grad_batchnorm2d(module: nn.Module, data_input: torch.Tensor, grad_output: torch.Tensor): 122 | assert isinstance(module, nn.BatchNorm2d) 123 | batchnorm2d = module 124 | assert data_input.ndimension() == 4 # n x c x h x w 125 | assert grad_output.ndimension() == 4 # n x c x h x w 126 | assert batchnorm2d.affine 127 | 128 | if batchnorm2d.weight.requires_grad: 129 | grads = data_input.mul(grad_output).sum(dim=(2, 3)) # n x c 130 | setattr(batchnorm2d.weight, 'grads', grads) 131 | 132 | if batchnorm2d.bias.requires_grad: 133 | setattr(batchnorm2d.bias, 'grads', grad_output.sum(dim=(2, 3))) # n x c 134 | 135 | -------------------------------------------------------------------------------- /torchsso/curv/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybertronai/pytorch-sso/d16173236728fdd1e943bc3b61243bfb28f95348/torchsso/curv/__init__.py -------------------------------------------------------------------------------- /torchsso/curv/cov/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybertronai/pytorch-sso/d16173236728fdd1e943bc3b61243bfb28f95348/torchsso/curv/cov/__init__.py -------------------------------------------------------------------------------- /torchsso/curv/cov/batchnorm.py: -------------------------------------------------------------------------------- 1 | from torchsso import Curvature, DiagCurvature 2 | 3 | 4 | class CovBatchNorm1d(Curvature): 5 | 6 | def update_in_backward(self, grad_output_data): 7 | pass 8 | 9 | 10 | class DiagCovBatchNorm1d(DiagCurvature): 11 | 12 | def update_in_backward(self, grad_output): 13 | data_input = getattr(self._module, 'data_input', None) # n x f 14 | assert data_input is not None 15 | 16 | in_in = data_input.mul(data_input) # n x f 17 | grad_grad = grad_output.mul(grad_output) # n x f 18 | 19 | data_w = in_in.mul(grad_grad).mean(dim=0) # f x 1 20 | 21 | self._data = [data_w] 22 | 23 | if self.bias: 24 | data_b = grad_grad.mean(dim=0) # f x 1 25 | self._data.append(data_b) 26 | 27 | 28 | class CovBatchNorm2d(Curvature): 29 | 30 | def update_in_backward(self, grad_output): 31 | pass 32 | 33 | 34 | class DiagCovBatchNorm2d(DiagCurvature): 35 | 36 | def update_in_backward(self, grad_out): 37 | data_input = getattr(self._module, 'data_input', None) # n x c x h x w 38 | assert data_input is not None 39 | 40 | in_in = data_input.mul(data_input).sum(dim=(2, 3)) # n x c 41 | grad_grad = grad_out.mul(grad_out).sum(dim=(2, 3)) # n x c 42 | 43 | data_w = in_in.mul(grad_grad).mean(dim=0) # c x 1 44 | 45 | self._data = [data_w] 46 | 47 | if self.bias: 48 | data_b = grad_grad.mean(dim=0) # c x 1 49 | self._data.append(data_b) 50 | -------------------------------------------------------------------------------- /torchsso/curv/cov/conv.py: -------------------------------------------------------------------------------- 1 | from torchsso import Curvature, DiagCurvature, KronCurvature 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class CovConv2d(Curvature): 7 | 8 | def update_in_backward(self, grad_output): 9 | pass 10 | 11 | def precgrad(self, params): 12 | pass 13 | 14 | 15 | class DiagCovConv2d(DiagCurvature): 16 | 17 | def update_in_backward(self, grad_output): 18 | conv2d = self._module 19 | data_input = getattr(conv2d, 'data_input', None) # n x c_in x h_in x w_in 20 | assert data_input is not None 21 | 22 | # n x (c_in)(k_h)(k_w) x (h_out)(w_out) 23 | input2d = F.unfold(data_input, 24 | kernel_size=conv2d.kernel_size, stride=conv2d.stride, 25 | padding=conv2d.padding, dilation=conv2d.dilation) 26 | 27 | # n x c_out x h_out x w_out 28 | n, c_out, h, w = grad_output.shape 29 | # n x c_out x (h_out)(w_out) 30 | grad_output2d = grad_output.reshape(n, c_out, -1) 31 | 32 | grad_in = torch.einsum('bik,bjk->bij', 33 | grad_output2d, input2d) # n x c_out x (c_in)(k_h)(k_w) 34 | 35 | data_w = grad_in.mul(grad_in).mean(dim=0) # c_out x (c_in)(k_h)(k_w) 36 | data_w = data_w.reshape((c_out, -1, *conv2d.kernel_size)) # c_out x c_in x k_h x k_w 37 | self._data = [data_w] 38 | 39 | if self.bias: 40 | grad_grad = grad_output2d.mul(grad_output2d) # n x c_out x (h_out)(w_out) 41 | data_b = grad_grad.sum(dim=2).mean(dim=0) # c_out 42 | self._data.append(data_b) 43 | 44 | 45 | class KronCovConv2d(KronCurvature): 46 | 47 | def update_in_forward(self, data_input): 48 | conv2d = self._module 49 | 50 | # n x (c_in)(k_h)(k_w) x (h_out)(w_out) 51 | input2d = F.unfold(data_input, 52 | kernel_size=conv2d.kernel_size, stride=conv2d.stride, 53 | padding=conv2d.padding, dilation=conv2d.dilation) 54 | 55 | n, a, _ = input2d.shape 56 | 57 | # (c_in)(k_h)(k_w) x n(h_out)(w_out) 58 | m = input2d.transpose(0, 1).reshape(a, -1) 59 | a, b = m.shape 60 | if self.bias: 61 | # {(c_in)(k_h)(k_w) + 1} x n(h_out)(w_out) 62 | m = torch.cat((m, m.new_ones((1, b))), 0) 63 | 64 | # (c_in)(k_h)(k_w) x (c_in)(k_h)(k_w) or 65 | # {(c_in)(k_h)(k_w) + 1} x {(c_in)(k_h)(k_w) + 1} 66 | A = torch.einsum('ik,jk->ij', m, m).div(n) 67 | self._A = A 68 | 69 | def update_in_backward(self, grad_output): 70 | n, c, h, w = grad_output.shape # n x c_out x h_out x w_out 71 | m = grad_output.transpose(0, 1).reshape(c, -1) # c_out x n(h_out)(w_out) 72 | 73 | G = torch.einsum('ik,jk->ij', m, m).div(n*h*w) # c_out x c_out 74 | self._G = G 75 | 76 | def precondition_grad(self, params): 77 | A_inv, G_inv = self.inv 78 | 79 | # todo check params == list? 80 | oc, _, _, _ = params[0].shape 81 | if self.bias: 82 | grad2d = torch.cat( 83 | (params[0].grad.reshape(oc, -1), params[1].grad.view(-1, 1)), 1) 84 | preconditioned_grad2d = G_inv.mm(grad2d).mm(A_inv) 85 | 86 | params[0].grad.copy_(preconditioned_grad2d[:, 0:-1].reshape_as(params[0])) 87 | params[1].grad.copy_(preconditioned_grad2d[:, -1]) 88 | else: 89 | grad2d = params[0].grad.reshape(oc, -1) 90 | preconditioned_grad2d = G_inv.mm(grad2d).mm(A_inv) 91 | 92 | params[0].grad.copy_(preconditioned_grad2d.reshape_as(params[0])) 93 | 94 | def sample_params(self, params, mean, std_scale): 95 | A_ic, G_ic = self.std 96 | oc, ic, h, w = mean[0].shape 97 | if self.bias: 98 | m = torch.cat( 99 | (mean[0].reshape(oc, -1), mean[1].view(-1, 1)), 1) 100 | param = m.add(std_scale, G_ic.mm( 101 | torch.randn_like(m)).mm(A_ic)) 102 | params[0].data.copy_(param[:, 0:-1].reshape(oc, ic, h, w)) 103 | params[1].data.copy_(param[:, -1]) 104 | else: 105 | m = mean[0].reshape(oc, -1) 106 | param = m.add(std_scale, G_ic.mm( 107 | torch.randn_like(m)).mm(A_ic)) 108 | params[0].data = param.reshape(oc, ic, h, w) 109 | 110 | def _get_shape(self): 111 | linear = self._module 112 | w = getattr(linear, 'weight') 113 | c_out, c_in, k_h, k_w = w.shape 114 | 115 | G_shape = (c_out, c_out) 116 | 117 | dim = c_in * k_h * k_w 118 | if self.bias: 119 | A_shape = (dim + 1, dim + 1) 120 | else: 121 | A_shape = (dim, dim) 122 | 123 | return A_shape, G_shape 124 | 125 | -------------------------------------------------------------------------------- /torchsso/curv/cov/linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchsso import Curvature, DiagCurvature, KronCurvature 3 | 4 | 5 | class CovLinear(Curvature): 6 | 7 | def update_in_backward(self, grad_output): 8 | data_input = getattr(self._module, 'data_input', None) # n x f_in 9 | assert data_input is not None 10 | 11 | n = data_input.shape[0] 12 | 13 | if self.bias: 14 | ones = torch.ones((n, 1), device=data_input.device, dtype=data_input.dtype) 15 | data_input = torch.cat((data_input, ones), 1) # n x (f_in+1) 16 | 17 | grad = torch.einsum('bi,bj->bij', grad_output, data_input) # n x f_out x f_in 18 | grad = grad.reshape((n, -1)) # n x (f_out)(f_in) 19 | 20 | data = torch.einsum('bi,bj->ij', grad, grad) 21 | 22 | self._data = [data] 23 | 24 | def precondition_grad(self, params): 25 | pass 26 | 27 | 28 | class DiagCovLinear(DiagCurvature): 29 | 30 | def update_in_backward(self, grad_output): 31 | data_input = getattr(self._module, 'data_input', None) # n x f_in 32 | assert data_input is not None 33 | 34 | n = data_input.shape[0] 35 | 36 | in_in = data_input.mul(data_input) # n x f_in 37 | grad_grad = grad_output.mul(grad_output) # n x f_out 38 | 39 | data_w = torch.einsum('ki,kj->ij', grad_grad, 40 | in_in).div(n) # f_out x f_in 41 | self._data = [data_w] 42 | 43 | if self.bias: 44 | data_b = grad_grad.mean(dim=0) # f_out x 1 45 | self._data.append(data_b) 46 | 47 | 48 | class KronCovLinear(KronCurvature): 49 | 50 | def update_in_forward(self, input_data): 51 | n = input_data.shape[0] # n x f_in 52 | if self.bias: 53 | ones = input_data.new_ones((n, 1)) 54 | # shape: n x (f_in+1) 55 | input_data = torch.cat((input_data, ones), 1) 56 | 57 | # f_in x f_in or (f_in+1) x (f_in+1) 58 | A = torch.einsum('ki,kj->ij', input_data, input_data).div(n) 59 | self._A = A 60 | 61 | def update_in_backward(self, grad_output): 62 | n = grad_output.shape[0] # n x f_out 63 | 64 | # f_out x f_out 65 | G = torch.einsum( 66 | 'ki,kj->ij', grad_output, grad_output).div(n) 67 | self._G = G 68 | 69 | def precondition_grad(self, params): 70 | A_inv, G_inv = self.inv 71 | 72 | # todo check params == list? 73 | if self.bias: 74 | grad = torch.cat( 75 | (params[0].grad, params[1].grad.view(-1, 1)), 1) 76 | preconditioned_grad = G_inv.mm(grad).mm(A_inv) 77 | 78 | params[0].grad.copy_(preconditioned_grad[:, :-1]) 79 | params[1].grad.copy_(preconditioned_grad[:, -1]) 80 | else: 81 | grad = params[0].grad 82 | preconditioned_grad = G_inv.mm(grad).mm(A_inv) 83 | 84 | params[0].grad.copy_(preconditioned_grad) 85 | 86 | def sample_params(self, params, mean, std_scale): 87 | A_ic, G_ic = self.std 88 | 89 | if self.bias: 90 | m = torch.cat( 91 | (mean[0], mean[1].view(-1, 1)), 1) 92 | param = m.add(std_scale, G_ic.mm( 93 | torch.randn_like(m)).mm(A_ic)) 94 | params[0].data.copy_(param[:, 0:-1]) 95 | params[1].data.copy_(param[:, -1]) 96 | else: 97 | m = mean[0] 98 | param = mean.add(std_scale, G_ic.mm( 99 | torch.randn_like(m)).mm(A_ic)) 100 | params[0].data = param 101 | 102 | def _get_shape(self): 103 | linear = self._module 104 | w = getattr(linear, 'weight') 105 | f_out, f_in = w.shape 106 | 107 | G_shape = (f_out, f_out) 108 | 109 | if self.bias: 110 | A_shape = (f_in + 1, f_in + 1) 111 | else: 112 | A_shape = (f_in, f_in) 113 | 114 | return A_shape, G_shape 115 | 116 | -------------------------------------------------------------------------------- /torchsso/curv/curvature.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchsso 6 | 7 | PI_TYPE_TRACENORM = 'tracenorm' 8 | 9 | 10 | class Curvature(object): 11 | r"""Base implementation of the curvatures for each layer. 12 | 13 | This class computes/maintains curvature (data) and EMA/inverse of it for a given layer (module) 14 | which are used for torchsso.optim.SecondOrderOptimizer. 15 | Standard deviation (std) is calculated for torchsso.optim.VIOptimizer based on the inverse. 16 | IE, data -> ema -> inv (-> std) 17 | 18 | Args: 19 | module (torch.nn.Module): a layer with trainable params for which the curvature is computed 20 | ema_decay (float, optional): decay rate for EMA of curvature 21 | damping (float, optional): value to be added to the diagonal of EMA before inverting it 22 | use_max_ema (bool, optional): whether to use the maximum value as EMA 23 | use_sqrt_ema (bool, optional): whether to take the squre root of EMA 24 | """ 25 | 26 | def __init__(self, module: nn.Module, ema_decay=1., damping=1e-7, 27 | use_max_ema=False, use_sqrt_ema=False, 28 | pi_type=PI_TYPE_TRACENORM): 29 | 30 | if ema_decay < 0 or 1 < ema_decay: 31 | raise ValueError("Invalid ema_decay: {}".format(ema_decay)) 32 | if damping < 0: 33 | raise ValueError("Invalid damping: {}".format(damping)) 34 | if pi_type not in [PI_TYPE_TRACENORM]: 35 | raise ValueError("Invalid pi_type: {}".format(pi_type)) 36 | 37 | self._module = module 38 | self.ema_decay = ema_decay 39 | self._damping = damping 40 | self._l2_reg = 0 41 | self._l2_reg_ema = 0 42 | 43 | self._data = None 44 | self._acc_data = None 45 | self.ema = None 46 | self.ema_max = None 47 | self.inv = None 48 | self.std = None 49 | 50 | self.use_sqrt_ema = use_sqrt_ema 51 | self.use_max_ema = use_max_ema 52 | 53 | self.pi_type = pi_type 54 | 55 | module.register_forward_hook(self.forward_postprocess) 56 | module.register_backward_hook(self.backward_postprocess) 57 | 58 | @property 59 | def data(self): 60 | return self._data 61 | 62 | @data.setter 63 | def data(self, value): 64 | self._data = value 65 | 66 | @property 67 | def shape(self): 68 | if self._data is None: 69 | return self._get_shape() 70 | 71 | return tuple([d.shape for d in self._data]) 72 | 73 | @property 74 | def device(self): 75 | return next(self._module.parameters()).device 76 | 77 | def _get_shape(self): 78 | size = 0 79 | for p in self._module.parameters(): 80 | size += p.view(-1).shape[0] 81 | 82 | return tuple((size, size)) 83 | 84 | def element_wise_init(self, value): 85 | init_data = [] 86 | for s in self.shape: 87 | diag = torch.ones(s[0], device=self.device).mul(value) # 1d 88 | diag = torch.diag(diag) # 1d -> 2d 89 | init_data.append(diag) 90 | 91 | self._data = init_data 92 | 93 | @property 94 | def module(self): 95 | return self._module 96 | 97 | @property 98 | def bias(self): 99 | bias = getattr(self._module, 'bias', None) 100 | return False if bias is None else True 101 | 102 | @property 103 | def damping(self): 104 | return self._damping + self._l2_reg_ema 105 | 106 | @property 107 | def l2_reg(self): 108 | return self._l2_reg 109 | 110 | @l2_reg.setter 111 | def l2_reg(self, value): 112 | self._l2_reg = value 113 | 114 | @property 115 | def l2_reg_ema(self): 116 | return self._l2_reg_ema 117 | 118 | def forward_postprocess(self, module, input, output): 119 | assert self._module == module 120 | 121 | data_input = input[0].detach() 122 | 123 | if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 124 | bnorm = module 125 | f = bnorm.num_features 126 | if isinstance(module, nn.BatchNorm1d): 127 | shape = (1, f) 128 | elif isinstance(module, nn.BatchNorm2d): 129 | shape = (1, f, 1, 1) 130 | else: 131 | shape = (1, f, 1, 1, 1) 132 | # restore normalized input 133 | data_input_norm = (output - bnorm.bias.view(shape)).div(bnorm.weight.view(shape)) 134 | data_input = data_input_norm 135 | 136 | setattr(module, 'data_input', data_input) 137 | setattr(module, 'data_output', output) 138 | 139 | self.update_in_forward(data_input) 140 | 141 | def backward_postprocess(self, module, grad_input, grad_output): 142 | assert self._module == module 143 | 144 | index = 1 if self.bias else 0 145 | grad_input = None if grad_input[index] is None else grad_input[index].detach() 146 | grad_output = grad_output[0] 147 | 148 | setattr(module, 'grad_input', grad_input) 149 | setattr(module, 'grad_output', grad_output) 150 | 151 | self.update_in_backward(grad_output) 152 | 153 | # adjust grad scale along with 'reduction' in loss function 154 | batch_size = grad_output.shape[0] 155 | self.adjust_data_scale(batch_size**2) 156 | 157 | def adjust_data_scale(self, scale): 158 | self._data = [d.mul(scale) for d in self._data] 159 | 160 | def update_in_forward(self, data_input): 161 | pass 162 | 163 | def update_in_backward(self, grad_output): 164 | raise NotImplementedError 165 | 166 | def step(self, update_std=False, update_inv=True): 167 | # TODO(oosawak): Add check for ema/inv timing 168 | self.update_ema() 169 | if update_inv: 170 | self.update_inv() 171 | if update_std: 172 | self.update_std() 173 | 174 | def update_ema(self): 175 | data = self.data 176 | ema = self.ema 177 | ema_max = self.ema_max 178 | beta = self.ema_decay 179 | if ema is None or beta == 1: 180 | self.ema = [d.clone() for d in data] 181 | if self.use_max_ema and ema_max is None: 182 | self.ema_max = [e.clone() for e in self.ema] 183 | self._l2_reg_ema = self._l2_reg 184 | else: 185 | self.ema = [d.mul(beta).add(1 - beta, e) 186 | for d, e in zip(data, ema)] 187 | self._l2_reg_ema = self._l2_reg * beta + self._l2_reg_ema * (1 - beta) 188 | 189 | if self.use_max_ema: 190 | for e, e_max in zip(self.ema, self.ema_max): 191 | torch.max(e, e_max, out=e_max) 192 | 193 | def update_inv(self): 194 | ema = self.ema if not self.use_max_ema else self.ema_max 195 | self.inv = [self._inv(e) for e in ema] 196 | 197 | def _inv(self, X): 198 | X_damp = add_value_to_diagonal(X, self.damping) 199 | 200 | return torchsso.utils.inv(X_damp) 201 | 202 | def precondition_grad(self, params): 203 | raise NotImplementedError 204 | 205 | def update_std(self): 206 | raise NotImplementedError 207 | 208 | def sample_params(self, params, mean, std_scale): 209 | raise NotImplementedError 210 | 211 | def std_norm(self): 212 | raise NotImplementedError 213 | 214 | 215 | class DiagCurvature(Curvature): 216 | 217 | def _get_shape(self): 218 | return tuple(p.shape for p in self.module.parameters()) 219 | 220 | def element_wise_init(self, value): 221 | self._data = [torch.ones(s, device=self.device).mul(value) for s in self.shape] 222 | 223 | def update_in_backward(self, grad_output_data): 224 | raise NotImplementedError 225 | 226 | def _inv(self, X): 227 | if self.use_sqrt_ema: 228 | X = X.sqrt() 229 | 230 | X_damp = X.add(X.new_ones(X.shape).mul(self.damping)) 231 | 232 | return 1 / X_damp 233 | 234 | def precondition_grad(self, params): 235 | for p, inv in zip(params, self.inv): 236 | preconditioned_grad = inv.mul(p.grad) 237 | 238 | p.grad.copy_(preconditioned_grad) 239 | 240 | def update_std(self): 241 | self.std = [inv.sqrt() for inv in self.inv] 242 | 243 | def sample_params(self, params, mean, std_scale): 244 | for p, m, std in zip(params, mean, self.std): 245 | noise = torch.randn_like(m) 246 | p.data.copy_(torch.addcmul(m, std_scale, noise, std)) 247 | 248 | def std_norm(self): 249 | if self.std is None: 250 | return 0 251 | 252 | return sum(std.norm().item() for std in self.std) 253 | 254 | 255 | class KronCurvature(Curvature): 256 | 257 | def __init__(self, *args, **kwargs): 258 | super(KronCurvature, self).__init__(*args, **kwargs) 259 | 260 | self._A = None 261 | self._G = None 262 | 263 | @property 264 | def data(self): 265 | return [self._A, self._G] 266 | 267 | @data.setter 268 | def data(self, value): 269 | self._A, self._G = value 270 | 271 | @property 272 | def shape(self): 273 | if self._A is None or self._G is None: 274 | return self._get_shape() 275 | 276 | return self._A.shape, self._G.shape 277 | 278 | def _get_shape(self): 279 | raise NotImplementedError 280 | 281 | def element_wise_init(self, value): 282 | super(KronCurvature, self).element_wise_init(math.sqrt(value)) 283 | self._A, self._G = self._data 284 | 285 | @property 286 | def A(self): 287 | return self._A 288 | 289 | @property 290 | def G(self): 291 | return self._G 292 | 293 | def update_in_forward(self, input_data): 294 | raise NotImplementedError 295 | 296 | def update_in_backward(self, grad_output_data): 297 | raise NotImplementedError 298 | 299 | def adjust_data_scale(self, scale): 300 | self._G.mul_(scale) 301 | 302 | def update_inv(self): 303 | A, G = self.ema 304 | 305 | if self.pi_type == PI_TYPE_TRACENORM: 306 | pi = torch.sqrt((A.trace()/A.shape[0])/(G.trace()/G.shape[0])) 307 | else: 308 | pi = 1. 309 | 310 | r = self.damping**0.5 311 | self.inv = [torchsso.utils.inv(add_value_to_diagonal(X, value)) 312 | for X, value in zip([A, G], [r*pi, r/pi])] 313 | 314 | def precondition_grad(self, params): 315 | raise NotImplementedError 316 | 317 | def update_std(self): 318 | A_inv, G_inv = self.inv 319 | 320 | self.std = [torchsso.utils.cholesky(X) 321 | for X in [A_inv, G_inv]] 322 | 323 | def sample_params(self, params, mean, std_scale): 324 | raise NotImplementedError 325 | 326 | def std_norm(self): 327 | if self.std is None: 328 | return 0 329 | 330 | A_ic, G_ic = self.std 331 | return A_ic.norm().item() * G_ic.norm().item() 332 | 333 | 334 | def add_value_to_diagonal(X, value): 335 | if torch.cuda.is_available(): 336 | indices = torch.cuda.LongTensor([[i, i] for i in range(X.shape[0])]) 337 | else: 338 | indices = torch.LongTensor([[i, i] for i in range(X.shape[0])]) 339 | values = X.new_ones(X.shape[0]).mul(value) 340 | return X.index_put(tuple(indices.t()), values, accumulate=True) 341 | -------------------------------------------------------------------------------- /torchsso/curv/fisher/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torchsso.utils import TensorAccumulator 5 | 6 | 7 | class Fisher(object): 8 | 9 | def __init__(self): 10 | self.prob = None 11 | self._do_backward = True 12 | self._acc_cov = TensorAccumulator() 13 | 14 | @property 15 | def do_backward(self): 16 | return self._do_backward 17 | 18 | def turn_on_backward(self): 19 | self._do_backward = True 20 | 21 | def turn_off_backward(self): 22 | self._do_backward = False 23 | 24 | def accumulate_cov(self, cov): 25 | self._acc_cov.update(cov) 26 | 27 | def finalize(self): 28 | return self._acc_cov.get() 29 | 30 | def update_as_presoftmax(self, prob): 31 | raise NotImplementedError('This method supports only torchsso.KronFisherLinear.') 32 | 33 | 34 | def get_closure_for_fisher(optimizer, model, data, target, approx_type=None, num_mc=1): 35 | 36 | _APPROX_TYPE_MC = 'mc' 37 | 38 | def turn_off_param_grad(): 39 | for group in optimizer.param_groups: 40 | group['curv'].turn_on_backward() 41 | for param in group['params']: 42 | param.requires_grad = False 43 | 44 | def turn_on_param_grad(): 45 | for group in optimizer.param_groups: 46 | group['curv'].turn_off_backward() 47 | for param in group['params']: 48 | param.requires_grad = True 49 | 50 | def closure(): 51 | 52 | for group in optimizer.param_groups: 53 | assert isinstance(group['curv'], Fisher), f"Invalid Curvature type: {type(group['curv'])}." 54 | 55 | optimizer.zero_grad() 56 | output = model(data) 57 | prob = F.softmax(output, dim=1) 58 | 59 | is_sampling = approx_type is None or approx_type == _APPROX_TYPE_MC 60 | 61 | if is_sampling: 62 | turn_off_param_grad() 63 | 64 | if approx_type == _APPROX_TYPE_MC: 65 | dist = torch.distributions.Categorical(prob) 66 | _target = dist.sample((num_mc,)) 67 | for group in optimizer.param_groups: 68 | group['curv'].prob = torch.ones_like(prob[:, 0]).div(num_mc) 69 | 70 | for i in range(num_mc): 71 | loss = F.cross_entropy(output, _target[i]) 72 | loss.backward(retain_graph=True) 73 | else: 74 | for i in range(model.num_classes): 75 | for group in optimizer.param_groups: 76 | group['curv'].prob = prob[:, i] 77 | loss = F.cross_entropy(output, torch.ones_like(target).mul(i)) 78 | loss.backward(retain_graph=True) 79 | 80 | turn_on_param_grad() 81 | 82 | else: 83 | raise ValueError('Invalid approx type: {}'.format(approx_type)) 84 | 85 | loss = F.cross_entropy(output, target) 86 | loss.backward() 87 | 88 | return loss, output 89 | 90 | return closure 91 | -------------------------------------------------------------------------------- /torchsso/curv/fisher/batchnorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchsso import DiagCovBatchNorm2d, Fisher 3 | 4 | 5 | class DiagFisherBatchNorm2d(DiagCovBatchNorm2d, Fisher): 6 | 7 | def __init__(self, *args, **kwargs): 8 | DiagCovBatchNorm2d.__init__(self, *args, **kwargs) 9 | Fisher.__init__(self) 10 | 11 | def update_in_backward(self, grad_out): 12 | if self.do_backward: 13 | assert self.prob is not None 14 | data_input = getattr(self._module, 'data_input', None) # n x c x h x w 15 | assert data_input is not None 16 | 17 | n = grad_out.shape[0] # n x c x h x w 18 | pg = torch.mul(grad_out, self.prob.reshape(n, 1, 1, 1)) 19 | 20 | grad_grad = pg.mul(grad_out).sum(dim=(2, 3)) # n x c 21 | in_in = data_input.mul(data_input).sum(dim=(2, 3)) # n x c 22 | 23 | data_w = in_in.mul(grad_grad).mean(dim=0) # c x 1 24 | 25 | self._data = [data_w] 26 | 27 | if self.bias: 28 | data_b = grad_grad.mean(dim=0) # c x 1 29 | self._data.append(data_b) 30 | self.accumulate_cov(self._data) 31 | else: 32 | self._data = self.finalize() 33 | -------------------------------------------------------------------------------- /torchsso/curv/fisher/conv.py: -------------------------------------------------------------------------------- 1 | from torchsso import DiagCovConv2d, KronCovConv2d, Fisher 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class DiagFisherConv2d(DiagCovConv2d, Fisher): 7 | 8 | def __init__(self, *args, **kwargs): 9 | DiagCovConv2d.__init__(self, *args, **kwargs) 10 | Fisher.__init__(self) 11 | 12 | def update_in_backward(self, grad_output): 13 | 14 | if self.do_backward: 15 | assert self.prob is not None 16 | 17 | conv2d = self._module 18 | data_input = getattr(conv2d, 'data_input', None) # n x c_in x h_in x w_in 19 | assert data_input is not None 20 | 21 | # n x (c_in)(k_h)(k_w) x (h_out)(w_out) 22 | input2d = F.unfold(data_input, 23 | kernel_size=conv2d.kernel_size, stride=conv2d.stride, 24 | padding=conv2d.padding, dilation=conv2d.dilation) 25 | 26 | # n x c_out x h_out x w_out 27 | n, c_out, h, w = grad_output.shape 28 | # n x c_out x (h_out)(w_out) 29 | grad_output2d = grad_output.reshape(n, c_out, -1) 30 | 31 | grad_in = torch.einsum('bik,bjk->bij', 32 | grad_output2d, input2d) # n x c_out x (c_in)(k_h)(k_w) 33 | 34 | pgi = torch.mul(grad_in, self.prob.reshape(n, 1, 1)) 35 | data_w = pgi.mul(grad_in).mean(dim=0) # c_out x (c_in)(k_h)(k_w) 36 | data_w = data_w.reshape((c_out, -1, *conv2d.kernel_size)) # c_out x c_in x k_h x k_w 37 | self._data = [data_w] 38 | 39 | if self.bias: 40 | pg = torch.mul(grad_output2d, self.prob.reshape(n, 1, 1)) 41 | grad_grad = pg.mul(grad_output2d) # n x c_out x (h_out)(w_out) 42 | data_b = grad_grad.sum(dim=2).mean(dim=0) # c_out 43 | self._data.append(data_b) 44 | 45 | self.accumulate_cov(self._data) 46 | else: 47 | self._data = self.finalize() 48 | 49 | 50 | class KronFisherConv2d(KronCovConv2d, Fisher): 51 | 52 | def __init__(self, *args, **kwargs): 53 | KronCovConv2d.__init__(self, *args, **kwargs) 54 | Fisher.__init__(self) 55 | 56 | def update_in_backward(self, grad_output): 57 | if self.do_backward: 58 | assert self.prob is not None 59 | n, c, h, w = grad_output.shape # n x c_out x h_out x w_out 60 | 61 | pg = torch.mul(grad_output, self.prob.reshape(n, 1, 1, 1)) 62 | pm = pg.transpose(0, 1).reshape(c, -1) # c_out x n(h_out)(w_out) 63 | m = grad_output.transpose(0, 1).reshape(c, -1) # c_out x n(h_out)(w_out) 64 | 65 | G = torch.einsum('ik,jk->ij', pm, m).div(n*h*w) # c_out x c_out 66 | self._G = G 67 | self.accumulate_cov(G) 68 | else: 69 | self._G = self.finalize() 70 | -------------------------------------------------------------------------------- /torchsso/curv/fisher/linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchsso import DiagCovLinear, KronCovLinear, Fisher 3 | 4 | 5 | class DiagFisherLinear(DiagCovLinear, Fisher): 6 | 7 | def __init__(self, *args, **kwargs): 8 | DiagCovLinear.__init__(self, *args, **kwargs) 9 | Fisher.__init__(self) 10 | 11 | def update_in_backward(self, grad_output): 12 | if self.do_backward: 13 | assert self.prob is not None 14 | 15 | data_input = getattr(self._module, 'data_input', None) # n x f_in 16 | assert data_input is not None 17 | 18 | n = data_input.shape[0] 19 | 20 | in_in = data_input.mul(data_input) # n x f_in 21 | 22 | pg = torch.mul(grad_output, self.prob.reshape(n, 1)) 23 | grad_grad = pg.mul(grad_output) # n x f_out 24 | 25 | data_w = torch.einsum('ki,kj->ij', grad_grad, 26 | in_in).div(n) # f_out x f_in 27 | self._data = [data_w] 28 | 29 | if self.bias: 30 | data_b = grad_grad.mean(dim=0) # f_out x 1 31 | self._data.append(data_b) 32 | 33 | self.accumulate_cov(self._data) 34 | else: 35 | self._data = self.finalize() 36 | 37 | 38 | class KronFisherLinear(KronCovLinear, Fisher): 39 | 40 | def __init__(self, *args, **kwargs): 41 | KronCovLinear.__init__(self, *args, **kwargs) 42 | Fisher.__init__(self) 43 | 44 | def update_in_backward(self, grad_output): 45 | if self.do_backward: 46 | assert self.prob is not None 47 | n = grad_output.shape[0] # n x f_out 48 | 49 | pg = torch.mul(grad_output, self.prob.reshape(n, 1)) 50 | 51 | # f_out x f_out 52 | G = torch.einsum( 53 | 'ki,kj->ij', pg, grad_output).div(n) 54 | self._G = G 55 | self.accumulate_cov(G) 56 | else: 57 | self._G = self.finalize() 58 | 59 | def update_as_presoftmax(self, prob): 60 | n, dim = prob.shape 61 | cov = torch.einsum('ki,kj->ij', prob, prob).div(n) 62 | fisher_presoftmax = (torch.diag(prob.sum(dim=0)) - cov).div(n) 63 | self._G = fisher_presoftmax 64 | 65 | -------------------------------------------------------------------------------- /torchsso/curv/hessian/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchsso import KronCurvature 3 | 4 | 5 | class KronHessian(KronCurvature): 6 | 7 | def update_in_forward(self, input_data): 8 | raise NotImplementedError 9 | 10 | def update_in_backward(self, grad_output): 11 | output = getattr(self._module, 'data_output') 12 | 13 | device = grad_output.device 14 | n = grad_output.shape[0] 15 | dim = grad_output.shape[1] 16 | 17 | post_curv = self.post_curv 18 | 19 | if post_curv is not None: 20 | post_module = post_curv.module 21 | 22 | import time 23 | 24 | print('-----------------') 25 | start = time.time() 26 | 27 | print(self.module) 28 | 29 | if post_curv is not None: 30 | post_module = post_curv.module 31 | print(post_module) 32 | 33 | post_output = getattr(post_module, 'data_output') 34 | post_dim = post_output.shape[1] 35 | 36 | post_out_grad_out = torch.zeros((n, post_dim, dim)) # n x post_dim x dim 37 | if post_dim <= dim: 38 | post_output = reshape_4d_to_2d(post_output) 39 | print('n: {}, dim: {}'.format(len(post_output), post_dim)) 40 | for i in range(post_dim): 41 | outputs = tuple(po[i] for po in post_output) 42 | grad = torch.autograd.grad(outputs, output, create_graph=True) 43 | post_out_grad_out[:, i, :] = reshape_4d_to_2d(grad[0], reduce=True) # n x dim 44 | else: 45 | post_grad_output = getattr(post_module, 'grad_output') 46 | grad_output = reshape_4d_to_2d(grad_output) 47 | print('n: {}, dim: {}'.format(len(grad_output), dim)) 48 | for i in range(dim): 49 | outputs = tuple(g[i] for g in grad_output) 50 | grad = torch.autograd.grad(outputs, post_grad_output, create_graph=True) 51 | post_out_grad_out[:, :, i] = reshape_4d_to_2d(grad[0], reduce=True) # n x post_dim 52 | 53 | post_out_grad_out = post_out_grad_out.to(device) 54 | 55 | recursive_approx = getattr(post_curv, 'recursive_approx', False) 56 | if recursive_approx: 57 | equation = 'bij,ik,bkl->bjl' 58 | post_hessian_output = post_curv.G # post_dim x post_dim 59 | else: 60 | equation = 'bij,bik,bkl->bjl' 61 | post_hessian_output = getattr(post_module, 'hessian_output', None) # n x post_dim x post_dim 62 | 63 | msg = 'hessian of loss w.r.t. outputs of post layer' \ 64 | ' have to be computed beforehand.' 65 | assert post_hessian_output is not None, msg 66 | 67 | # compute sample hessian_output based on hessian_output of post module 68 | hessian_output = torch.einsum(equation, 69 | post_out_grad_out, # n x post_dim x dim 70 | post_hessian_output, # n x post_dim x post_dim 71 | post_out_grad_out) # n x post_dim x dim 72 | 73 | del post_module.hessian_output 74 | del post_out_grad_out 75 | 76 | else: 77 | # compute sample hessian_output from scratch 78 | hessian_output = torch.zeros((n, dim, dim)) 79 | print('n: {}, dim: {}'.format(len(grad_output), dim)) 80 | for i in range(dim): 81 | outputs = tuple(g[i] for g in reshape_4d_to_2d(grad_output)) 82 | grad = torch.autograd.grad(outputs, output, create_graph=True) 83 | hessian_output[:, i, :] = reshape_4d_to_2d(grad[0], reduce=True) 84 | 85 | hessian_output = hessian_output.to(device) 86 | setattr(self._module, 'hessian_output', hessian_output) 87 | 88 | # refresh hessian_output 89 | self._G = hessian_output.sum((0,)) # dim x dim 90 | 91 | elapsed = time.time() - start 92 | print('{}s'.format(elapsed)) 93 | 94 | def precondition_grad(self, params): 95 | raise NotImplementedError 96 | 97 | def sample_params(self, params, mean, std_scale): 98 | raise NotImplementedError 99 | 100 | def backward_postprocess(self, module, grad_input, grad_output): 101 | # skip hook for higher order derivative 102 | order = getattr(module, 'derivative_order', 1) 103 | if order > 1: 104 | return 105 | 106 | super(KronHessian, self).backward_postprocess(module, grad_input, grad_output) 107 | 108 | # skip hook for higher order derivative 109 | setattr(module, 'derivative_order', 2) 110 | 111 | def reset_derivative_order(self): 112 | module = self._module 113 | setattr(module, 'derivative_order', 1) 114 | 115 | def step(self, update_std=False): 116 | super(KronHessian, self).step(update_std) 117 | self.reset_derivative_order() 118 | 119 | 120 | def reshape_4d_to_2d(data, reduce=False): 121 | ndim = len(data.shape) 122 | if ndim == 2: 123 | return data 124 | 125 | assert ndim == 4, 'number of dimension of data is expected to be 4, got {}.'.format(ndim) 126 | 127 | if reduce: 128 | # n x c x h x w -> n x c 129 | return data.sum((2, 3)) 130 | else: 131 | n, c, h, w = data.shape 132 | # n x c x h x w -> n x h x w x c -> n*h*w x c 133 | data = data.transpose(1, 2).transpose(2, 3).contiguous().view(n*h*w, c) 134 | return data 135 | 136 | -------------------------------------------------------------------------------- /torchsso/curv/hessian/conv.py: -------------------------------------------------------------------------------- 1 | from torchsso import KronCovConv2d, KronHessian 2 | 3 | 4 | class KronHessianConv2d(KronCovConv2d, KronHessian): 5 | 6 | def __init__(self, module, ema_decay=1., damping=0, post_curv=None, recursive_approx=False): 7 | KronHessian.__init__(self, module, ema_decay, damping, post_curv, recursive_approx) 8 | 9 | def update_in_backward(self, grad_output): 10 | KronHessian.update_in_backward(self, grad_output) 11 | -------------------------------------------------------------------------------- /torchsso/curv/hessian/linear.py: -------------------------------------------------------------------------------- 1 | from torchsso import KronCovLinear, KronHessian 2 | 3 | 4 | class KronHessianLinear(KronCovLinear, KronHessian): 5 | 6 | def __init__(self, module, ema_decay=1., damping=0, post_curv=None, recursive_approx=False): 7 | KronHessian.__init__(self, module, ema_decay, damping, post_curv, recursive_approx) 8 | 9 | def update_in_backward(self, grad_output): 10 | KronHessian.update_in_backward(self, grad_output) 11 | -------------------------------------------------------------------------------- /torchsso/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from torchsso.optim.firstorder import DistributedFirstOrderOptimizer # NOQA 2 | from torchsso.optim.secondorder import SecondOrderOptimizer, DistributedSecondOrderOptimizer # NOQA 3 | from torchsso.optim.vi import VIOptimizer, DistributedVIOptimizer, VOGN # NOQA 4 | from torchsso.optim import lr_scheduler # NOQA 5 | -------------------------------------------------------------------------------- /torchsso/optim/firstorder.py: -------------------------------------------------------------------------------- 1 | from torch.optim import Optimizer 2 | from torch.nn.utils import parameters_to_vector, vector_to_parameters 3 | 4 | 5 | class DistributedFirstOrderOptimizer(Optimizer): 6 | 7 | def __init__(self, optimizer, model, dist, lars=False): 8 | super(DistributedFirstOrderOptimizer, self).__setattr__( 9 | 'actual_optimizer', optimizer 10 | ) 11 | super(DistributedFirstOrderOptimizer, self).__setattr__( 12 | 'model', model 13 | ) 14 | super(DistributedFirstOrderOptimizer, self).__setattr__( 15 | 'dist', dist 16 | ) 17 | super(DistributedFirstOrderOptimizer, self).__setattr__( 18 | 'lars', lars 19 | ) 20 | 21 | def step(self, closure=None, thr=1e-2, eps=1e-9): 22 | loss = None 23 | if closure is not None: 24 | loss = closure() 25 | world_size = self.dist.get_world_size() 26 | grads = [p.grad for p in self.model.parameters()] 27 | # pack 28 | packed_tensor = parameters_to_vector(grads) 29 | # all reduce 30 | self.dist.all_reduce(packed_tensor) 31 | # unpack 32 | vector_to_parameters(packed_tensor.div_(world_size), grads) 33 | 34 | if self.lars: 35 | for group in self.param_groups: 36 | for p in group['params']: 37 | setattr(p, 'data_pre', p.data.detach().clone()) 38 | 39 | self.actual_optimizer.step(closure=None) 40 | 41 | if self.lars: 42 | for group in self.param_groups: 43 | for p in group['params']: 44 | d_norm_pre = p.data_pre.norm() 45 | if d_norm_pre > thr: 46 | upd = p.data - p.data_pre 47 | upd_norm = upd.norm() 48 | rate = group['lr'] * d_norm_pre / (upd_norm + eps) 49 | p.data = p.data_pre.add(rate, upd) 50 | 51 | return loss 52 | 53 | def __getattr__(self, item): 54 | return getattr(self.actual_optimizer, item) 55 | 56 | def __setattr__(self, key, value): 57 | if key == 'step': 58 | super().__setattr__(key, value) 59 | else: 60 | setattr(self.actual_optimizer, key, value) 61 | -------------------------------------------------------------------------------- /torchsso/optim/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim import Optimizer 2 | 3 | 4 | class _IterLRScheduler(object): 5 | def __init__(self, optimizer, last_iter=-1): 6 | if not isinstance(optimizer, Optimizer): 7 | raise TypeError('{} is not an Optimizer'.format( 8 | type(optimizer).__name__)) 9 | self.optimizer = optimizer 10 | if last_iter == -1: 11 | for group in optimizer.param_groups: 12 | group.setdefault('initial_lr', group['lr']) 13 | else: 14 | for i, group in enumerate(optimizer.param_groups): 15 | if 'initial_lr' not in group: 16 | raise KeyError("param 'initial_lr' is not specified " 17 | "in param_groups[{}] when resuming an optimizer".format(i)) 18 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) 19 | self.step(last_iter + 1) 20 | self.last_iter = last_iter 21 | self.scheduler_type = 'iter' 22 | 23 | def state_dict(self): 24 | """Returns the state of the scheduler as a :class:`dict`. 25 | It contains an entry for every variable in self.__dict__ which 26 | is not the optimizer. 27 | """ 28 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 29 | 30 | def load_state_dict(self, state_dict): 31 | """Loads the schedulers state. 32 | Arguments: 33 | state_dict (dict): scheduler state. Should be an object returned 34 | from a call to :meth:`state_dict`. 35 | """ 36 | self.__dict__.update(state_dict) 37 | 38 | def get_lr(self): 39 | raise NotImplementedError 40 | 41 | def step(self, iter=None): 42 | if iter is None: 43 | iter = self.last_iter + 1 44 | self.last_iter = iter 45 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 46 | param_group['lr'] = lr 47 | 48 | 49 | class PolynomialDecayIterLR(_IterLRScheduler): 50 | """Set the learning rate of each parameter group to the initial lr decayed 51 | by gamma every iter. When last_iter=-1, sets initial lr as lr. 52 | Args: 53 | optimizer (Optimizer): Wrapped optimizer. 54 | gamma (float): Multiplicative factor of learning rate decay. 55 | last_iter (int): The index of last iter. Default: -1. 56 | """ 57 | 58 | def __init__(self, optimizer, rate, max_count, target=None, start_iter=0, last_iter=-1): 59 | self.rate = rate 60 | self.max_count = max_count 61 | self.target = target 62 | self.start_iter = start_iter 63 | super(PolynomialDecayIterLR, self).__init__(optimizer, last_iter) 64 | 65 | def get_lr(self): 66 | if self.last_iter < self.start_iter: 67 | return [param_group['lr'] 68 | for param_group in self.optimizer.param_groups] 69 | decay = max(1-(self.last_iter-self.start_iter) / (self.max_count-self.start_iter), 0) 70 | if self.target is not None: 71 | if self.rate > 0: 72 | return [self.target if self.target / (base_lr * decay ** self.rate) > 1 73 | else base_lr * decay ** self.rate 74 | for base_lr in self.base_lrs] 75 | else: 76 | return [self.target if self.target / (base_lr * decay ** self.rate) < 1 77 | else base_lr * decay ** self.rate 78 | for base_lr in self.base_lrs] 79 | return [base_lr * decay ** self.rate 80 | for base_lr in self.base_lrs] 81 | 82 | 83 | class GradualWarmupIterLR(_IterLRScheduler): 84 | """Set the learning rate of each parameter group to the initial lr decayed 85 | by gamma every iter. When last_iter=-1, sets initial lr as lr. 86 | Args: 87 | optimizer (Optimizer): Wrapped optimizer. 88 | gamma (float): Multiplicative factor of learning rate decay. 89 | last_iter (int): The index of last iter. Default: -1. 90 | """ 91 | 92 | def __init__(self, optimizer, initial_lr, max_count, last_iter=-1): 93 | self.initial_lr = initial_lr 94 | self.max_count = max_count 95 | super(GradualWarmupIterLR, self).__init__(optimizer, last_iter) 96 | 97 | def get_lr(self): 98 | if self.last_iter > self.max_count: 99 | return [param_group['lr'] 100 | for param_group in self.optimizer.param_groups] 101 | else: 102 | alpha = self.last_iter / self.max_count 103 | return [self.initial_lr*(1-alpha) + base_lr*alpha 104 | for base_lr in self.base_lrs] 105 | 106 | 107 | class MomentumCorrectionLR(object): 108 | 109 | def __init__(self, scheduler): 110 | super(MomentumCorrectionLR, self).__setattr__( 111 | 'scheduler', scheduler) 112 | 113 | for group in self.optimizer.param_groups: 114 | group['init_momentum'] = group['momentum'] 115 | 116 | def step(self, count=None): 117 | self.scheduler.step(count) 118 | 119 | for group in self.optimizer.param_groups: 120 | lr = group['lr'] 121 | lr_pre = group.get('lr_pre', None) 122 | 123 | if lr_pre is not None: 124 | m = group.get('init_momentum', 0) 125 | group['momentum'] = m * lr / lr_pre 126 | 127 | group['lr_pre'] = group['lr'] 128 | 129 | def __getattr__(self, item): 130 | return getattr(self.scheduler, item) 131 | 132 | def __setattr__(self, key, value): 133 | setattr(self.scheduler, key, value) 134 | -------------------------------------------------------------------------------- /torchsso/optim/secondorder.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import math 3 | 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.optim import Optimizer 9 | import torchsso 10 | from torchsso.utils import TensorAccumulator 11 | from torchsso.utils.chainer_communicators import create_communicator 12 | from torchsso.utils.chainer_communicators import _utility 13 | 14 | 15 | class SecondOrderOptimizer(Optimizer): 16 | r"""An optimizer for Second-Order Optimization. 17 | 18 | This optimizer manages the curvatures for each layer as a collection 19 | of torchsso.Curvature instance. 20 | This optimizer updates the params with the gradients pre-conditioned 21 | by the inverse of the curvature for each layer. 22 | 23 | Args: 24 | model (torch.nn.Module): model with parameters to be trained 25 | curv_type (str): type of the curvature ('Hessian', 'Fisher', or 'Cov') 26 | curv_shapes (dict): shape the curvatures for each type of layer 27 | curv_kwargs (dict): arguments (with keys) to be passed to torchsso.Curvature.__init__() 28 | lr (float, optional): learning rate 29 | momentum (float, optional): momentum factor 30 | momentum_type (str, optional): type of gradients of which momentum 31 | is calculated ('raw' or 'preconditioned') 32 | grad_ema_decay (float, optional): decay rate for EMA of gradients 33 | grad_ema_type (str, optional): type of gradients of which EMA 34 | is calculated ('raw' or 'preconditioned') 35 | l2_reg (float, optional): L2 penalty 36 | weight_decay (float, optional): weight decay 37 | normalizing_weights (bool, optional): whether the scale of the params 38 | are normalized after each step 39 | weight_scale (float, optional): the scale of the params for normalizing weights 40 | acc_steps (int, optional): number of steps for which gradients and curvatures 41 | are accumulated before each step 42 | non_reg_for_bn (bool, optional): whether the regularization is applied to BatchNorm params 43 | bias_correction (bool, optional): whether the bias correction (refer torch.optim.Adam) is applied 44 | lars (bool, optional): whether LARS (https://arxiv.org/abs/1708.03888) is applied 45 | lars_type (str, optional): type of gradients of which LARS 46 | is applied ('raw' or 'preconditioned') 47 | update_inv (bool, optional): whether to update curvature inverses at each step 48 | precondition_grad (bool, optional): whether to apply preconditioning 49 | (if False, this optimizer works as SGD) 50 | 51 | Example: 52 | >>> curv_shapes = {"Conv2d": "Kron", "Linear": "Diag"} 53 | >>> curv_kwargs = {"damping": 1e-3, "ema_decay": 0.999} 54 | >>> optimizer = torchsso.optim.SecondOrderOptimizer(model, "Cov", curv_shapes, curv_kwargs) 55 | >>> 56 | >>> def closure(): 57 | >>> optimizer.zero_grad() 58 | >>> output = model(data) 59 | >>> loss = F.cross_entropy(output, target) 60 | >>> loss.backward(create_graph=args.create_graph) 61 | >>> return loss, output 62 | >>> 63 | >>> optimizer.step(closure=closure) 64 | """ 65 | 66 | def __init__(self, model: nn.Module, curv_type: str, curv_shapes: dict, curv_kwargs: dict, 67 | lr=0.01, momentum=0., momentum_type='preconditioned', 68 | grad_ema_decay=1., grad_ema_type='raw', l2_reg=0., weight_decay=0., 69 | normalizing_weights=False, weight_scale=None, 70 | acc_steps=1, non_reg_for_bn=False, bias_correction=False, 71 | lars=False, lars_type='preconditioned', update_inv=True, precondition_grad=True): 72 | 73 | if lr < 0: 74 | raise ValueError("Invalid learning rate: {}".format(lr)) 75 | if momentum < 0: 76 | raise ValueError("Invalid momentum: {}".format(momentum)) 77 | if momentum > 0 and momentum_type not in ['raw', 'preconditioned']: 78 | raise ValueError("Invalid momentum type: {}".format(momentum_type)) 79 | if grad_ema_decay < 0 or 1 < grad_ema_decay: 80 | raise ValueError("Invalid grad_ema value: {}".format(grad_ema_decay)) 81 | if grad_ema_decay > 0 and grad_ema_type not in ['raw', 'preconditioned']: 82 | raise ValueError("Invalid grad_ema type: {}".format(grad_ema_type)) 83 | if l2_reg < 0: 84 | raise ValueError("Invalid l2_reg value: {}".format(l2_reg)) 85 | if weight_decay < 0: 86 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 87 | if acc_steps < 1: 88 | raise ValueError("Invalid acc_steps: {}".format(acc_steps)) 89 | if lars and lars_type not in ['raw', 'preconditioned']: 90 | raise ValueError("Invalid LARS type: {}".format(lars_type)) 91 | if normalizing_weights and weight_scale is not None and weight_scale <= 0: 92 | raise ValueError("Invalid weight scale for LARS: {}".format(weight_scale)) 93 | 94 | self.model = model 95 | defaults = {'lr': lr, 'momentum': momentum, 'momentum_type': momentum_type, 96 | 'grad_ema_decay': grad_ema_decay, 'grad_ema_type': grad_ema_type, 97 | 'l2_reg': l2_reg, 'weight_decay': weight_decay, 98 | 'normalizing_weights': normalizing_weights, 'weight_scale': weight_scale, 99 | 'acc_steps': acc_steps, 'bias_correction': bias_correction, 100 | 'lars': lars, 'lars_type': lars_type} 101 | defaults.update(curv_kwargs) 102 | self.defaults = defaults 103 | self.state = defaultdict(dict) 104 | self.optim_state = {'step': 0, 'acc_step': 0} 105 | 106 | self.param_groups = [] 107 | self.curv_type = curv_type 108 | self.curv_shapes = {} if curv_shapes is None else curv_shapes 109 | self.update_inv = update_inv 110 | self.precondition_grad = precondition_grad 111 | 112 | for module in model.modules(): 113 | if len(list(module.children())) > 0: 114 | continue 115 | params = list(module.parameters()) 116 | if len(params) == 0: 117 | continue 118 | 119 | curv_class = self.get_curv_class(module) 120 | curvature = curv_class(module, **curv_kwargs) 121 | 122 | group = { 123 | 'params': params, 124 | 'curv': curvature, 125 | 'acc_curv': TensorAccumulator(), 126 | 'acc_grads': TensorAccumulator() 127 | } 128 | 129 | self.add_param_group(group) 130 | self.init_buffer(params) 131 | 132 | if non_reg_for_bn and \ 133 | isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 134 | group['l2_reg'] = 0 135 | group['weight_decay'] = 0 136 | group['normalizing_weights'] = False 137 | 138 | def init_buffer(self, params): 139 | for p in params: 140 | state = self.state[p] 141 | state['momentum_buffer'] = torch.zeros_like(p.data) 142 | state['grad_ema_buffer'] = torch.zeros_like(p.data) 143 | 144 | @property 145 | def local_param_groups(self): 146 | return self.param_groups 147 | 148 | def get_curv_class(self, module): 149 | module_name = module.__class__.__name__ 150 | curv_shape = self.curv_shapes.get(module_name, '') 151 | curv_name = curv_shape + self.curv_type + module_name 152 | curv_class = getattr(torchsso, curv_name, None) 153 | 154 | assert curv_class is not None, f"Failed to lookup Curvature class {curv_name} for {module}." 155 | 156 | return curv_class 157 | 158 | def step(self, closure=None): 159 | """Performs a single optimization step. 160 | 161 | Arguments: 162 | closure (callable, optional): A closure that reevaluates the model 163 | and returns the loss. 164 | """ 165 | 166 | n = self.defaults['acc_steps'] 167 | loss = None 168 | 169 | if closure is not None: 170 | # forward and backward 171 | loss = closure() 172 | 173 | # accumulate 174 | for group in self.param_groups: 175 | params = group['params'] 176 | 177 | grads = [p.grad.data for p in params] 178 | group['acc_grads'].update(grads, scale=1/n) 179 | 180 | curv = group['curv'] 181 | if curv is not None: 182 | group['acc_curv'].update(curv.data, scale=1/n) 183 | 184 | # update acc step 185 | self.optim_state['acc_step'] += 1 186 | if self.optim_state['acc_step'] < n: 187 | return loss 188 | else: 189 | self.optim_state['acc_step'] = 0 190 | 191 | self.backward_postprocess() 192 | 193 | self.optim_state['step'] += 1 194 | 195 | for group in self.local_param_groups: 196 | 197 | self.update_preprocess(group, grad_type='raw') 198 | 199 | # update curvature 200 | params, curv = group['params'], group['curv'] 201 | if curv is not None: 202 | curv.step(update_inv=self.update_inv) 203 | if self.precondition_grad: 204 | curv.precondition_grad(params) 205 | 206 | # update params 207 | self.update_preprocess(group, grad_type='preconditioned') 208 | self.update(group) 209 | self.update_postprocess(group) 210 | 211 | return loss 212 | 213 | def backward_postprocess(self, target='params'): 214 | for group in self.param_groups: 215 | params = group[target] 216 | 217 | acc_grads = group['acc_grads'].get() 218 | for p, acc_grad in zip(params, acc_grads): 219 | p.grad = acc_grad.clone() 220 | 221 | curv = group['curv'] 222 | if curv is not None: 223 | curv.data = group['acc_curv'].get() 224 | 225 | def update(self, group, target='params'): 226 | params = group[target] 227 | for p in params: 228 | grad = p.grad 229 | if grad is None: 230 | continue 231 | p.data.add_(-group['lr'], grad) 232 | 233 | def update_preprocess(self, group, target='params', grad_type='raw'): 234 | assert grad_type in ['raw', 'preconditioned'], 'Invalid grad type: {}.'.format(grad_type) 235 | params = group[target] 236 | state = self.state 237 | 238 | def apply_l2_reg(p, grad): 239 | if group['l2_reg'] != 0: 240 | if grad.is_sparse: 241 | raise RuntimeError( 242 | "l2 regularization option is not compatible with sparse gradients") 243 | grad.add_(group['l2_reg'], p.data) 244 | curv = group['curv'] 245 | if curv is not None: 246 | curv.l2_reg = group['l2_reg'] 247 | 248 | def apply_weight_decay(p, grad): 249 | if group['weight_decay'] != 0: 250 | if hasattr(grad, 'is_sparse') and grad.is_sparse: 251 | raise RuntimeError( 252 | "weight_decay option is not compatible with sparse gradients") 253 | grad.add_(group['weight_decay'], p.data) 254 | 255 | def apply_momentum(p, grad): 256 | momentum = group['momentum'] 257 | 258 | if momentum != 0: 259 | buf = state[p]['momentum_buffer'] 260 | buf.mul_(momentum).add_(grad) 261 | grad.copy_(buf) 262 | 263 | def apply_grad_ema_decay(p, grad): 264 | grad_ema_decay = group['grad_ema_decay'] 265 | if grad_ema_decay != 1: 266 | buf = state[p]['grad_ema_buffer'] 267 | buf.mul_(1 - grad_ema_decay).add_(grad.mul(grad_ema_decay)) 268 | grad.copy_(buf) 269 | 270 | def apply_bias_correction(grad): 271 | curv = group['curv'] 272 | beta1 = 1 - group['grad_ema_decay'] 273 | beta2 = 1 - curv.ema_decay 274 | 275 | bias_correction1 = 1 - beta1 ** self.optim_state['step'] 276 | bias_correction2 = 1 - beta2 ** self.optim_state['step'] 277 | if getattr(curv, 'use_sqrt_ema', False): 278 | bias_correction2 = math.sqrt(bias_correction2) 279 | 280 | grad.mul_(bias_correction2 / bias_correction1) 281 | 282 | def apply_lars(p, grad, thr=1e-2, eps=1e-9): 283 | d_norm = p.data.norm() 284 | if d_norm > thr: 285 | g_norm = grad.norm() 286 | rate = d_norm / (g_norm + eps) 287 | grad.mul_(rate) 288 | 289 | for p in params: 290 | 291 | grad = p.grad 292 | 293 | if grad is None: 294 | continue 295 | 296 | if grad_type == 'raw': 297 | apply_l2_reg(p, grad) 298 | 299 | if grad_type == 'preconditioned': 300 | apply_weight_decay(p, grad) 301 | 302 | if group['momentum_type'] == grad_type: 303 | apply_momentum(p, grad) 304 | 305 | if group['grad_ema_type'] == grad_type: 306 | apply_grad_ema_decay(p, grad) 307 | 308 | if grad_type == 'preconditioned' and group['bias_correction']: 309 | apply_bias_correction(grad) 310 | 311 | if group['lars_type'] == grad_type and group['lars']: 312 | apply_lars(p, grad) 313 | 314 | def update_postprocess(self, group, target='params'): 315 | params = group[target] 316 | curv = group['curv'] 317 | 318 | def apply_normalizing_weights(p, thr=1e-2, eps=1e-9): 319 | d_norm = p.data.norm() 320 | if d_norm > thr: 321 | scale = group['weight_scale'] 322 | if scale is None: 323 | scale = np.sqrt(2.0 * w.data.shape[0]) 324 | p.data.div_(d_norm + eps).mul_(scale) 325 | 326 | if group['normalizing_weights']: 327 | for p, _p in zip(params, group['params']): 328 | w = getattr(curv.module, 'weight', None) 329 | if w is not None and w is _p: 330 | apply_normalizing_weights(p) 331 | 332 | 333 | class DistributedSecondOrderOptimizer(SecondOrderOptimizer): 334 | 335 | def __init__(self, *args, **kwargs): 336 | 337 | self.actual_optimizer.__init__(self, *args, **kwargs) 338 | 339 | self.comm = create_communicator() 340 | 341 | local_size = self.comm.size 342 | local_rank = self.comm.rank 343 | indices = np.array_split(np.arange(len(self.param_groups)), local_size) 344 | indices = [local_indices.tolist() for local_indices in indices] 345 | local_indices = indices[local_rank] 346 | local_param_groups = [self.param_groups[i] for i in local_indices] 347 | 348 | self.indices = indices 349 | self.local_indices = local_indices 350 | self._local_param_groups = local_param_groups 351 | setattr(self.comm, 'indices', indices) 352 | 353 | @property 354 | def actual_optimizer(self): 355 | return SecondOrderOptimizer 356 | 357 | @property 358 | def local_param_groups(self): 359 | return self._local_param_groups 360 | 361 | def extractors_for_rsv(self): 362 | extractors = [_utility.extract_attr_from_params('grad'), 363 | _utility.extract_attr_from_curv('data', True)] 364 | return extractors 365 | 366 | def extractors_for_agv(self): 367 | extractors = [_utility.extract_attr_from_params('data')] 368 | return extractors 369 | 370 | def backward_postprocess(self, target='params'): 371 | self.actual_optimizer.backward_postprocess(self, target) 372 | # reduce_scatter_v 373 | self.comm.reduce_scatterv_data(self.param_groups, self.extractors_for_rsv()) 374 | 375 | def is_updated(self): 376 | return self.optim_state['acc_step'] == 0 377 | 378 | def step(self, closure=None): 379 | """Performs a single optimization step. 380 | 381 | Arguments: 382 | closure (callable, optional): A closure that reevaluates the model 383 | and returns the loss. 384 | """ 385 | 386 | ret = self.actual_optimizer.step(self, closure) 387 | 388 | if self.is_updated(): 389 | # all_gather_v 390 | self.comm.allgatherv_data(self.param_groups, self.extractors_for_agv()) 391 | 392 | return ret 393 | 394 | -------------------------------------------------------------------------------- /torchsso/optim/vi.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torchsso.optim import SecondOrderOptimizer, DistributedSecondOrderOptimizer 7 | from torchsso.utils import TensorAccumulator 8 | from torchsso.utils.chainer_communicators import _utility 9 | 10 | 11 | class VIOptimizer(SecondOrderOptimizer): 12 | r"""An optimizer for Variational Inference (VI) based on torch.optim.SecondOrderOptimizer. 13 | 14 | This optimizer manages the posterior distribution (mean and covariance of multivariate Gaussian) 15 | of params for each layer. 16 | 17 | Args: 18 | model (torch.nn.Module): model with parameters to be trained 19 | model (float): dataset size 20 | curv_type (str): type of the curvature ('Hessian', 'Fisher', or 'Cov') 21 | curv_shapes (dict): shape the curvatures for each type of layer 22 | curv_kwargs (dict): arguments (with keys) to be passed to torchsso.Curvature.__init__() 23 | lr (float, optional): learning rate 24 | momentum (float, optional): momentum factor 25 | momentum_type (str, optional): type of gradients of which momentum 26 | is calculated ('raw' or 'preconditioned') 27 | grad_ema_decay (float, optional): decay rate for EMA of gradients 28 | grad_ema_type (str, optional): type of gradients of which EMA 29 | is calculated ('raw' or 'preconditioned') 30 | weight_decay (float, optional): weight decay 31 | normalizing_weights (bool, optional): whether the scale of the params 32 | are normalized after each step 33 | weight_scale (float, optional): the scale of the params for normalizing weights 34 | acc_steps (int, optional): number of steps for which gradients and curvatures 35 | are accumulated before each step 36 | non_reg_for_bn (bool, optional): whether the regularization is applied to BatchNorm params 37 | bias_correction (bool, optional): whether the bias correction (refer torch.optim.Adam) is applied 38 | lars (bool, optional): whether LARS (https://arxiv.org/abs/1708.03888) is applied 39 | lars_type (str, optional): type of gradients of which LARS 40 | is applied ('raw' or 'preconditioned') 41 | num_mc_samples (int, optional): number of MC samples taken from the posterior in each step 42 | val_num_mc_samples (int, optional): number of MC samples taken from the posterior for evaluation 43 | kl_weighting (float, optional): KL weighting (https://arxiv.org/abs/1712.02390) 44 | warmup_kl_weighting_init (float, optional): initial KL weighting for warming up the value 45 | warmup_kl_weighting_steps (float, optional): number of steps until the value reaches the kl_weighting 46 | prior_variance (float, optional): variance of the prior distribution (Gaussian) of each param 47 | init_precision (float, optional): initial (diagonal) precision of the posterior of params 48 | """ 49 | 50 | def __init__(self, model: nn.Module, dataset_size: float, curv_type: str, curv_shapes: dict, curv_kwargs: dict, 51 | lr=0.01, momentum=0., momentum_type='preconditioned', 52 | grad_ema_decay=1., grad_ema_type='raw', weight_decay=0., 53 | normalizing_weights=False, weight_scale=None, 54 | acc_steps=1, non_reg_for_bn=False, bias_correction=False, 55 | lars=False, lars_type='preconditioned', 56 | num_mc_samples=10, val_num_mc_samples=10, 57 | kl_weighting=1, warmup_kl_weighting_init=0.01, warmup_kl_weighting_steps=None, 58 | prior_variance=1, init_precision=None, 59 | seed=1, total_steps=1000): 60 | 61 | if dataset_size < 0: 62 | raise ValueError("Invalid dataset size: {}".format(dataset_size)) 63 | if num_mc_samples < 1: 64 | raise ValueError("Invalid number of MC samples: {}".format(num_mc_samples)) 65 | if val_num_mc_samples < 0: 66 | raise ValueError("Invalid number of MC samples for validation: {}".format(val_num_mc_samples)) 67 | if kl_weighting < 0: 68 | raise ValueError("Invalid KL weighting: {}".format(kl_weighting)) 69 | if warmup_kl_weighting_steps is not None and warmup_kl_weighting_init < 0: 70 | raise ValueError("Invalid initial KL weighting: {}".format(warmup_kl_weighting_init)) 71 | if prior_variance < 0: 72 | raise ValueError("Invalid prior variance: {}".format(prior_variance)) 73 | if init_precision is not None and init_precision < 0: 74 | raise ValueError("Invalid initial precision: {}".format(init_precision)) 75 | 76 | init_kl_weighting = kl_weighting if warmup_kl_weighting_steps is None else warmup_kl_weighting_init 77 | l2_reg = init_kl_weighting / dataset_size / prior_variance if prior_variance != 0 else 0 78 | std_scale = math.sqrt(init_kl_weighting / dataset_size) 79 | 80 | super(VIOptimizer, self).__init__(model, curv_type, curv_shapes, curv_kwargs, 81 | lr=lr, momentum=momentum, momentum_type=momentum_type, 82 | grad_ema_decay=grad_ema_decay, grad_ema_type=grad_ema_type, 83 | l2_reg=l2_reg, weight_decay=weight_decay, 84 | normalizing_weights=normalizing_weights, weight_scale=weight_scale, 85 | acc_steps=acc_steps, non_reg_for_bn=non_reg_for_bn, 86 | bias_correction=bias_correction, 87 | lars=lars, lars_type=lars_type) 88 | 89 | self.defaults['std_scale'] = std_scale 90 | self.defaults['kl_weighting'] = kl_weighting 91 | self.defaults['warmup_kl_weighting_init'] = warmup_kl_weighting_init 92 | self.defaults['warmup_kl_weighting_steps'] = warmup_kl_weighting_steps 93 | self.defaults['num_mc_samples'] = num_mc_samples 94 | self.defaults['val_num_mc_samples'] = val_num_mc_samples 95 | self.defaults['total_steps'] = total_steps 96 | self.defaults['seed_base'] = seed 97 | 98 | for group in self.param_groups: 99 | group['std_scale'] = 0 if group['l2_reg'] == 0 else std_scale 100 | group['mean'] = [p.data.detach().clone() for p in group['params']] 101 | self.init_buffer(group['mean']) 102 | 103 | if init_precision is not None: 104 | curv = group['curv'] 105 | curv.element_wise_init(init_precision) 106 | curv.step(update_std=(group['std_scale'] > 0)) 107 | 108 | def zero_grad(self): 109 | r"""Clears the gradients of all optimized :class:`torch.Tensor` s.""" 110 | for group in self.param_groups: 111 | for m in group['mean']: 112 | if m.grad is not None: 113 | m.grad.detach_() 114 | m.grad.zero_() 115 | 116 | super(VIOptimizer, self).zero_grad() 117 | 118 | @property 119 | def seed(self): 120 | return self.optim_state['step'] + self.defaults['seed_base'] 121 | 122 | def set_random_seed(self, seed=None): 123 | if seed is None: 124 | seed = self.seed 125 | torch.manual_seed(seed) 126 | if torch.cuda.is_available(): 127 | torch.cuda.manual_seed_all(seed) 128 | 129 | def sample_params(self): 130 | 131 | for group in self.param_groups: 132 | params, mean = group['params'], group['mean'] 133 | curv = group['curv'] 134 | if curv is not None and curv.std is not None: 135 | # sample from posterior 136 | curv.sample_params(params, mean, group['std_scale']) 137 | else: 138 | for p, m in zip(params, mean): 139 | p.data.copy_(m.data) 140 | 141 | def copy_mean_to_params(self): 142 | for group in self.param_groups: 143 | params, mean = group['params'], group['mean'] 144 | for p, m in zip(params, mean): 145 | p.data.copy_(m.data) 146 | if getattr(p, 'grad', None) is not None \ 147 | and getattr(m, 'grad', None) is not None: 148 | p.grad.copy_(m.grad) 149 | 150 | def adjust_kl_weighting(self): 151 | warmup_steps = self.defaults['warmup_kl_weighting_steps'] 152 | if warmup_steps is None: 153 | return 154 | 155 | current_step = self.optim_state['step'] 156 | if warmup_steps < current_step: 157 | return 158 | 159 | target_kl = self.defaults['kl_weighting'] 160 | init_kl = self.defaults['warmup_kl_weighting_init'] 161 | 162 | rate = current_step / warmup_steps 163 | kl_weighting = init_kl + rate * (target_kl - init_kl) 164 | 165 | rate = kl_weighting / init_kl 166 | l2_reg = rate * self.defaults['l2_reg'] 167 | std_scale = math.sqrt(rate) * self.defaults['std_scale'] 168 | for group in self.param_groups: 169 | if group['l2_reg'] > 0: 170 | group['l2_reg'] = l2_reg 171 | if group['std_scale'] > 0: 172 | group['std_scale'] = std_scale 173 | 174 | def step(self, closure=None): 175 | """Performs a single optimization step. 176 | 177 | Arguments: 178 | closure (callable, optional): A closure that reevaluates the model 179 | and returns the loss. 180 | 181 | def closure(): 182 | # forward/backward 183 | return loss, output 184 | """ 185 | 186 | m = self.defaults['num_mc_samples'] 187 | n = self.defaults['acc_steps'] 188 | 189 | acc_loss = TensorAccumulator() 190 | acc_prob = TensorAccumulator() 191 | 192 | self.set_random_seed() 193 | 194 | for _ in range(m): 195 | 196 | # sampling 197 | self.sample_params() 198 | 199 | # forward and backward 200 | loss, output = closure() 201 | 202 | acc_loss.update(loss, scale=1/m) 203 | if output.ndim == 2: 204 | prob = F.softmax(output, dim=1) 205 | elif output.ndim == 1: 206 | prob = torch.sigmoid(output) 207 | else: 208 | raise ValueError(f'Invalid ndim {output.ndim}') 209 | acc_prob.update(prob, scale=1/n) 210 | 211 | # accumulate 212 | for group in self.param_groups: 213 | params = group['params'] 214 | 215 | grads = [p.grad.data for p in params] 216 | group['acc_grads'].update(grads, scale=1/m/n) 217 | 218 | curv = group['curv'] 219 | if curv is not None: 220 | group['acc_curv'].update(curv.data, scale=1/m/n) 221 | 222 | loss, prob = acc_loss.get(), acc_prob.get() 223 | 224 | # update acc step 225 | self.optim_state['acc_step'] += 1 226 | if self.optim_state['acc_step'] < n: 227 | return loss, prob 228 | else: 229 | self.optim_state['acc_step'] = 0 230 | 231 | self.backward_postprocess(target='mean') 232 | self.optim_state['step'] += 1 233 | 234 | # update distribution 235 | for group in self.local_param_groups: 236 | 237 | self.update_preprocess(group, target='mean', grad_type='raw') 238 | 239 | # update covariance 240 | mean, curv = group['mean'], group['curv'] 241 | if curv is not None: 242 | curv.step(update_std=(group['std_scale'] > 0)) 243 | curv.precondition_grad(mean) 244 | 245 | # update mean 246 | self.update_preprocess(group, target='mean', grad_type='preconditioned') 247 | self.update(group, target='mean') 248 | self.update_postprocess(group, target='mean') 249 | 250 | # copy mean to param 251 | params = group['params'] 252 | for p, m in zip(params, mean): 253 | p.data.copy_(m.data) 254 | p.grad.copy_(m.grad) 255 | 256 | self.adjust_kl_weighting() 257 | 258 | return loss, prob 259 | 260 | def prediction(self, data, mc=None, keep_probs=False): 261 | 262 | self.set_random_seed(self.optim_state['step']) 263 | 264 | acc_prob = TensorAccumulator() 265 | probs = [] 266 | 267 | mc_samples = self.defaults['val_num_mc_samples'] if mc is None else mc 268 | 269 | use_mean = mc_samples == 0 270 | n = 1 if use_mean else mc_samples 271 | 272 | for _ in range(n): 273 | 274 | if use_mean: 275 | self.copy_mean_to_params() 276 | else: 277 | # sampling 278 | self.sample_params() 279 | 280 | output = self.model(data) 281 | if output.ndim == 2: 282 | prob = F.softmax(output, dim=1) 283 | elif output.ndim == 1: 284 | prob = torch.sigmoid(output) 285 | else: 286 | raise ValueError(f'Invalid ndim {output.ndim}') 287 | 288 | acc_prob.update(prob, scale=1/n) 289 | if keep_probs: 290 | probs.append(prob) 291 | 292 | self.copy_mean_to_params() 293 | 294 | prob = acc_prob.get() 295 | 296 | if keep_probs: 297 | return prob, probs 298 | else: 299 | return prob 300 | 301 | 302 | class VOGN(VIOptimizer): 303 | 304 | def __init__(self, *args, **kwargs): 305 | default_kwargs = dict(lr=1e-3, 306 | curv_type='Cov', 307 | curv_shapes={ 308 | 'Linear': 'Diag', 309 | 'Conv2d': 'Diag', 310 | 'BatchNorm1d': 'Diag', 311 | 'BatchNorm2d': 'Diag' 312 | }, 313 | curv_kwargs={'ema_decay': 0.01, 'damping': 1e-7}, 314 | warmup_kl_weighting_init=0.01, warmup_kl_weighting_steps=1000, 315 | grad_ema_decay=0.1, num_mc_samples=50, val_num_mc_samples=100) 316 | 317 | default_kwargs.update(kwargs) 318 | 319 | super(VOGN, self).__init__(*args, **default_kwargs) 320 | 321 | 322 | class DistributedVIOptimizer(DistributedSecondOrderOptimizer, VIOptimizer): 323 | 324 | def __init__(self, *args, mc_group_id=0, **kwargs): 325 | super(DistributedVIOptimizer, self).__init__(*args, **kwargs) 326 | self.defaults['seed_base'] += mc_group_id * self.defaults['total_steps'] 327 | 328 | @property 329 | def actual_optimizer(self): 330 | return VIOptimizer 331 | 332 | def zero_grad(self): 333 | self.actual_optimizer.zero_grad(self) 334 | 335 | def extractors_for_rsv(self): 336 | extractors = [_utility.extract_attr_from_params('grad', target='mean'), 337 | _utility.extract_attr_from_curv('data', True)] 338 | return extractors 339 | 340 | def extractors_for_agv(self): 341 | extractors = [_utility.extract_attr_from_params('data', target='mean'), 342 | _utility.extract_attr_from_curv('std', True)] 343 | return extractors 344 | 345 | def step(self, closure=None): 346 | ret = super(DistributedVIOptimizer, self).step(closure) 347 | 348 | if self.is_updated(): 349 | self.copy_mean_to_params() 350 | 351 | return ret 352 | 353 | -------------------------------------------------------------------------------- /torchsso/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from torchsso.utils.logger import Logger # NOQA 2 | from torchsso.utils.inv_cupy import inv # NOQA 3 | from torchsso.utils.cholesky_cupy import cholesky # NOQA 4 | from torchsso.utils.accumulator import TensorAccumulator # NOQA 5 | -------------------------------------------------------------------------------- /torchsso/utils/accumulator.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | 4 | class TensorAccumulator(object): 5 | 6 | def __init__(self): 7 | self._accumulation = None 8 | 9 | def check_type(self, data): 10 | accumulation = self._accumulation 11 | 12 | if isinstance(data, list): 13 | assert type(data[0]) == Tensor, 'the type of data has to be list of torch.Tensor or torch.Tensor' 14 | else: 15 | assert type(data) == Tensor, 'the type of data has to be list of torch.Tensor or torch.Tensor' 16 | 17 | if accumulation is not None: 18 | assert type(data) == type(accumulation), \ 19 | 'the type of data ({}) is different from ' \ 20 | 'the type of the accumulation ({})'.format( 21 | type(data), type(accumulation)) 22 | 23 | def update(self, data, scale=1.): 24 | self.check_type(data) 25 | 26 | accumulation = self._accumulation 27 | 28 | if isinstance(data, list): 29 | if accumulation is None: 30 | self._accumulation = [d.mul(scale) for d in data] 31 | else: 32 | self._accumulation = [acc.add(scale, d) 33 | for acc, d in zip(accumulation, data)] 34 | else: 35 | if accumulation is None: 36 | self._accumulation = data.mul(scale) 37 | else: 38 | self._accumulation = accumulation.add(scale, data) 39 | 40 | def get(self, clear=True): 41 | accumulation = self._accumulation 42 | if accumulation is None: 43 | return 44 | 45 | if isinstance(accumulation, list): 46 | data = [d.clone() for d in self._accumulation] 47 | else: 48 | data = accumulation.clone() 49 | 50 | if clear: 51 | self.clear() 52 | 53 | return data 54 | 55 | def clear(self): 56 | self._accumulation = None 57 | 58 | -------------------------------------------------------------------------------- /torchsso/utils/chainer_communicators/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def create_communicator(communicator_name='pure_nccl', 5 | mpi_comm=None, 6 | rsv_comm_dtype=np.float32, 7 | agv_comm_dtype=np.float32, 8 | use_hiercoll=False, 9 | dims=None, 10 | ): 11 | if mpi_comm is None: 12 | import mpi4py.MPI 13 | mpi_comm = mpi4py.MPI.COMM_WORLD 14 | 15 | if communicator_name != 'pure_nccl' and rsv_comm_dtype != np.float32: 16 | raise ValueError( 17 | 'rsv_comm_dtype is only available at \'pure_nccl\' communicator') 18 | 19 | if communicator_name != 'pure_nccl' and agv_comm_dtype != np.float32: 20 | raise ValueError( 21 | 'agv_comm_dtype is only available at \'pure_nccl\' communicator') 22 | 23 | if communicator_name != 'pure_nccl' and dims is not None: 24 | raise ValueError( 25 | 'dims is only available at \'pure_nccl\' communicator') 26 | 27 | if communicator_name == 'pure_nccl': 28 | from torchsso.utils.chainer_communicators.pure_nccl_communicator \ 29 | import PureNCCLCommunicator 30 | return PureNCCLCommunicator(mpi_comm, 31 | rsv_comm_dtype=rsv_comm_dtype, 32 | agv_comm_dtype=agv_comm_dtype, 33 | use_hiercoll=use_hiercoll, 34 | dims=dims 35 | ) 36 | else: 37 | raise ValueError( 38 | 'Unrecognized communicator_name: {}'.format(communicator_name)) 39 | -------------------------------------------------------------------------------- /torchsso/utils/chainer_communicators/_utility.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy 4 | try: 5 | import cupy 6 | from torchsso.utils.cupy import to_cupy 7 | except: 8 | pass 9 | # print("No cupy detected") 10 | 11 | from chainer.backends import cuda 12 | 13 | import torch 14 | 15 | 16 | class Packer(object): 17 | 18 | def __init__(self): 19 | self.unpack_kernel = cupy.ElementwiseKernel( 20 | 'raw T vec, int32 matrix_size', 21 | 'raw T mat', 22 | """ 23 | int x = i % matrix_size; 24 | int y = i / matrix_size; 25 | if( x < y ) { 26 | int tmp = y; 27 | y = x; 28 | x = tmp; 29 | } 30 | mat[i] = vec[matrix_size * y - y * (y + 1) / 2 + x]; 31 | """, 32 | 'unpack' 33 | ) 34 | 35 | def pack(self, arrays, gpu_buf, sizeof_dtype, stream, offset=0): 36 | buf_offset = offset * sizeof_dtype 37 | for local_arrays in arrays: 38 | for array, triangular in local_arrays: 39 | if triangular: 40 | nbytes = self._put_triangular_matrix_to_device_memory( 41 | array, gpu_buf, buf_offset, stream) 42 | else: 43 | nbytes = array.size * sizeof_dtype 44 | gpu_buf.from_device(array, nbytes, buf_offset, stream) 45 | buf_offset += nbytes 46 | 47 | def unpack(self, arrays, gpu_buf, sizeof_dtype, stream, offset=0): 48 | buf_offset = offset * sizeof_dtype 49 | for local_arrays in arrays: 50 | for array, triangular in local_arrays: 51 | if triangular: 52 | nbytes = self._get_triangular_matrix_from_device_memory( 53 | array, gpu_buf, buf_offset, stream) 54 | else: 55 | nbytes = array.size * sizeof_dtype 56 | gpu_buf.to_device(array, nbytes, buf_offset, stream) 57 | buf_offset += nbytes 58 | 59 | def _put_triangular_matrix_to_device_memory( 60 | self, array, memory, offset, stream): 61 | """Puts a triangular matrix to ``DeviceMemory`` 62 | """ 63 | if array.dtype.char == 'f' or array.dtype.char == 'd': 64 | dtype = array.dtype.char 65 | else: 66 | dtype = numpy.find_common_type((array.dtype.char, 'f'), ()).char 67 | 68 | cublas_handle = cupy.cuda.device.get_cublas_handle() 69 | 70 | if array.shape[0] != array.shape[1]: 71 | raise RuntimeError('non square matrix') 72 | 73 | n = array.shape[0] 74 | nelems = n * (n + 1) // 2 75 | nbytes = nelems * array.dtype.itemsize 76 | 77 | if dtype == 'f': 78 | trttp = cupy.cuda.cublas.strttp 79 | else: 80 | trttp = cupy.cuda.cublas.dtrttp 81 | 82 | with stream: 83 | trttp(cublas_handle, cupy.cuda.cublas.CUBLAS_FILL_MODE_LOWER, n, 84 | array.data.ptr, n, memory.ptr() + offset) 85 | 86 | return nbytes 87 | 88 | def _get_triangular_matrix_from_device_memory( 89 | self, array, memory, offset, stream): 90 | """Gets a triangular matrix from ``DeviceMemory`` 91 | """ 92 | if array.shape[0] != array.shape[1]: 93 | raise RuntimeError('non square matrix') 94 | 95 | n = array.shape[0] 96 | nelems = n * (n + 1) // 2 97 | nbytes = nelems * array.dtype.itemsize 98 | 99 | with stream: 100 | self.unpack_kernel( 101 | memory.array(nelems, offset=offset, dtype=array.dtype), 102 | n, array, size=n * n) 103 | 104 | return nbytes 105 | 106 | 107 | def _check_array(array, name): 108 | xp = cuda.get_array_module(array) 109 | with cuda.get_device_from_array(array): 110 | if not array.dtype == xp.float32: 111 | warnings.warn('non FP32 dtype detected in {}'.format(name)) 112 | array = array.astype(xp.float32) 113 | if not (array.flags.c_contiguous or array.flags.f_contiguous): 114 | warnings.warn('non contiguous array detected in {}'.format(name)) 115 | array = xp.ascontiguousarray(array) 116 | return array 117 | 118 | 119 | def extract(param_groups, indices, extractors): 120 | """Extracts arrays from given fisher blocks using indices and extractors 121 | 122 | Args: 123 | fblocks: List of ``FisherBlock`` instances 124 | indices: List of ``int``s 125 | extractors: Callable that extract arrays from a given ``FisherBlock`` 126 | 127 | Return: 128 | List of tuple(array, bool). Second item indicates triangular flag. 129 | """ 130 | arrays = [] 131 | for local_indices in indices: 132 | local_arrays = [] 133 | for index in local_indices: 134 | for extractor in extractors: 135 | for array in extractor(param_groups[index]): 136 | local_arrays.append(array) 137 | arrays.append(local_arrays) 138 | return arrays 139 | 140 | 141 | def extract_attr_from_params(attr, target='params', triangular=False): 142 | """Extracts arrays from all ``Parameter``s in a given ``FisherBlock`` 143 | """ 144 | 145 | def _extract_attr_from_params(group): 146 | arrays = [] 147 | for param in group[target]: 148 | x = getattr(param, attr, None) 149 | if x is not None: 150 | #x = _check_array(x, fblock.linkname) 151 | #setattr(param, attr, x) 152 | x_ten = x.data 153 | x_cp = to_cupy(x_ten) 154 | arrays.append((x_cp, triangular)) 155 | return arrays 156 | 157 | return _extract_attr_from_params 158 | 159 | 160 | def extract_attr_from_curv(attr, triangular=False): 161 | """Extracts arrays from all ``Parameter``s in a given ``FisherBlock`` 162 | """ 163 | 164 | def _extract_attr_from_curv(group): 165 | arrays = [] 166 | 167 | curv = group['curv'] 168 | if curv is None: 169 | return arrays 170 | 171 | target = getattr(curv, attr, None) 172 | if target is None: 173 | if curv.data is not None: 174 | zeros = [] 175 | for x in curv.data: 176 | zeros.append(torch.zeros_like(x)) 177 | setattr(curv, attr, zeros) 178 | target = getattr(curv, attr) 179 | else: 180 | return arrays 181 | 182 | for x in target: 183 | #x = _check_array(x, fblock.linkname) 184 | #setattr(param, attr, x) 185 | x_ten = x.data 186 | x_cp = to_cupy(x_ten) 187 | _triangular = triangular and x_cp.ndim == 2 and x_cp.shape[0] == x_cp.shape[1] 188 | arrays.append((x_cp, _triangular)) 189 | 190 | return arrays 191 | 192 | return _extract_attr_from_curv 193 | 194 | 195 | def get_nelems(arrays): 196 | """Computes number of elements from given arrays using the triangular flag. 197 | """ 198 | nelems = 0 199 | for local_arrays in arrays: 200 | for array, triangular in local_arrays: 201 | if triangular: 202 | if array.shape[0] != array.shape[1]: 203 | raise RuntimeError('get_nelems: not a square matrix') 204 | nelems += array.shape[0] * (array.shape[0] + 1) // 2 205 | else: 206 | nelems += array.size 207 | return nelems 208 | 209 | 210 | def assign(gpu_buf, nbytes): 211 | if nbytes > gpu_buf.size: 212 | gpu_buf.assign(nbytes) 213 | return True 214 | return False 215 | 216 | 217 | def allocate_asgrad(fblocks, attr): 218 | for fblock in fblocks: 219 | for _, param in sorted(fblock.link.namedparams()): 220 | if not hasattr(param, attr): 221 | # We need to allocate memory space for recieving data 222 | _grad = param.grad.copy() 223 | _grad.fill(0.) 224 | setattr(param, attr, _grad) 225 | -------------------------------------------------------------------------------- /torchsso/utils/chainer_communicators/base.py: -------------------------------------------------------------------------------- 1 | from chainermn.communicators import mpi_communicator_base 2 | import warnings 3 | 4 | from torchsso.utils.chainer_communicators import _utility 5 | 6 | 7 | class KFACCommunicatorBase(mpi_communicator_base.MpiCommunicatorBase): 8 | 9 | def __init__(self, mpi_comm): 10 | super(KFACCommunicatorBase, self).__init__(mpi_comm) 11 | self.indices = None 12 | self.packer = _utility.Packer() 13 | 14 | def allreduce_grad(self): 15 | # We don't use AllReduce for training K-FAC 16 | warnings.warn('AllReduce called, skipping...') 17 | 18 | def reduce_scatterv_data(self, fblocks, extractors): 19 | raise NotImplementedError 20 | 21 | def allgatherv_data(self, fblocks, extractors): 22 | raise NotImplementedError 23 | -------------------------------------------------------------------------------- /torchsso/utils/cholesky_cupy.py: -------------------------------------------------------------------------------- 1 | try: 2 | import cupy 3 | from torchsso.utils.cupy import to_cupy, from_cupy 4 | except: 5 | # print("No cupy detected") 6 | pass 7 | 8 | 9 | def cholesky(m, upper=True): 10 | m_cp = to_cupy(m) 11 | m_chl_cp = cupy.linalg.decomposition.cholesky(m_cp) 12 | if upper: 13 | m_chl_cp = m_chl_cp.transpose() 14 | return from_cupy(m_chl_cp) 15 | -------------------------------------------------------------------------------- /torchsso/utils/cupy.py: -------------------------------------------------------------------------------- 1 | try: 2 | import cupy 3 | except: 4 | # print("No cupy detected") 5 | pass 6 | 7 | from torch.utils.dlpack import to_dlpack, from_dlpack 8 | 9 | 10 | def to_cupy(m_tensor): 11 | return cupy.fromDlpack(to_dlpack(m_tensor)) 12 | 13 | 14 | def from_cupy(m_cp): 15 | return from_dlpack(m_cp.toDlpack()) 16 | -------------------------------------------------------------------------------- /torchsso/utils/inv_cupy.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import scipy 3 | import torch 4 | 5 | try: 6 | import cupy 7 | from cupy import cuda 8 | from cupy.cuda import cublas 9 | from cupy.cuda import device 10 | from cupy.linalg import util 11 | if cuda.cusolver_enabled: 12 | from cupy.cuda import cusolver 13 | from torchsso.utils.cupy import to_cupy, from_cupy 14 | except: 15 | pass 16 | # print("No cupy detected") 17 | 18 | 19 | import warnings 20 | 21 | 22 | use_cholesky = True 23 | 24 | # Based cupy (cupy/cupy/linalg/solve.py) @ 067f830 25 | 26 | 27 | def inv(m): 28 | if torch.cuda.is_available(): 29 | m_cp = to_cupy(m) 30 | m_inv_cp = inv_core(m_cp, use_cholesky) 31 | return from_cupy(m_inv_cp) 32 | else: 33 | result = torch.from_numpy(scipy.linalg.inv(m.cpu().numpy())) 34 | return result 35 | 36 | 37 | def inv_core(a, cholesky=False): 38 | """Computes the inverse of a matrix. 39 | This function computes matrix ``a_inv`` from n-dimensional regular matrix 40 | ``a`` such that ``dot(a, a_inv) == eye(n)``. 41 | Args: 42 | a (cupy.ndarray): The regular matrix 43 | b (Boolean): Use cholesky decomposition 44 | Returns: 45 | cupy.ndarray: The inverse of a matrix. 46 | .. seealso:: :func:`numpy.linalg.inv` 47 | """ 48 | 49 | xp = cupy.get_array_module(a) 50 | if xp == numpy: 51 | if cholesky: 52 | warnings.warn( 53 | "Current fast-inv using cholesky doesn't support numpy.ndarray.") 54 | return numpy.linalg.inv(a) 55 | 56 | if not cuda.cusolver_enabled: 57 | raise RuntimeError('Current cupy only supports cusolver in CUDA 8.0') 58 | 59 | # to prevent `a` to be overwritten 60 | a = a.copy() 61 | 62 | util._assert_cupy_array(a) 63 | util._assert_rank2(a) 64 | util._assert_nd_squareness(a) 65 | 66 | if a.dtype.char == 'f' or a.dtype.char == 'd': 67 | dtype = a.dtype.char 68 | else: 69 | dtype = numpy.find_common_type((a.dtype.char, 'f'), ()).char 70 | 71 | cusolver_handle = device.get_cusolver_handle() 72 | dev_info = cupy.empty(1, dtype=cupy.int) 73 | m = a.shape[0] 74 | 75 | b = cupy.eye(m, dtype=dtype) 76 | 77 | if not cholesky: 78 | if dtype == 'f': 79 | getrf = cusolver.sgetrf 80 | getrf_bufferSize = cusolver.sgetrf_bufferSize 81 | getrs = cusolver.sgetrs 82 | else: # dtype == 'd' 83 | getrf = cusolver.dgetrf 84 | getrf_bufferSize = cusolver.dgetrf_bufferSize 85 | getrs = cusolver.dgetrs 86 | 87 | buffersize = getrf_bufferSize(cusolver_handle, m, m, a.data.ptr, m) 88 | 89 | # TODO(y1r): cache buffer to avoid malloc 90 | workspace = cupy.empty(buffersize, dtype=dtype) 91 | ipiv = cupy.empty((a.shape[0], 1), dtype=dtype) 92 | 93 | # LU Decomposition 94 | getrf(cusolver_handle, m, m, a.data.ptr, m, 95 | workspace.data.ptr, ipiv.data.ptr, dev_info.data.ptr) 96 | 97 | # TODO(y1r): check dev_info status 98 | 99 | # solve for the inverse 100 | getrs(cusolver_handle, 0, m, m, a.data.ptr, m, 101 | ipiv.data.ptr, b.data.ptr, m, dev_info.data.ptr) 102 | 103 | # TODO(y1r): check dev_info status 104 | else: 105 | if dtype == 'f': 106 | potrf = cusolver.spotrf 107 | potrf_bufferSize = cusolver.spotrf_bufferSize 108 | potrs = cusolver.spotrs 109 | else: # dtype == 'd' 110 | potrf = cusolver.dpotrf 111 | potrf_bufferSize = cusolver.dpotrf_bufferSize 112 | potrs = cusolver.dpotrs 113 | 114 | buffersize = potrf_bufferSize( 115 | cusolver_handle, cublas.CUBLAS_FILL_MODE_UPPER, m, a.data.ptr, m) 116 | 117 | # TODO(y1r): cache buffer to avoid malloc 118 | workspace = cupy.empty(buffersize, dtype=dtype) 119 | 120 | # Cholesky Decomposition 121 | potrf(cusolver_handle, cublas.CUBLAS_FILL_MODE_UPPER, m, 122 | a.data.ptr, m, workspace.data.ptr, buffersize, dev_info.data.ptr) 123 | 124 | # TODO(y1r): check dev_info status 125 | 126 | # solve for the inverse 127 | potrs(cusolver_handle, cublas.CUBLAS_FILL_MODE_UPPER, m, 128 | m, a.data.ptr, m, b.data.ptr, m, dev_info.data.ptr) 129 | 130 | # TODO(y1r): check dev_info status 131 | 132 | return b 133 | -------------------------------------------------------------------------------- /torchsso/utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | import shutil 5 | 6 | 7 | # Select the best-resolution timer function 8 | try: 9 | _get_time = time.perf_counter 10 | except AttributeError: 11 | if os.name == 'nt': 12 | _get_time = time.clock 13 | else: 14 | _get_time = time.time 15 | 16 | 17 | class Logger(object): 18 | 19 | def __init__(self, out, logname): 20 | self.out = out 21 | self.logname = logname 22 | self._log = [] 23 | self._start_at = None 24 | 25 | if not os.path.isdir(self.out): 26 | os.makedirs(self.out) 27 | 28 | def start(self): 29 | self._start_at = _get_time() 30 | 31 | @property 32 | def elapsed_time(self): 33 | if self._start_at is None: 34 | raise RuntimeError('training has not been started yet') 35 | return _get_time() - self._start_at 36 | 37 | def write(self, log): 38 | self._log.append(log) 39 | tmp_path = os.path.join(self.out, 'tmp') 40 | with open(tmp_path, 'w') as f: 41 | json.dump(self._log, f, indent=4) 42 | 43 | path = os.path.join(self.out, self.logname) 44 | shutil.move(tmp_path, path) 45 | 46 | --------------------------------------------------------------------------------