├── .gitignore ├── LICENSE ├── MegaAdversarial ├── README.md └── src │ ├── __init__.py │ └── attacks │ ├── __init__.py │ ├── attack.py │ ├── base.py │ ├── fgsm.py │ └── pgd.py ├── README.md ├── examples ├── cifar10 │ ├── Build the model.ipynb │ ├── Evaluate the model.ipynb │ ├── assets │ │ ├── fgsm_random_train.png │ │ ├── fgsm_random_train_clean_test.png │ │ ├── fgsm_random_train_fgsm_eps_8_255_test.png │ │ └── fgsm_random_train_pgd_eps_8_255_lr_2_255_iters_7_test.png │ ├── checkpoints │ │ ├── accuracy │ │ │ ├── fgsm_random_8_255_clean.pkl │ │ │ ├── fgsm_random_8_255_fgsm_8_255.pkl │ │ │ ├── fgsm_random_8_255_pgd_8_255_2_255_7.pkl │ │ │ ├── fgsm_random_8_255_smoothing_00125_clean.pkl │ │ │ ├── fgsm_random_8_255_smoothing_00125_fgsm_8_255.pkl │ │ │ └── fgsm_random_8_255_smoothing_00125_pgd_8_255_2_255_7.pkl │ │ ├── fgsm_random_8_255.pth │ │ ├── fgsm_random_8_255_smoothing_00125.pth │ │ └── fgsm_random_8_255_switching_05_05215_04875.pth │ └── train_and_attack.py └── mnist │ ├── Build the model.ipynb │ ├── Evaluate the model.ipynb │ ├── assets │ ├── mnist_adv.pdf │ └── mnist_adv.png │ ├── checkpoints │ └── checkpoint_15444.pth │ └── train_and_attack.py └── sopa ├── __init__.py └── src ├── __init__.py ├── models ├── __init__.py ├── odenet_cifar10 │ ├── __init__.py │ ├── data.py │ ├── layers.py │ └── utils.py ├── odenet_mnist │ ├── attacks_runner.py │ ├── attacks_utils.py │ ├── data.py │ ├── layers.py │ ├── metrics.py │ ├── runner.py │ ├── runner_new.py │ ├── runner_old.py │ ├── train_validate.py │ └── utils.py └── utils.py └── solvers ├── __init__.py ├── euler.py ├── rk_parametric.py ├── rk_parametric_order2stage2.py ├── rk_parametric_order3stage3.py ├── rk_parametric_order4stage4.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | *.log 3 | .data/ 4 | data/* 5 | data/ 6 | *.pyc 7 | events/* 8 | .ipynb_checkpoints/ 9 | .ipynb_checkpoints/* 10 | gifs/* 11 | wandb/ 12 | tests/ 13 | tests/* 14 | checkpoints/accuracy 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, MetaSolver 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /MegaAdversarial/README.md: -------------------------------------------------------------------------------- 1 | ### Make neural networks robust again -------------------------------------------------------------------------------- /MegaAdversarial/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/MegaAdversarial/src/__init__.py -------------------------------------------------------------------------------- /MegaAdversarial/src/attacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Clean, Clean2Ensemble 2 | from .pgd import PGD 3 | from .fgsm import FGSM, FGSMRandom, FGSM2Ensemble 4 | -------------------------------------------------------------------------------- /MegaAdversarial/src/attacks/attack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 4 | 5 | class Attack(nn.Module): 6 | def __init__(self, model): 7 | super().__init__() 8 | self.model = model 9 | self.device = device 10 | 11 | def _project(self, x): 12 | return torch.clamp(x, 0, 1) 13 | 14 | def _clamp(self, x, min, max): 15 | return torch.max(torch.min(x, max), min) 16 | 17 | def forward(self, *args, **kwargs): 18 | raise NotImplementedError 19 | 20 | 21 | class Attack2Ensemble(nn.Module): 22 | def __init__(self, models): 23 | super().__init__() 24 | self.models = models 25 | self.device = device 26 | 27 | def _project(self, x): 28 | return torch.clamp(x, 0, 1) 29 | 30 | def _clamp(self, x, min, max): 31 | return torch.max(torch.min(x, max), min) 32 | 33 | def forward(self, *args, **kwargs): 34 | raise NotImplementedError -------------------------------------------------------------------------------- /MegaAdversarial/src/attacks/base.py: -------------------------------------------------------------------------------- 1 | from .attack import Attack, Attack2Ensemble 2 | 3 | 4 | class Clean(Attack): 5 | def forward(self, x, y, kwargs): 6 | return x, y 7 | 8 | class Clean2Ensemble(Attack2Ensemble): 9 | def forward(self, x, y, kwargs_arr): 10 | return x, y 11 | -------------------------------------------------------------------------------- /MegaAdversarial/src/attacks/fgsm.py: -------------------------------------------------------------------------------- 1 | from .attack import Attack, Attack2Ensemble 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.transforms as transforms 5 | 6 | 7 | 8 | class FGSM(Attack): 9 | """ 10 | The standard FGSM attack. Assumes (0, 1) normalization. 11 | """ 12 | 13 | def __init__(self, model, eps=None, mean=None, std=None): 14 | super(FGSM, self).__init__(model) 15 | self.eps = eps 16 | self.loss_fn = nn.CrossEntropyLoss().to(self.device) 17 | self.mean = mean if mean is not None else (0., 0., 0.) 18 | self.std = std if std is not None else (1., 1., 1.) 19 | 20 | 21 | def forward(self, x, y, kwargs): 22 | 23 | training = self.model.training 24 | if training: 25 | self.model.eval() 26 | 27 | inv_normalize = transforms.Normalize(mean=[-m/s for m, s in zip(self.mean, self.std)], std=[1/s for s in self.std]) 28 | normalize = transforms.Normalize(mean=self.mean, std=self.std) 29 | x = inv_normalize(x) # x in [0, 1] 30 | 31 | x_attacked = x.clone().detach() 32 | x_attacked.requires_grad_(True) 33 | loss = self.loss_fn(self.model(normalize(x_attacked), **kwargs), y) 34 | grad = torch.autograd.grad( 35 | [loss], [x_attacked], create_graph=False, retain_graph=False 36 | )[0] 37 | x_attacked = x_attacked + self.eps * grad.sign() 38 | x_attacked = self._project(x_attacked) 39 | x_attacked = normalize(x_attacked) 40 | x_attacked = x_attacked.detach() 41 | if training: 42 | self.model.train() 43 | return x_attacked, y 44 | 45 | 46 | def clamp(X, lower_limit, upper_limit): 47 | if not isinstance(upper_limit, torch.Tensor): 48 | upper_limit = torch.tensor(upper_limit, device=X.device, dtype=X.dtype) 49 | if not isinstance(lower_limit, torch.Tensor): 50 | lower_limit = torch.tensor(lower_limit, device=X.device, dtype=X.dtype) 51 | return torch.max(torch.min(X, upper_limit), lower_limit) 52 | 53 | 54 | class FGSMRandom(Attack): 55 | """ 56 | The standard FGSM attack. Assumes (0, 1) normalization. 57 | This implementation is inspired by the implementation from here: 58 | https://github.com/locuslab/fast_adversarial/blob/master/CIFAR10/train_fgsm.py 59 | """ 60 | 61 | def __init__(self, model, alpha, epsilon=None, mu=None, std=None): 62 | ''' 63 | Args: 64 | model: the neural network model 65 | alpha: the step size 66 | epsilon: the radius of the random noise 67 | mu: the mean value of all dataset samples 68 | std: the std value of all dataset samples 69 | ''' 70 | super(FGSMRandom, self).__init__(model) 71 | self.epsilon = epsilon 72 | self.alpha = alpha 73 | if (mu is not None) and (std is not None): 74 | mu = torch.tensor(mu, device=self.device).view(1, 3, 1, 1) 75 | std = torch.tensor(std, device=self.device).view(1, 3, 1, 1) 76 | 77 | self.lower_limit = (0. - mu) / std 78 | self.upper_limit = (1. - mu) / std # lower = -mu/std, upper=(1-mu)/std 79 | 80 | self.epsilon = self.epsilon / std 81 | self.alpha = self.alpha / std 82 | else: 83 | self.lower_limit = 0. 84 | self.upper_limit = 1. 85 | 86 | self.loss_fn = nn.CrossEntropyLoss().to(self.device) 87 | 88 | def forward(self, x, y, kwargs): 89 | training = self.model.training 90 | if training: 91 | self.model.eval() 92 | 93 | delta = self.epsilon - (2 * self.epsilon) * torch.rand_like(x) # Uniform[-eps, eps] 94 | delta.data = clamp(delta, self.lower_limit - x, self.upper_limit - x) 95 | delta.requires_grad = True 96 | output = self.model(x + delta, **kwargs) 97 | loss = self.loss_fn(output, y) 98 | loss.backward() 99 | grad = delta.grad.detach() 100 | delta.data = clamp(delta + self.alpha * torch.sign(grad), -self.epsilon, self.epsilon) 101 | delta.data = clamp(delta, self.lower_limit - x, self.upper_limit - x) 102 | delta = delta.detach() 103 | 104 | if training: 105 | self.model.train() 106 | return x + delta, y 107 | 108 | 109 | class FGSM2Ensemble(Attack2Ensemble): 110 | """ 111 | The standard FGSM attack. Assumes (0, 1) normalization. 112 | """ 113 | 114 | def __init__(self, models, eps=None, mean=None, std=None): 115 | super(FGSM2Ensemble, self).__init__(models) 116 | self.eps = eps 117 | self.loss_fn = nn.NLLLoss().to(self.device) 118 | self.mean = mean if mean is not None else (0., 0., 0.) 119 | self.std = std if std is not None else (1., 1., 1.) 120 | 121 | def forward(self, x, y, kwargs_arr): 122 | 123 | training = self.models[0].training 124 | if training: 125 | for model in self.models: 126 | model.eval() 127 | 128 | inv_normalize = transforms.Normalize(mean=[-m/s for m, s in zip(self.mean, self.std)], std=[1/s for s in self.std]) 129 | normalize = transforms.Normalize(mean=self.mean, std=self.std) 130 | x = inv_normalize(x) # x in [0, 1] 131 | 132 | x_attacked = x.clone().detach() 133 | x_attacked.requires_grad_(True) 134 | 135 | probs_ensemble = 0 136 | 137 | for model, kwargs in zip(self.models, kwargs_arr): 138 | logits = model(normalize(x_attacked), **kwargs) 139 | probs_ensemble = probs_ensemble + nn.Softmax()(logits) 140 | 141 | probs_ensemble /= len(self.models) 142 | 143 | loss = self.loss_fn(torch.log(probs_ensemble), y) 144 | grad = torch.autograd.grad( 145 | [loss], [x_attacked], create_graph=False, retain_graph=False 146 | )[0] 147 | x_attacked = x_attacked + self.eps * grad.sign() 148 | x_attacked = self._project(x_attacked) 149 | x_attacked = normalize(x_attacked) 150 | x_attacked = x_attacked.detach() 151 | 152 | if training: 153 | for model in self.models: 154 | model.train() 155 | return x_attacked, y 156 | -------------------------------------------------------------------------------- /MegaAdversarial/src/attacks/pgd.py: -------------------------------------------------------------------------------- 1 | from .attack import Attack 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.transforms as transforms 5 | 6 | 7 | 8 | class PGD(Attack): 9 | """ 10 | The standard PGD attack. Assumes (0, 1) normalization. 11 | """ 12 | 13 | def __init__(self, model, eps=None, lr=None, n_iter=None, randomized_start=True, mean=None, std=None): 14 | super(PGD, self).__init__(model) 15 | self.eps = eps 16 | self.lr = lr 17 | self.n_iter = n_iter 18 | self.randomized_start = randomized_start 19 | self.loss_fn = nn.CrossEntropyLoss().to(self.device) 20 | self.mean = mean if mean is not None else (0., 0., 0.) 21 | self.std = std if std is not None else (1., 1., 1.) 22 | 23 | def forward(self, x, y, kwargs): 24 | training = self.model.training 25 | if training: 26 | self.model.eval() 27 | 28 | inv_normalize = transforms.Normalize(mean=[-m/s for m, s in zip(self.mean, self.std)], std=[1/s for s in self.std]) 29 | normalize = transforms.Normalize(mean=self.mean, std=self.std) 30 | x = inv_normalize(x) # x in [0, 1] 31 | 32 | if self.randomized_start: 33 | x_attacked = ( 34 | self._project(x + torch.zeros_like(x).uniform_(-self.eps, self.eps)) 35 | .clone() 36 | .detach() 37 | ) 38 | else: 39 | x_attacked = x.clone().detach() 40 | 41 | for i in range(self.n_iter): 42 | x_attacked.requires_grad_(True) 43 | loss = self.loss_fn(self.model(normalize(x_attacked), **kwargs), y) 44 | grad = torch.autograd.grad( 45 | [loss], [x_attacked], create_graph=False, retain_graph=False 46 | )[0] 47 | x_attacked = self._clamp( 48 | x_attacked + self.lr * grad.sign(), x - self.eps, x + self.eps 49 | ) 50 | x_attacked = self._project(x_attacked) 51 | if i == self.n_iter - 1: 52 | x_attacked = normalize(x_attacked) 53 | x_attacked = x_attacked.detach() 54 | 55 | if training: 56 | self.model.train() 57 | return x_attacked, y -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Meta-Solver for Neural Ordinary Differential Equations 2 | Towards robust neural ODEs using parametrized solvers. 3 | 4 | # Main idea 5 | Each *Runge-Kutta (RK)* solver with `s` stages and of the `p`-th order is defined by a table of coefficients (*Butcher tableau*). For `s=p=2`, `s=p=3` and `s=p=4` all coefficient in the table can be parametrized with no more than two variables [1]. 6 | 7 | Usually, during neural ODE training RK solver with fixed Butcher tableau is used, and only the *right-hand side (RHS)* function is trained. We propose to use the whole parametric family of RK solvers to improve robustness of neural ODEs. 8 | 9 | # Requirements 10 | - pytorch==1.7 11 | - apex==0.1 (for training) 12 | 13 | # Examples 14 | For CIFAR-10 and MNIST demo, please, check `examples` folder. 15 | 16 | # Meta Solver Regimes 17 | In the notebook `examples/cifar10/Evaluate model.ipynb` we show how to perform the forward pass through the Neural ODE using different types of Meta Solver regimes, namely 18 | - Standalone 19 | - Solver switching/smoothing 20 | - Solver ensembling 21 | - Model ensembling 22 | 23 | In more details, usage of different regimes means 24 | - **Standalone** 25 | - Use one solver during inference. 26 | - This regime is applied in the training and testing stages. 27 | 28 | 29 | 30 | - **Solver switching / smoothing** 31 | - For each batch one solver is chosen from a group of solvers with finite (in switching regime) or infinite (in smoothing regime) number of candidates. 32 | - This regime is applied in the training stage 33 | 34 | 35 | - **Solver ensembling** 36 | - Use several solvers durung inference. 37 | - Outputs of ODE Block (obtained with different solvers) are averaged before propagating through the next layer. 38 | - This regime is applied in the training and testing stages. 39 | 40 | 41 | - **Model ensembling** 42 | - Use several solvers durung inference. 43 | - Model probabilites obtained via propagation with different solvers are averaged to get the final result. 44 | - This regime is applied in the training and testing stages. 45 | 46 | # Selected results 47 | ## Different solver parameterizations yield different robustness 48 | We have trained a neural ODE model several times, using different ``u`` values in parametrization of the 2-nd order Runge-Kutta solver. The image below depicts robust accuracies for the MNIST classification task. We use PGD attack (eps=0.3, lr=2/255 and iters=7). The mean values of robust accuracy (bold lines) and +- standard error mean (shaded region) computed across 9 random seeds are shown in this image. 49 | 50 | 51 | 52 | ## Solver smoothing improves robustness 53 | We compare results of neural ODE adversarial training on CIFAR-10 dataset with meta-solver in standalone, switching or smoothing regimes. We choose 8-steps RK2 solvers for this experiment. 54 | - We perform training using FGSM random technique described in https://arxiv.org/abs/2001.03994 (with eps=8/255, alpha=10/255). 55 | - We use cyclic learning rate schedule with one cycle (36 epochs, max_lr=0.1, base_lr=1e-7). 56 | - We measure robust accuracy of resulting models after FGSM (eps=8/255) and PGD (eps=8/255, lr=2/255, iters=7) attacks. 57 | - We use `premetanode10` architecture from `sopa/src/models/odenet_cifar10/layers.py` that has the following form 58 | `Conv -> PreResNet block -> ODE block -> PreResNet block -> ODE block -> GeLU -> Average Pooling -> Fully Connected` 59 | - We compute mean and standard error across 3 random seeds. 60 | 61 | 65 | 66 | ![](examples/cifar10/assets/fgsm_random_train.png) 67 | 68 | # References 69 | [1] [Wanner, G., & Hairer, E. (1993). Solving ordinary differential equations I. Springer Berlin Heidelberg](https://www.springer.com/gp/book/9783540566700) 70 | -------------------------------------------------------------------------------- /examples/cifar10/Build the model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 11, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import argparse\n", 10 | "import copy\n", 11 | "import sys\n", 12 | "\n", 13 | "sys.path.append('../../')\n", 14 | "import sopa.src.models.odenet_cifar10.layers as cifar10_models\n", 15 | "from sopa.src.models.odenet_cifar10.utils import *" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 12, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "parser = argparse.ArgumentParser()\n", 25 | "# Architecture params\n", 26 | "parser.add_argument('--is_odenet', type=eval, default=True, choices=[True, False])\n", 27 | "parser.add_argument('--network', type=str, choices=['metanode34', 'metanode18', 'metanode10', 'metanode6', 'metanode4',\n", 28 | " 'premetanode34', 'premetanode18', 'premetanode10', 'premetanode6',\n", 29 | " 'premetanode4'],\n", 30 | " default='premetanode10')\n", 31 | "parser.add_argument('--in_planes', type=int, default=64)\n", 32 | "\n", 33 | "# Type of layer's output normalization\n", 34 | "parser.add_argument('--normalization_resblock', type=str, default='NF',\n", 35 | " choices=['BN', 'GN', 'LN', 'IN', 'NF'])\n", 36 | "parser.add_argument('--normalization_odeblock', type=str, default='NF',\n", 37 | " choices=['BN', 'GN', 'LN', 'IN', 'NF'])\n", 38 | "parser.add_argument('--normalization_bn1', type=str, default='NF',\n", 39 | " choices=['BN', 'GN', 'LN', 'IN', 'NF'])\n", 40 | "parser.add_argument('--num_gn_groups', type=int, default=32, help='Number of groups for GN normalization')\n", 41 | "\n", 42 | "# Type of layer's weights normalization\n", 43 | "parser.add_argument('--param_normalization_resblock', type=str, default='PNF',\n", 44 | " choices=['WN', 'SN', 'PNF'])\n", 45 | "parser.add_argument('--param_normalization_odeblock', type=str, default='PNF',\n", 46 | " choices=['WN', 'SN', 'PNF'])\n", 47 | "parser.add_argument('--param_normalization_bn1', type=str, default='PNF',\n", 48 | " choices=['WN', 'SN', 'PNF'])\n", 49 | "# Type of activation\n", 50 | "parser.add_argument('--activation_resblock', type=str, default='ReLU',\n", 51 | " choices=['ReLU', 'GeLU', 'Softsign', 'Tanh', 'AF'])\n", 52 | "parser.add_argument('--activation_odeblock', type=str, default='ReLU',\n", 53 | " choices=['ReLU', 'GeLU', 'Softsign', 'Tanh', 'AF'])\n", 54 | "parser.add_argument('--activation_bn1', type=str, default='ReLU',\n", 55 | " choices=['ReLU', 'GeLU', 'Softsign', 'Tanh', 'AF'])\n", 56 | "\n", 57 | "args, unknown_args = parser.parse_known_args()" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 13, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "# Initialize Neural ODE model\n", 67 | "config = copy.deepcopy(args)\n", 68 | "\n", 69 | "norm_layers = (get_normalization(config.normalization_resblock),\n", 70 | " get_normalization(config.normalization_odeblock),\n", 71 | " get_normalization(config.normalization_bn1))\n", 72 | "param_norm_layers = (get_param_normalization(config.param_normalization_resblock),\n", 73 | " get_param_normalization(config.param_normalization_odeblock),\n", 74 | " get_param_normalization(config.param_normalization_bn1))\n", 75 | "act_layers = (get_activation(config.activation_resblock),\n", 76 | " get_activation(config.activation_odeblock),\n", 77 | " get_activation(config.activation_bn1))\n", 78 | "\n", 79 | "model = getattr(cifar10_models, config.network)(norm_layers, param_norm_layers, act_layers,\n", 80 | " config.in_planes, is_odenet=config.is_odenet)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 14, 86 | "metadata": {}, 87 | "outputs": [ 88 | { 89 | "data": { 90 | "text/plain": [ 91 | "MetaNODE(\n", 92 | " (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 93 | " (bn1): Identity()\n", 94 | " (layer1): MetaLayer(\n", 95 | " (blocks_res): Sequential(\n", 96 | " (0): PreBasicBlock(\n", 97 | " (bn1): Identity()\n", 98 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 99 | " (bn2): Identity()\n", 100 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 101 | " (shortcut): Sequential()\n", 102 | " )\n", 103 | " )\n", 104 | " (blocks_ode): ModuleList(\n", 105 | " (0): MetaODEBlock(\n", 106 | " (rhs_func): PreBasicBlock2(\n", 107 | " (bn1): Identity()\n", 108 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 109 | " (bn2): Identity()\n", 110 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 111 | " (shortcut): Sequential()\n", 112 | " )\n", 113 | " )\n", 114 | " )\n", 115 | " )\n", 116 | " (layer2): MetaLayer(\n", 117 | " (blocks_res): Sequential(\n", 118 | " (0): PreBasicBlock(\n", 119 | " (bn1): Identity()\n", 120 | " (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 121 | " (bn2): Identity()\n", 122 | " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 123 | " (shortcut): Sequential(\n", 124 | " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 125 | " )\n", 126 | " )\n", 127 | " )\n", 128 | " (blocks_ode): ModuleList(\n", 129 | " (0): MetaODEBlock(\n", 130 | " (rhs_func): PreBasicBlock2(\n", 131 | " (bn1): Identity()\n", 132 | " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 133 | " (bn2): Identity()\n", 134 | " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 135 | " (shortcut): Sequential()\n", 136 | " )\n", 137 | " )\n", 138 | " )\n", 139 | " )\n", 140 | " (fc_layers): Sequential(\n", 141 | " (0): AdaptiveAvgPool2d(output_size=(1, 1))\n", 142 | " (1): Flatten()\n", 143 | " (2): Linear(in_features=128, out_features=10, bias=True)\n", 144 | " )\n", 145 | ")" 146 | ] 147 | }, 148 | "execution_count": 14, 149 | "metadata": {}, 150 | "output_type": "execute_result" 151 | } 152 | ], 153 | "source": [ 154 | "model" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [] 163 | } 164 | ], 165 | "metadata": { 166 | "kernelspec": { 167 | "display_name": "Python 3", 168 | "language": "python", 169 | "name": "python3" 170 | }, 171 | "language_info": { 172 | "codemirror_mode": { 173 | "name": "ipython", 174 | "version": 3 175 | }, 176 | "file_extension": ".py", 177 | "mimetype": "text/x-python", 178 | "name": "python", 179 | "nbconvert_exporter": "python", 180 | "pygments_lexer": "ipython3", 181 | "version": "3.6.8" 182 | } 183 | }, 184 | "nbformat": 4, 185 | "nbformat_minor": 2 186 | } 187 | -------------------------------------------------------------------------------- /examples/cifar10/assets/fgsm_random_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/assets/fgsm_random_train.png -------------------------------------------------------------------------------- /examples/cifar10/assets/fgsm_random_train_clean_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/assets/fgsm_random_train_clean_test.png -------------------------------------------------------------------------------- /examples/cifar10/assets/fgsm_random_train_fgsm_eps_8_255_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/assets/fgsm_random_train_fgsm_eps_8_255_test.png -------------------------------------------------------------------------------- /examples/cifar10/assets/fgsm_random_train_pgd_eps_8_255_lr_2_255_iters_7_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/assets/fgsm_random_train_pgd_eps_8_255_lr_2_255_iters_7_test.png -------------------------------------------------------------------------------- /examples/cifar10/checkpoints/accuracy/fgsm_random_8_255_clean.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/checkpoints/accuracy/fgsm_random_8_255_clean.pkl -------------------------------------------------------------------------------- /examples/cifar10/checkpoints/accuracy/fgsm_random_8_255_fgsm_8_255.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/checkpoints/accuracy/fgsm_random_8_255_fgsm_8_255.pkl -------------------------------------------------------------------------------- /examples/cifar10/checkpoints/accuracy/fgsm_random_8_255_pgd_8_255_2_255_7.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/checkpoints/accuracy/fgsm_random_8_255_pgd_8_255_2_255_7.pkl -------------------------------------------------------------------------------- /examples/cifar10/checkpoints/accuracy/fgsm_random_8_255_smoothing_00125_clean.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/checkpoints/accuracy/fgsm_random_8_255_smoothing_00125_clean.pkl -------------------------------------------------------------------------------- /examples/cifar10/checkpoints/accuracy/fgsm_random_8_255_smoothing_00125_fgsm_8_255.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/checkpoints/accuracy/fgsm_random_8_255_smoothing_00125_fgsm_8_255.pkl -------------------------------------------------------------------------------- /examples/cifar10/checkpoints/accuracy/fgsm_random_8_255_smoothing_00125_pgd_8_255_2_255_7.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/checkpoints/accuracy/fgsm_random_8_255_smoothing_00125_pgd_8_255_2_255_7.pkl -------------------------------------------------------------------------------- /examples/cifar10/checkpoints/fgsm_random_8_255.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/checkpoints/fgsm_random_8_255.pth -------------------------------------------------------------------------------- /examples/cifar10/checkpoints/fgsm_random_8_255_smoothing_00125.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/checkpoints/fgsm_random_8_255_smoothing_00125.pth -------------------------------------------------------------------------------- /examples/cifar10/checkpoints/fgsm_random_8_255_switching_05_05215_04875.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/checkpoints/fgsm_random_8_255_switching_05_05215_04875.pth -------------------------------------------------------------------------------- /examples/mnist/Build the model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import argparse\n", 11 | "from argparse import Namespace\n", 12 | "import torch\n", 13 | "import sys\n", 14 | "\n", 15 | "sys.path.append('../../')\n", 16 | "from sopa.src.models.odenet_mnist.layers import MetaNODE" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "parser = argparse.ArgumentParser()\n", 26 | "parser.add_argument('--network', type=str, choices=['resnet', 'odenet'], default='odenet')\n", 27 | "parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res'])\n", 28 | "parser.add_argument('--activation', type=str, choices=['tanh', 'softplus', 'softsign', 'relu'], default='relu')\n", 29 | "parser.add_argument('--in_channels', type=int, default=1)\n", 30 | "args, unknown_args = parser.parse_known_args()" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 3, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "is_odenet = args.network == 'odenet'\n", 40 | "\n", 41 | "model = MetaNODE(downsampling_method=args.downsampling_method,\n", 42 | " is_odenet=is_odenet,\n", 43 | " activation_type=args.activation,\n", 44 | " in_channels=args.in_channels)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 4, 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "data": { 54 | "text/plain": [ 55 | "MetaNODE(\n", 56 | " (downsampling_layers): Sequential(\n", 57 | " (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))\n", 58 | " (1): GroupNorm(32, 64, eps=1e-05, affine=True)\n", 59 | " (2): ReLU(inplace=True)\n", 60 | " (3): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", 61 | " (4): GroupNorm(32, 64, eps=1e-05, affine=True)\n", 62 | " (5): ReLU(inplace=True)\n", 63 | " (6): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", 64 | " )\n", 65 | " (fc_layers): Sequential(\n", 66 | " (0): GroupNorm(32, 64, eps=1e-05, affine=True)\n", 67 | " (1): ReLU(inplace=True)\n", 68 | " (2): AdaptiveAvgPool2d(output_size=(1, 1))\n", 69 | " (3): Flatten()\n", 70 | " (4): Linear(in_features=64, out_features=10, bias=True)\n", 71 | " )\n", 72 | " (blocks): ModuleList(\n", 73 | " (0): MetaODEBlock(\n", 74 | " (rhs_func): ODEfunc(\n", 75 | " (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)\n", 76 | " (relu): ReLU(inplace=True)\n", 77 | " (conv1): ConcatConv2d(\n", 78 | " (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 79 | " )\n", 80 | " (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)\n", 81 | " (conv2): ConcatConv2d(\n", 82 | " (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 83 | " )\n", 84 | " (norm3): GroupNorm(32, 64, eps=1e-05, affine=True)\n", 85 | " )\n", 86 | " )\n", 87 | " )\n", 88 | ")" 89 | ] 90 | }, 91 | "execution_count": 4, 92 | "metadata": {}, 93 | "output_type": "execute_result" 94 | } 95 | ], 96 | "source": [ 97 | "model" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [] 106 | } 107 | ], 108 | "metadata": { 109 | "kernelspec": { 110 | "display_name": "Python 3", 111 | "language": "python", 112 | "name": "python3" 113 | }, 114 | "language_info": { 115 | "codemirror_mode": { 116 | "name": "ipython", 117 | "version": 3 118 | }, 119 | "file_extension": ".py", 120 | "mimetype": "text/x-python", 121 | "name": "python", 122 | "nbconvert_exporter": "python", 123 | "pygments_lexer": "ipython3", 124 | "version": "3.8.3" 125 | } 126 | }, 127 | "nbformat": 4, 128 | "nbformat_minor": 4 129 | } 130 | -------------------------------------------------------------------------------- /examples/mnist/Evaluate the model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "! python3 -m pip install wandb -q" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import os\n", 19 | "import argparse\n", 20 | "from argparse import Namespace\n", 21 | "import time\n", 22 | "import numpy as np\n", 23 | "import torch\n", 24 | "import torch.nn as nn\n", 25 | "import torch.optim as optim\n", 26 | "\n", 27 | "from decimal import Decimal\n", 28 | "import wandb\n", 29 | "import sys\n", 30 | "\n", 31 | "sys.path.append('../../')\n", 32 | "\n", 33 | "from sopa.src.solvers.utils import create_solver\n", 34 | "from sopa.src.models.utils import fix_seeds, RunningAverageMeter\n", 35 | "from sopa.src.models.odenet_mnist.layers import MetaNODE\n", 36 | "from sopa.src.models.odenet_mnist.utils import makedirs, learning_rate_with_decay\n", 37 | "from sopa.src.models.odenet_mnist.data import get_mnist_loaders, inf_generator\n", 38 | "from MegaAdversarial.src.attacks import (\n", 39 | " Clean,\n", 40 | " PGD,\n", 41 | " FGSM\n", 42 | ")" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 3, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "parser = argparse.ArgumentParser()\n", 52 | "parser.add_argument('--network', type=str, choices=['resnet', 'odenet'], default='odenet')\n", 53 | "parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res'])\n", 54 | "parser.add_argument('--activation', type=str, choices=['tanh', 'softplus', 'softsign', 'relu'], default='relu')\n", 55 | "parser.add_argument('--in_channels', type=int, default=1)\n", 56 | "\n", 57 | "parser.add_argument('--solvers',\n", 58 | " type=lambda s: [tuple(map(lambda iparam: str(iparam[1]) if iparam[0] <= 1 else (\n", 59 | " int(iparam[1]) if iparam[0] == 2 else (\n", 60 | " float(iparam[1]) if iparam[0] == 3 else Decimal(iparam[1]))),\n", 61 | " enumerate(item.split(',')))) for item in s.strip().split(';')],\n", 62 | " default=None,\n", 63 | " help='Each solver is represented with (method,parameterization,n_steps,step_size,u0,v0) \\n' +\n", 64 | " 'If the solver has only one parameter u0, set v0 to -1; \\n' +\n", 65 | " 'n_steps and step_size are exclusive parameters, only one of them can be != -1, \\n'\n", 66 | " 'If n_steps = step_size = -1, automatic time grid_constructor is used \\n;'\n", 67 | " 'For example, --solvers rk4,uv,2,-1,0.3,0.6;rk3,uv,-1,0.1,0.4,0.6;rk2,u,4,-1,0.3,-1')\n", 68 | "\n", 69 | "parser.add_argument('--solver_mode', type=str, choices=['switch', 'ensemble', 'standalone'], default='standalone')\n", 70 | "parser.add_argument('--val_solver_modes',\n", 71 | " type=lambda s: s.strip().split(','),\n", 72 | " default=['standalone'],\n", 73 | " help='Solver modes to use for validation step')\n", 74 | "\n", 75 | "parser.add_argument('--switch_probs', type=lambda s: [float(item) for item in s.split(',')], default=None,\n", 76 | " help=\"--switch_probs 0.8,0.1,0.1\")\n", 77 | "parser.add_argument('--ensemble_weights', type=lambda s: [float(item) for item in s.split(',')], default=None,\n", 78 | " help=\"ensemble_weights 0.6,0.2,0.2\")\n", 79 | "parser.add_argument('--ensemble_prob', type=float, default=1.)\n", 80 | "\n", 81 | "parser.add_argument('--noise_type', type=str, choices=['cauchy', 'normal'], default=None)\n", 82 | "parser.add_argument('--noise_sigma', type=float, default=0.001)\n", 83 | "parser.add_argument('--noise_prob', type=float, default=0.)\n", 84 | "parser.add_argument('--minimize_rk2_error', type=eval, default=False, choices=[True, False])\n", 85 | "\n", 86 | "parser.add_argument('--nepochs_nn', type=int, default=50)\n", 87 | "parser.add_argument('--nepochs_solver', type=int, default=0)\n", 88 | "parser.add_argument('--nstages', type=int, default=1)\n", 89 | "\n", 90 | "parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False])\n", 91 | "parser.add_argument('--lr', type=float, default=0.01)\n", 92 | "parser.add_argument('--weight_decay', type=float, default=0.0005)\n", 93 | "parser.add_argument('--batch_size', type=int, default=128)\n", 94 | "parser.add_argument('--test_batch_size', type=int, default=1000)\n", 95 | "parser.add_argument('--base_lr', type=float, default=1e-5, help='base_lr for CyclicLR scheduler')\n", 96 | "parser.add_argument('--max_lr', type=float, default=1e-3, help='max_lr for CyclicLR scheduler')\n", 97 | "parser.add_argument('--step_size_up', type=int, default=2000, help='step_size_up for CyclicLR scheduler')\n", 98 | "parser.add_argument('--cyclic_lr_mode', type=str, default='triangular2', help='mode for CyclicLR scheduler')\n", 99 | "parser.add_argument('--lr_uv', type=float, default=1e-3)\n", 100 | "parser.add_argument('--torch_dtype', type=str, default='float32')\n", 101 | "parser.add_argument('--wandb_name', type=str, default='find_best_solver')\n", 102 | "\n", 103 | "parser.add_argument('--data_root', type=str, default='./')\n", 104 | "parser.add_argument('--save_dir', type=str, default='./')\n", 105 | "parser.add_argument('--debug', action='store_true')\n", 106 | "parser.add_argument('--gpu', type=int, default=0)\n", 107 | "\n", 108 | "parser.add_argument('--seed', type=int, default=502)\n", 109 | "# Noise and adversarial attacks parameters:\n", 110 | "parser.add_argument('--data_noise_std', type=float, default=0.,\n", 111 | " help='Applies Norm(0, std) gaussian noise to each training batch')\n", 112 | "parser.add_argument('--eps_adv_training', type=float, default=0.3,\n", 113 | " help='Epsilon for adversarial training')\n", 114 | "parser.add_argument(\n", 115 | " \"--adv_training_mode\",\n", 116 | " default=\"clean\",\n", 117 | " choices=[\"clean\", \"fgsm\", \"at\"], # , \"at_ls\", \"av\", \"fs\", \"nce\", \"nce_moco\", \"moco\", \"av_extra\", \"meta\"],\n", 118 | " help='''Adverarial training method/mode, by default there is no adversarial training (clean).\n", 119 | " For further details see MegaAdversarial/train in this repository.\n", 120 | " '''\n", 121 | ")\n", 122 | "parser.add_argument('--use_wandb', type=eval, default=True, choices=[True, False])\n", 123 | "parser.add_argument('--ss_loss', type=eval, default=False, choices=[True, False])\n", 124 | "parser.add_argument('--ss_loss_reg', type=float, default=0.1)\n", 125 | "parser.add_argument('--timestamp', type=int, default=int(1e6 * time.time()))\n", 126 | "\n", 127 | "parser.add_argument('--eps_adv_testing', type=float, default=0.3,\n", 128 | " help='Epsilon for adversarial testing')\n", 129 | "parser.add_argument('--adv_testing_mode',\n", 130 | " default=\"clean\",\n", 131 | " choices=[\"clean\", \"fgsm\", \"at\"],\n", 132 | " help='''Adversarsarial testing mode''')\n", 133 | "\n", 134 | "args = parser.parse_args(['--solvers', 'rk4,u3,4,-1,0.3,-1', '--seed', '902', '--adv_testing_mode', 'at', \n", 135 | " '--max_lr', '0.001', '--base_lr', '1e-05'])" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 4, 141 | "metadata": {}, 142 | "outputs": [ 143 | { 144 | "name": "stderr", 145 | "output_type": "stream", 146 | "text": [ 147 | "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mtalgat\u001b[0m (use `wandb login --relogin` to force relogin)\n", 148 | "\u001b[34m\u001b[1mwandb\u001b[0m: wandb version 0.10.19 is available! To upgrade, please run:\n", 149 | "\u001b[34m\u001b[1mwandb\u001b[0m: $ pip install wandb --upgrade\n" 150 | ] 151 | }, 152 | { 153 | "data": { 154 | "text/html": [ 155 | "\n", 156 | " Tracking run with wandb version 0.10.7
\n", 157 | " Syncing run breezy-violet-42 to Weights & Biases (Documentation).
\n", 158 | " Project page: https://wandb.ai/sopa_node/find_best_solver
\n", 159 | " Run page: https://wandb.ai/sopa_node/find_best_solver/runs/16j2vx1d
\n", 160 | " Run data is saved locally in wandb/run-20210216_151114-16j2vx1d

\n", 161 | " " 162 | ], 163 | "text/plain": [ 164 | "" 165 | ] 166 | }, 167 | "metadata": {}, 168 | "output_type": "display_data" 169 | } 170 | ], 171 | "source": [ 172 | "makedirs(args.save_dir)\n", 173 | "if args.use_wandb:\n", 174 | " wandb.init(project=args.wandb_name, anonymous=\"allow\", entity=\"sopa_node\")\n", 175 | " wandb.config.update(args)\n", 176 | " wandb.config.update({'u': float(args.solvers[0][-2])})\n", 177 | " makedirs(wandb.config.save_dir)\n", 178 | " makedirs(os.path.join(wandb.config.save_dir, wandb.run.path))" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": {}, 184 | "source": [ 185 | "### Load the model" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 5, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "# Load a checkpoint\n", 195 | "checkpoint_name = './checkpoints/checkpoint_15444.pth'\n", 196 | "device = f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu'\n", 197 | "model = torch.load(checkpoint_name, map_location=device)" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 6, 203 | "metadata": {}, 204 | "outputs": [ 205 | { 206 | "data": { 207 | "text/plain": [ 208 | "MetaNODE(\n", 209 | " (downsampling_layers): Sequential(\n", 210 | " (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))\n", 211 | " (1): GroupNorm(32, 64, eps=1e-05, affine=True)\n", 212 | " (2): ReLU(inplace=True)\n", 213 | " (3): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", 214 | " (4): GroupNorm(32, 64, eps=1e-05, affine=True)\n", 215 | " (5): ReLU(inplace=True)\n", 216 | " (6): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", 217 | " )\n", 218 | " (fc_layers): Sequential(\n", 219 | " (0): GroupNorm(32, 64, eps=1e-05, affine=True)\n", 220 | " (1): ReLU(inplace=True)\n", 221 | " (2): AdaptiveAvgPool2d(output_size=(1, 1))\n", 222 | " (3): Flatten()\n", 223 | " (4): Linear(in_features=64, out_features=10, bias=True)\n", 224 | " )\n", 225 | " (blocks): ModuleList(\n", 226 | " (0): MetaODEBlock(\n", 227 | " (rhs_func): ODEfunc(\n", 228 | " (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)\n", 229 | " (relu): ReLU(inplace=True)\n", 230 | " (conv1): ConcatConv2d(\n", 231 | " (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 232 | " )\n", 233 | " (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)\n", 234 | " (conv2): ConcatConv2d(\n", 235 | " (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 236 | " )\n", 237 | " (norm3): GroupNorm(32, 64, eps=1e-05, affine=True)\n", 238 | " )\n", 239 | " )\n", 240 | " )\n", 241 | ")" 242 | ] 243 | }, 244 | "execution_count": 6, 245 | "metadata": {}, 246 | "output_type": "execute_result" 247 | } 248 | ], 249 | "source": [ 250 | "model" 251 | ] 252 | }, 253 | { 254 | "cell_type": "markdown", 255 | "metadata": {}, 256 | "source": [ 257 | "### Build a data loader" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 7, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "train_loader, test_loader, train_eval_loader = get_mnist_loaders(args.data_aug,\n", 267 | " args.batch_size,\n", 268 | " args.test_batch_size,\n", 269 | " data_root=args.data_root)\n", 270 | "data_gen = inf_generator(train_loader)\n", 271 | "batches_per_epoch = len(train_loader)" 272 | ] 273 | }, 274 | { 275 | "cell_type": "markdown", 276 | "metadata": {}, 277 | "source": [ 278 | "### Evaluate the model" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 8, 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [ 287 | "if args.torch_dtype == 'float64':\n", 288 | " dtype = torch.float64\n", 289 | "elif args.torch_dtype == 'float32':\n", 290 | " dtype = torch.float32\n", 291 | " \n", 292 | "solvers = [create_solver(*solver_params, dtype=dtype, device=device) for solver_params in args.solvers]\n", 293 | "for solver in solvers:\n", 294 | " solver.freeze_params()\n", 295 | " \n", 296 | "solver_options = Namespace(**{key: vars(args)[key] for key in ['solver_mode', 'switch_probs',\n", 297 | " 'ensemble_prob', 'ensemble_weights']})" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "metadata": { 304 | "scrolled": true 305 | }, 306 | "outputs": [], 307 | "source": [ 308 | "def one_hot(x, K):\n", 309 | " return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)\n", 310 | "\n", 311 | "def accuracy(model, dataset_loader, device, solvers=None, solver_options=None):\n", 312 | " model.eval()\n", 313 | " total_correct = 0\n", 314 | " for x, y in dataset_loader:\n", 315 | " x = x.to(device)\n", 316 | " y = one_hot(np.array(y.numpy()), 10)\n", 317 | " target_class = np.argmax(y, axis=1)\n", 318 | " with torch.no_grad():\n", 319 | " if solver is not None:\n", 320 | " out = model(x, solvers, solver_options).cpu().detach().numpy()\n", 321 | " else:\n", 322 | " out = model(x).cpu().detach().numpy()\n", 323 | " predicted_class = np.argmax(out, axis=1)\n", 324 | " total_correct += np.sum(predicted_class == target_class)\n", 325 | " return total_correct / len(dataset_loader.dataset)\n", 326 | "\n", 327 | "\n", 328 | "def adversarial_accuracy(model, dataset_loader, device, solvers=None, solver_options=None, args=None):\n", 329 | " model.eval()\n", 330 | " total_correct = 0\n", 331 | " if args.adv_testing_mode == \"clean\":\n", 332 | " test_attack = Clean(model)\n", 333 | " elif args.adv_testing_mode == \"fgsm\":\n", 334 | " test_attack = FGSM(model, mean=[0.], std=[1.], **CONFIG_FGSM_TEST)\n", 335 | " elif args.adv_testing_mode == \"at\":\n", 336 | " test_attack = PGD(model, mean=[0.], std=[1.], **CONFIG_PGD_TEST)\n", 337 | " else:\n", 338 | " raise ValueError(\"Attack type not understood.\")\n", 339 | " for x, y in dataset_loader:\n", 340 | " x, y = x.to(device), y.to(device)\n", 341 | " x, y = test_attack(x, y, {\"solvers\": solvers, \"solver_options\": solver_options})\n", 342 | " y = one_hot(np.array(y.cpu().numpy()), 10)\n", 343 | " target_class = np.argmax(y, axis=1)\n", 344 | " with torch.no_grad():\n", 345 | " if solver is not None:\n", 346 | " out = model(x, solvers, solver_options).cpu().detach().numpy()\n", 347 | " else:\n", 348 | " out = model(x).cpu().detach().numpy()\n", 349 | " predicted_class = np.argmax(out, axis=1)\n", 350 | " total_correct += np.sum(predicted_class == target_class)\n", 351 | " return total_correct / len(dataset_loader.dataset)" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": null, 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [ 360 | "accuracy_test = accuracy(model, test_loader, device=device,\n", 361 | " solvers=solvers, solver_options=solver_options)\n", 362 | "accuracy_test" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": null, 368 | "metadata": {}, 369 | "outputs": [], 370 | "source": [ 371 | "CONFIG_PGD_TEST = {\"eps\": 0.3, \"lr\": 2.0 / 255, \"n_iter\": 7}\n", 372 | "adv_accuracy_test = adversarial_accuracy(model, test_loader, device,\n", 373 | " solvers=solvers, solver_options=solver_options, args=args\n", 374 | " )\n", 375 | "adv_accuracy_test" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": null, 381 | "metadata": {}, 382 | "outputs": [], 383 | "source": [] 384 | } 385 | ], 386 | "metadata": { 387 | "kernelspec": { 388 | "display_name": "Python 3", 389 | "language": "python", 390 | "name": "python3" 391 | }, 392 | "language_info": { 393 | "codemirror_mode": { 394 | "name": "ipython", 395 | "version": 3 396 | }, 397 | "file_extension": ".py", 398 | "mimetype": "text/x-python", 399 | "name": "python", 400 | "nbconvert_exporter": "python", 401 | "pygments_lexer": "ipython3", 402 | "version": "3.8.3" 403 | } 404 | }, 405 | "nbformat": 4, 406 | "nbformat_minor": 4 407 | } 408 | -------------------------------------------------------------------------------- /examples/mnist/assets/mnist_adv.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/mnist/assets/mnist_adv.pdf -------------------------------------------------------------------------------- /examples/mnist/assets/mnist_adv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/mnist/assets/mnist_adv.png -------------------------------------------------------------------------------- /examples/mnist/checkpoints/checkpoint_15444.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/mnist/checkpoints/checkpoint_15444.pth -------------------------------------------------------------------------------- /examples/mnist/train_and_attack.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from argparse import Namespace 4 | import time 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import copy 10 | 11 | from decimal import Decimal 12 | import wandb 13 | import sys 14 | 15 | sys.path.append('../../') 16 | 17 | from sopa.src.solvers.utils import create_solver 18 | from sopa.src.models.utils import fix_seeds, RunningAverageMeter 19 | from sopa.src.models.odenet_mnist.layers import MetaNODE 20 | from sopa.src.models.odenet_mnist.utils import makedirs, learning_rate_with_decay 21 | from sopa.src.models.odenet_mnist.data import get_mnist_loaders, inf_generator 22 | from MegaAdversarial.src.attacks import ( 23 | Clean, 24 | PGD, 25 | FGSM 26 | ) 27 | 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--network', type=str, choices=['resnet', 'odenet'], default='odenet') 30 | parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res']) 31 | parser.add_argument('--activation', type=str, choices=['tanh', 'softplus', 'softsign', 'relu'], default='relu') 32 | parser.add_argument('--in_channels', type=int, default=1) 33 | 34 | parser.add_argument('--solvers', 35 | type=lambda s: [tuple(map(lambda iparam: str(iparam[1]) if iparam[0] <= 1 else ( 36 | int(iparam[1]) if iparam[0] == 2 else ( 37 | float(iparam[1]) if iparam[0] == 3 else Decimal(iparam[1]))), 38 | enumerate(item.split(',')))) for item in s.strip().split(';')], 39 | default=None, 40 | help='Each solver is represented with (method,parameterization,n_steps,step_size,u0,v0) \n' + 41 | 'If the solver has only one parameter u0, set v0 to -1; \n' + 42 | 'n_steps and step_size are exclusive parameters, only one of them can be != -1, \n' 43 | 'If n_steps = step_size = -1, automatic time grid_constructor is used \n;' 44 | 'For example, --solvers rk4,uv,2,-1,0.3,0.6;rk3,uv,-1,0.1,0.4,0.6;rk2,u,4,-1,0.3,-1') 45 | 46 | parser.add_argument('--solver_mode', type=str, choices=['switch', 'ensemble', 'standalone'], default='standalone') 47 | parser.add_argument('--val_solver_modes', 48 | type=lambda s: s.strip().split(','), 49 | default=['standalone'], 50 | help='Solver modes to use for validation step') 51 | 52 | parser.add_argument('--switch_probs', type=lambda s: [float(item) for item in s.split(',')], default=None, 53 | help="--switch_probs 0.8,0.1,0.1") 54 | parser.add_argument('--ensemble_weights', type=lambda s: [float(item) for item in s.split(',')], default=None, 55 | help="ensemble_weights 0.6,0.2,0.2") 56 | parser.add_argument('--ensemble_prob', type=float, default=1.) 57 | 58 | parser.add_argument('--noise_type', type=str, choices=['cauchy', 'normal'], default=None) 59 | parser.add_argument('--noise_sigma', type=float, default=0.001) 60 | parser.add_argument('--noise_prob', type=float, default=0.) 61 | parser.add_argument('--minimize_rk2_error', type=eval, default=False, choices=[True, False]) 62 | 63 | parser.add_argument('--nepochs_nn', type=int, default=50) 64 | parser.add_argument('--nepochs_solver', type=int, default=0) 65 | parser.add_argument('--nstages', type=int, default=1) 66 | 67 | parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False]) 68 | parser.add_argument('--lr', type=float, default=0.01) 69 | parser.add_argument('--weight_decay', type=float, default=0.0005) 70 | parser.add_argument('--batch_size', type=int, default=128) 71 | parser.add_argument('--test_batch_size', type=int, default=1000) 72 | parser.add_argument('--base_lr', type=float, default=1e-5, help='base_lr for CyclicLR scheduler') 73 | parser.add_argument('--max_lr', type=float, default=1e-3, help='max_lr for CyclicLR scheduler') 74 | parser.add_argument('--step_size_up', type=int, default=2000, help='step_size_up for CyclicLR scheduler') 75 | parser.add_argument('--cyclic_lr_mode', type=str, default='triangular2', help='mode for CyclicLR scheduler') 76 | parser.add_argument('--lr_uv', type=float, default=1e-3) 77 | parser.add_argument('--torch_dtype', type=str, default='float32') 78 | parser.add_argument('--wandb_name', type=str, default='find_best_solver') 79 | 80 | parser.add_argument('--data_root', type=str, default='./') 81 | parser.add_argument('--save_dir', type=str, default='./') 82 | parser.add_argument('--debug', action='store_true') 83 | parser.add_argument('--gpu', type=int, default=0) 84 | 85 | parser.add_argument('--seed', type=int, default=502) 86 | # Noise and adversarial attacks parameters: 87 | parser.add_argument('--data_noise_std', type=float, default=0., 88 | help='Applies Norm(0, std) gaussian noise to each training batch') 89 | parser.add_argument('--eps_adv_training', type=float, default=0.3, 90 | help='Epsilon for adversarial training') 91 | parser.add_argument( 92 | "--adv_training_mode", 93 | default="clean", 94 | choices=["clean", "fgsm", "at"], # , "at_ls", "av", "fs", "nce", "nce_moco", "moco", "av_extra", "meta"], 95 | help='''Adverarial training method/mode, by default there is no adversarial training (clean). 96 | For further details see MegaAdversarial/train in this repository. 97 | ''' 98 | ) 99 | parser.add_argument('--use_wandb', type=eval, default=True, choices=[True, False]) 100 | parser.add_argument('--ss_loss', type=eval, default=False, choices=[True, False]) 101 | parser.add_argument('--ss_loss_reg', type=float, default=0.1) 102 | parser.add_argument('--timestamp', type=int, default=int(1e6 * time.time())) 103 | 104 | parser.add_argument('--eps_adv_testing', type=float, default=0.3, 105 | help='Epsilon for adversarial testing') 106 | parser.add_argument('--adv_testing_mode', 107 | default="clean", 108 | choices=["clean", "fgsm", "at"], 109 | help='''Adversarsarial testing mode''') 110 | 111 | args = parser.parse_args() 112 | 113 | sys.path.append('../../') 114 | 115 | makedirs(args.save_dir) 116 | if args.use_wandb: 117 | wandb.init(project=args.wandb_name, entity="sopa_node", anonymous="allow") 118 | wandb.config.update(args) 119 | wandb.config.update({'u': float(args.solvers[0][-2])}) # a dirty way to extract u from the rk2 solver 120 | # Path to save checkpoints locally in /// [Julia style] 121 | makedirs(wandb.config.save_dir) 122 | makedirs(os.path.join(wandb.config.save_dir, wandb.run.path)) 123 | 124 | if args.torch_dtype == 'float64': 125 | dtype = torch.float64 126 | elif args.torch_dtype == 'float32': 127 | dtype = torch.float32 128 | else: 129 | raise ValueError('torch_type should be either float64 or float32') 130 | 131 | # I've decided to copy and modify functions from src/models/odenet_mnist/train_validate.py 132 | 133 | CONFIG_PGD_TRAIN = {"eps": 0.3, "lr": 2.0 / 255, "n_iter": 7} 134 | CONFIG_FGSM_TRAIN = {"alpha": 0.3, "epsilon": 0.05} 135 | 136 | 137 | def one_hot(x, K): 138 | return np.array(x[:, None] == np.arange(K)[None, :], dtype=int) 139 | 140 | 141 | def accuracy(model, dataset_loader, device, solvers=None, solver_options=None): 142 | model.eval() 143 | total_correct = 0 144 | for x, y in dataset_loader: 145 | x = x.to(device) 146 | y = one_hot(np.array(y.numpy()), 10) 147 | target_class = np.argmax(y, axis=1) 148 | with torch.no_grad(): 149 | if solver is not None: 150 | out = model(x, solvers, solver_options).cpu().detach().numpy() 151 | else: 152 | out = model(x).cpu().detach().numpy() 153 | predicted_class = np.argmax(out, axis=1) 154 | total_correct += np.sum(predicted_class == target_class) 155 | return total_correct / len(dataset_loader.dataset) 156 | 157 | 158 | def adversarial_accuracy(model, dataset_loader, device, solvers=None, solver_options=None, args=None): 159 | model.eval() 160 | total_correct = 0 161 | if args.adv_testing_mode == "clean": 162 | test_attack = Clean(model) 163 | elif args.adv_testing_mode == "fgsm": 164 | test_attack = FGSM(model, mean=[0.], std=[1.], **CONFIG_FGSM_TRAIN) 165 | elif args.adv_testing_mode == "at": 166 | test_attack = PGD(model, mean=[0.], std=[1.], **CONFIG_PGD_TRAIN) 167 | else: 168 | raise ValueError("Attack type not understood.") 169 | for x, y in dataset_loader: 170 | x, y = x.to(device), y.to(device) 171 | x, y = test_attack(x, y, {"solvers": solvers, "solver_options": solver_options}) 172 | y = one_hot(np.array(y.cpu().numpy()), 10) 173 | target_class = np.argmax(y, axis=1) 174 | with torch.no_grad(): 175 | if solver is not None: 176 | out = model(x, solvers, solver_options).cpu().detach().numpy() 177 | else: 178 | out = model(x).cpu().detach().numpy() 179 | predicted_class = np.argmax(out, axis=1) 180 | total_correct += np.sum(predicted_class == target_class) 181 | return total_correct / len(dataset_loader.dataset) 182 | 183 | 184 | def train(model, 185 | data_gen, 186 | solvers, 187 | solver_options, 188 | criterion, 189 | optimizer, 190 | device, 191 | is_odenet=True, 192 | args=None): 193 | model.train() 194 | optimizer.zero_grad() 195 | x, y = data_gen.__next__() 196 | x = x.to(device) 197 | y = y.to(device) 198 | 199 | if args.adv_training_mode == "clean": 200 | train_attack = Clean(model) 201 | elif args.adv_training_mode == "fgsm": 202 | train_attack = FGSM(model, **CONFIG_FGSM_TRAIN) 203 | elif args.adv_training_mode == "at": 204 | train_attack = PGD(model, **CONFIG_PGD_TRAIN) 205 | else: 206 | raise ValueError("Attack type not understood.") 207 | x, y = train_attack(x, y, {"solvers": solvers, "solver_options": solver_options}) 208 | 209 | # Add noise: 210 | if args.data_noise_std > 1e-12: 211 | with torch.no_grad(): 212 | x = x + args.data_noise_std * torch.randn_like(x) 213 | ##### Forward pass 214 | if is_odenet: 215 | logits = model(x, solvers, solver_options, Namespace(ss_loss=args.ss_loss)) 216 | else: 217 | logits = model(x) 218 | 219 | xentropy = criterion(logits, y) 220 | if args.ss_loss: 221 | ss_loss = model.get_ss_loss() 222 | loss = xentropy + args.ss_loss_reg * ss_loss 223 | else: 224 | ss_loss = 0. 225 | loss = xentropy 226 | 227 | loss.backward() 228 | optimizer.step() 229 | if args.ss_loss: 230 | return {'xentropy': xentropy.item(), 'ss_loss': ss_loss.item()} 231 | return {'xentropy': xentropy.item()} 232 | 233 | 234 | if __name__ == "__main__": 235 | print(f'CUDA is available: {torch.cuda.is_available()}') 236 | print(args.solvers) 237 | fix_seeds(args.seed) 238 | 239 | if args.torch_dtype == 'float64': 240 | dtype = torch.float64 241 | elif args.torch_dtype == 'float32': 242 | dtype = torch.float32 243 | else: 244 | raise ValueError('torch_type should be either float64 or float32') 245 | 246 | device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') 247 | 248 | ########### Create train / val solvers 249 | print(args.solvers) 250 | train_solvers = [create_solver(*solver_params, dtype=dtype, device=device) for solver_params in args.solvers] 251 | for solver in train_solvers: 252 | solver.freeze_params() 253 | 254 | train_solver_options = Namespace(**{key: vars(args)[key] for key in ['solver_mode', 'switch_probs', 255 | 'ensemble_prob', 'ensemble_weights']}) 256 | 257 | val_solver_modes = args.val_solver_modes 258 | 259 | ########## Build the model 260 | is_odenet = args.network == 'odenet' 261 | 262 | model = MetaNODE(downsampling_method=args.downsampling_method, 263 | is_odenet=is_odenet, 264 | activation_type=args.activation, 265 | in_channels=args.in_channels) 266 | model.to(device) 267 | if args.use_wandb: 268 | wandb.watch(model) 269 | 270 | ########### Create data loaders 271 | train_loader, test_loader, train_eval_loader = get_mnist_loaders(args.data_aug, 272 | args.batch_size, 273 | args.test_batch_size, 274 | data_root=args.data_root) 275 | data_gen = inf_generator(train_loader) 276 | batches_per_epoch = len(train_loader) 277 | 278 | ########### Create criterion and optimizer 279 | 280 | criterion = nn.CrossEntropyLoss().to(device) 281 | loss_options = Namespace(ss_loss=args.ss_loss) 282 | 283 | ##### We exchange the learning rate with a circular learning rate 284 | 285 | optimizer = optim.RMSprop([{"params": model.parameters(), 'lr': args.lr}, ], lr=args.lr, 286 | weight_decay=args.weight_decay) 287 | scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=args.base_lr, 288 | max_lr=args.max_lr, step_size_up=args.step_size_up, mode=args.cyclic_lr_mode) 289 | 290 | ########### Train the model 291 | 292 | for itr in range(args.nepochs_nn * batches_per_epoch): 293 | 294 | # for param_group in optimizer.param_groups: 295 | # param_group['lr'] = lr_fn(itr) 296 | 297 | train_loss = train(model, 298 | data_gen, 299 | solvers=train_solvers, 300 | solver_options=train_solver_options, 301 | criterion=criterion, 302 | optimizer=optimizer, 303 | device=device, 304 | is_odenet=is_odenet, 305 | args=args) 306 | scheduler.step() 307 | 308 | if itr % batches_per_epoch == 0: 309 | train_acc = accuracy(model, train_loader, device, solvers=train_solvers, solver_options=train_solver_options) 310 | test_acc = accuracy(model, test_loader, device, solvers=train_solvers, solver_options=train_solver_options) 311 | adv_test_acc = adversarial_accuracy(model, test_loader, device, solvers=train_solvers, 312 | solver_options=train_solver_options, args=args) 313 | adv_train_acc = adversarial_accuracy(model, train_loader, device, solvers=train_solvers, 314 | solver_options=train_solver_options, args=args) 315 | 316 | makedirs(os.path.join(wandb.config.save_dir, wandb.run.path)) 317 | save_path = os.path.join(wandb.config.save_dir, wandb.run.path, "checkpoint_{}.pth".format(itr)) 318 | print(save_path) 319 | torch.save(model, save_path) 320 | wandb.save(save_path) 321 | 322 | wandb.log({'train_acc': train_acc, 323 | 'test_acc': test_acc, 324 | 'adv_test_acc': adv_test_acc, 325 | 'adv_train_acc': adv_train_acc, 326 | 'train_loss': train_loss['xentropy']}) -------------------------------------------------------------------------------- /sopa/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/sopa/__init__.py -------------------------------------------------------------------------------- /sopa/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/sopa/src/__init__.py -------------------------------------------------------------------------------- /sopa/src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/sopa/src/models/__init__.py -------------------------------------------------------------------------------- /sopa/src/models/odenet_cifar10/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/sopa/src/models/odenet_cifar10/__init__.py -------------------------------------------------------------------------------- /sopa/src/models/odenet_cifar10/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import torchvision.datasets as datasets 4 | import torchvision.transforms as transforms 5 | from torch.utils.data.sampler import SubsetRandomSampler 6 | import numpy as np 7 | 8 | def get_cifar10_train_val_loaders(data_aug=False, 9 | batch_size=128, 10 | val_perc=0.1, 11 | data_root=None, 12 | num_workers=1, 13 | pin_memory=True, 14 | shuffle=True, 15 | random_seed=None, 16 | download=False): 17 | ''' Returns iterators through train/val CIFAR10 datasets. 18 | 19 | If using CUDA, num_workers should be set to 1 and pin_memory to True. 20 | 21 | :param data_aug: bool 22 | Whether to apply the data augmentation scheme. Only applied on the train split. 23 | :param batch_size: int 24 | How many samples per batch to load. 25 | :param val_perc: float 26 | Percentage split of the training set used for the validation set. Should be a float in the range [0, 1]. 27 | :param data_root: str 28 | Path to the directory with the dataset. 29 | :param num_workers: int 30 | Number of subprocesses to use when loading the dataset. 31 | :param pin_memory: bool 32 | Whether to copy tensors into CUDA pinned memory. Set it to True if using GPU. 33 | :param shuffle: bool 34 | Whether to shuffle the train/validation indices 35 | :param random_seed: int 36 | Fix seed for reproducibility. 37 | :return: 38 | ''' 39 | 40 | if data_aug: 41 | transform_train = transforms.Compose([ 42 | transforms.RandomCrop(32, padding=4), 43 | transforms.RandomHorizontalFlip(), 44 | transforms.ToTensor(), 45 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 46 | ]) 47 | else: 48 | transform_train = transforms.Compose([ 49 | transforms.RandomHorizontalFlip(), 50 | transforms.ToTensor(), 51 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 52 | ]) 53 | 54 | transform_test = transforms.Compose([ 55 | transforms.ToTensor(), 56 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 57 | ]) 58 | 59 | train_dataset = datasets.CIFAR10(root=data_root, train=True, download=download, transform=transform_train) 60 | val_dataset = datasets.CIFAR10(root=data_root, train=True, download=download, transform=transform_test) 61 | 62 | num_train = len(train_dataset) 63 | indices = list(range(num_train)) 64 | split = int(np.floor(val_perc * num_train)) 65 | 66 | if shuffle: 67 | np.random.seed(random_seed) 68 | np.random.shuffle(indices) 69 | 70 | train_idx, val_idx = indices[split:], indices[:split] 71 | train_sampler = SubsetRandomSampler(train_idx) 72 | val_sampler = SubsetRandomSampler(val_idx) 73 | 74 | train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers, 75 | drop_last=True, pin_memory=pin_memory,) 76 | val_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=val_sampler, num_workers=num_workers, 77 | drop_last=True, pin_memory=pin_memory,) 78 | 79 | return train_loader, val_loader 80 | 81 | 82 | def get_cifar10_test_loader(batch_size=128, data_root=None, num_workers=1, pin_memory=True, shuffle=False, download=False): 83 | ''' Returns iterator through CIFAR10 test dataset 84 | 85 | If using CUDA, num_workers should be set to 1 and pin_memory to True. 86 | 87 | :param batch_size: int 88 | How many samples per batch to load. 89 | :param data_root: str 90 | Path to the directory with the dataset. 91 | :param num_workers: int 92 | Number of subprocesses to use when loading the dataset. 93 | :param pin_memory: bool 94 | Whether to copy tensors into CUDA pinned memory. Set it to True if using GPU. 95 | :param shuffle: bool 96 | Whether to shuffle the dataset after every epoch. 97 | :return: 98 | ''' 99 | transform_test = transforms.Compose([ 100 | transforms.ToTensor(), 101 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 102 | ]) 103 | test_dataset = datasets.CIFAR10(root=data_root, train=False, download=download, transform=transform_test) 104 | 105 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, 106 | drop_last=True, pin_memory=pin_memory) 107 | return test_loader 108 | 109 | 110 | def inf_generator(iterable): 111 | """Allows training with DataLoaders in a single infinite loop: 112 | for i, (x, y) in enumerate(inf_generator(train_loader)): 113 | """ 114 | iterator = iterable.__iter__() 115 | while True: 116 | try: 117 | yield iterator.__next__() 118 | except StopIteration: 119 | iterator = iterable.__iter__() 120 | 121 | -------------------------------------------------------------------------------- /sopa/src/models/odenet_cifar10/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import math 6 | 7 | import numpy as np 8 | from copy import deepcopy 9 | 10 | __all__ = ['MetaNODE', 'metanode4', 'metanode6', 'metanode10', 'metanode18', 'metanode34', 11 | 'premetanode4', 'premetanode6', 'premetanode10', 'premetanode18', 'premetanode34'] 12 | 13 | class Flatten(nn.Module): 14 | def __init__(self): 15 | super(Flatten, self).__init__() 16 | 17 | def forward(self, x): 18 | shape = torch.prod(torch.tensor(x.shape[1:])).item() 19 | return x.view(-1, shape) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | '''Standard ResNet block 24 | ''' 25 | expansion = 1 26 | 27 | def __init__(self, in_planes, planes, stride=1,\ 28 | norm_layer=None, act_layer=None, param_norm=lambda x: x 29 | ): 30 | super(BasicBlock, self).__init__() 31 | self.conv1 = param_norm(nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)) 32 | self.bn1 = norm_layer(planes) 33 | 34 | self.conv2 = param_norm(nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)) 35 | self.bn2 = norm_layer(planes) 36 | 37 | self.act = act_layer 38 | 39 | self.shortcut = nn.Sequential() 40 | if stride != 1 or in_planes != self.expansion * planes: 41 | self.shortcut = nn.Sequential( 42 | param_norm(nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)), 43 | norm_layer(self.expansion * planes) 44 | ) 45 | 46 | def forward(self, x): 47 | out = self.act(self.bn1(self.conv1(x))) 48 | out = self.bn2(self.conv2(out)) 49 | out += self.shortcut(x) 50 | out = self.act(out) 51 | return out 52 | 53 | 54 | class PreBasicBlock(nn.Module): 55 | '''Standard PreResNet block 56 | ''' 57 | expansion = 1 58 | 59 | def __init__(self, in_planes, planes, stride=1, \ 60 | norm_layer=None, act_layer=None, param_norm=lambda x: x 61 | ): 62 | super(PreBasicBlock, self).__init__() 63 | self.bn1 = norm_layer(in_planes) 64 | self.conv1 = param_norm(nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)) 65 | 66 | self.bn2 = norm_layer(planes) 67 | self.conv2 = param_norm(nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)) 68 | 69 | self.act = act_layer 70 | 71 | self.shortcut = nn.Sequential() 72 | if stride != 1 or in_planes != self.expansion * planes: 73 | self.shortcut = nn.Sequential( 74 | param_norm(nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)), 75 | ) 76 | 77 | def forward(self, x): 78 | out = self.conv1(self.act(self.bn1(x))) 79 | out = self.conv2(self.act(self.bn2(out))) 80 | out += self.shortcut(x) 81 | return out 82 | 83 | 84 | class BasicBlock2(nn.Module): 85 | '''Odefunc to use inside MetaODEBlock 86 | ''' 87 | expansion = 1 88 | 89 | def __init__(self, dim, 90 | norm_layer=None, act_layer=None, param_norm=lambda x: x): 91 | super(BasicBlock2, self).__init__() 92 | in_planes = dim 93 | planes = dim 94 | stride = 1 95 | self.nfe = 0 96 | 97 | self.conv1 = param_norm(nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)) 98 | # Replace BN to GN because BN doesn't work with our method normaly 99 | self.bn1 = norm_layer(planes) 100 | 101 | self.conv2 = param_norm(nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)) 102 | self.bn2 = norm_layer(planes) 103 | 104 | self.act = act_layer 105 | 106 | self.shortcut = nn.Sequential() 107 | 108 | def forward(self, t, x, ss_loss = False): 109 | self.nfe += 1 110 | if isinstance(x, tuple): 111 | x = x[0] 112 | out = self.conv1(x) 113 | out = self.bn1(out) 114 | out = self.act(out) 115 | out = self.conv2(out) 116 | out = self.bn2(out) 117 | out = self.act(out) 118 | 119 | if ss_loss: 120 | out = torch.abs(out) 121 | return out 122 | 123 | 124 | class PreBasicBlock2(nn.Module): 125 | '''Odefunc to use inside MetaODEBlock 126 | ''' 127 | expansion = 1 128 | 129 | def __init__(self, dim, 130 | norm_layer=None, act_layer=None, param_norm=lambda x: x): 131 | super(PreBasicBlock2, self).__init__() 132 | in_planes = dim 133 | planes = dim 134 | stride = 1 135 | self.nfe = 0 136 | 137 | # Replace BN to GN because BN doesn't work with our method normaly 138 | self.bn1 = norm_layer(in_planes) 139 | self.conv1 = param_norm(nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)) 140 | 141 | self.bn2 = norm_layer(planes) 142 | self.conv2 = param_norm(nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)) 143 | 144 | self.act = act_layer 145 | 146 | self.shortcut = nn.Sequential() 147 | 148 | def forward(self, t, x, ss_loss=False): 149 | self.nfe += 1 150 | if isinstance(x, tuple): 151 | x = x[0] 152 | out = self.bn1(x) 153 | out = self.act(out) 154 | out = self.conv1(out) 155 | out = self.bn2(out) 156 | out = self.act(out) 157 | out = self.conv2(out) 158 | 159 | if ss_loss: 160 | out = torch.abs(out) 161 | return out 162 | 163 | 164 | class MetaODEBlock(nn.Module): 165 | '''The same as MetaODEBlock for MNIST. Only difference is that odefunc is passed as an keyword argument''' 166 | def __init__(self, odefunc=None): 167 | super(MetaODEBlock, self).__init__() 168 | 169 | self.rhs_func = odefunc 170 | self.integration_time = torch.tensor([0, 1]).float() 171 | 172 | 173 | def forward(self, x, solvers, solver_options): 174 | nsolvers = len(solvers) 175 | 176 | if solver_options.solver_mode == 'standalone': 177 | y = solvers[0].integrate(self.rhs_func, x = x, t = self.integration_time) 178 | 179 | elif solver_options.solver_mode == 'switch': 180 | if solver_options.switch_probs is not None: 181 | switch_probs = solver_options.switch_probs 182 | else: 183 | switch_probs = [1./nsolvers for _ in range(nsolvers)] 184 | solver_id = np.random.choice(range(nsolvers), p = switch_probs) 185 | solver_options.switch_solver_id = solver_id 186 | 187 | y = solvers[solver_id].integrate(self.rhs_func, x = x, t = self.integration_time) 188 | 189 | elif solver_options.solver_mode == 'ensemble': 190 | coin_flip = torch.bernoulli(torch.tensor((1,)), solver_options.ensemble_prob) 191 | solver_options.ensemble_coin_flip = coin_flip 192 | 193 | if coin_flip : 194 | if solver_options.ensemble_weights is not None: 195 | ensemble_weights = solver_options.ensemble_weights 196 | else: 197 | ensemble_weights = [1./nsolvers for _ in range(nsolvers)] 198 | 199 | for i, (wi, solver) in enumerate(zip(ensemble_weights, solvers)): 200 | if i == 0: 201 | y = wi * solver.integrate(self.rhs_func, x = x, t = self.integration_time) 202 | else: 203 | y += wi * solver.integrate(self.rhs_func, x = x, t = self.integration_time) 204 | else: 205 | y = solvers[0].integrate(self.rhs_func, x = x, t = self.integration_time) 206 | 207 | return y[-1,:,:,:,:] 208 | 209 | def ss_loss(self, y, solvers, solver_options): 210 | z0 = y 211 | rhs_func_ss = partial(self.rhs_func, ss_loss = True) 212 | integration_time_ss = self.integration_time + 1 213 | 214 | nsolvers = len(solvers) 215 | 216 | if solver_options.solver_mode == 'standalone': 217 | z = solvers[0].integrate(rhs_func_ss.func, x = y, t = integration_time_ss) 218 | 219 | elif solver_options.solver_mode == 'switch': 220 | if solver_options.switch_probs is not None: 221 | switch_probs = solver_options.switch_probs 222 | else: 223 | switch_probs = [1./nsolvers for _ in range(nsolvers)] 224 | solver_id = solver_options.switch_solver_id 225 | 226 | z = solvers[solver_id].integrate(rhs_func_ss.func, x = y, t = integration_time_ss) 227 | 228 | elif solver_options.solver_mode == 'ensemble': 229 | coin_flip = solver_options.ensemble_coin_flip 230 | 231 | if coin_flip : 232 | if solver_options.ensemble_weights is not None: 233 | ensemble_weights = solver_options.ensemble_weights 234 | else: 235 | ensemble_weights = [1./nsolvers for _ in range(nsolvers)] 236 | 237 | for i, (wi, solver) in enumerate(zip(ensemble_weights, solvers)): 238 | if i == 0: 239 | z = wi * solver.integrate(rhs_func_ss.func, x = y, t = integration_time_ss) 240 | else: 241 | z += wi * solver.integrate(rhs_func_ss.func, x = y, t = integration_time_ss) 242 | else: 243 | z = solvers[0].integrate(rhs_func_ss.func, x = y, t = integration_time_ss) 244 | 245 | z = z[-1,:,:,:,:] - z0 246 | z = torch.norm(z.reshape((z.shape[0], -1)), dim = 1) 247 | z = torch.mean(z) 248 | 249 | return z 250 | 251 | 252 | class MetaLayer(nn.Module): 253 | ''' 254 | norm_layers_: tuple of normalization layers for (BasicBlock, BasicBlock2, bn1) 255 | param_norm_layers_: tuple of normalizations for weights in (BasicBlock, BasicBlock2, conv1) 256 | act_layers_: tuple of activation layers for (BasicBlock, BasicBlock2, activation after bn1) 257 | resblock: BasicBlock or PreBasicBlock 258 | odefunc: BasicBlock2 or PreBasicBlock2 259 | 260 | ''' 261 | def __init__(self, planes, num_blocks, stride, norm_layers_, param_norm_layers_, act_layers_, 262 | in_planes, resblock=None, odefunc=None): 263 | super(MetaLayer, self).__init__() 264 | 265 | num_resblocks, num_odeblocks = num_blocks 266 | 267 | strides = [stride] + [1] * (num_resblocks + num_odeblocks - 1) 268 | layers_res = [] 269 | layers_ode = [] 270 | 271 | self.in_planes = in_planes 272 | for stride in strides[:num_resblocks]: 273 | layers_res.append(resblock(self.in_planes, planes, stride, 274 | norm_layer = norm_layers_[0], 275 | param_norm=param_norm_layers_[0], 276 | act_layer = act_layers_[0])) 277 | self.in_planes = planes * resblock.expansion 278 | 279 | for stride in strides[num_resblocks:]: 280 | layers_ode.append(MetaODEBlock(odefunc(self.in_planes, 281 | norm_layer=norm_layers_[1], 282 | param_norm=param_norm_layers_[1], 283 | act_layer=act_layers_[1]))) 284 | 285 | self.blocks_res = nn.Sequential(*layers_res) 286 | self.blocks_ode = nn.ModuleList(layers_ode) 287 | 288 | 289 | def forward(self, x, solvers=None, solver_options=None, loss_options = None): 290 | x = self.blocks_res(x) 291 | 292 | self.ss_loss = 0 293 | 294 | for block in self.blocks_ode: 295 | x = block(x, solvers, solver_options) 296 | 297 | if (loss_options is not None) and loss_options.ss_loss: 298 | z = block.ss_loss(x, solvers, solver_options) 299 | self.ss_loss += z 300 | 301 | return x 302 | 303 | def get_ss_loss(self): 304 | return self.ss_loss 305 | 306 | @property 307 | def nfe(self): 308 | per_block_nfe = {idx: block.nfe for idx, block in enumerate(self.blocks_ode)} 309 | return sum(per_block_nfe) 310 | 311 | @nfe.setter 312 | def nfe(self, value): 313 | for block in self.blocks_ode: 314 | block.nfe = value 315 | 316 | 317 | class MetaNODE(nn.Module): 318 | def __init__(self, num_blocks, num_classes=10, 319 | norm_layers_=(None, None, None), 320 | param_norm_layers_=(lambda x: x, lambda x: x, lambda x: x), 321 | act_layers_=(None, None, None), 322 | in_planes_=64, 323 | resblock=None, 324 | odefunc=None): 325 | ''' 326 | norm_layers_: tuple of normalization layers for (BasicBlock, BasicBlock2, bn1) 327 | param_norm_layers_: tuple of normalizations for weights in (BasicBlock, BasicBlock2, conv1) 328 | act_layers_: tuple of activation layers for (BasicBlock, BasicBlock2, activation after bn1) 329 | resblock: BasicBlock or PreBasicBlock 330 | odefunc: BasicBlock2 or PreBasicBlock2 331 | ''' 332 | 333 | super(MetaNODE, self).__init__() 334 | self.in_planes = in_planes_ 335 | 336 | self.n_layers = len(num_blocks) 337 | self.n_features_linear = in_planes_ 338 | 339 | self.is_preactivation = False 340 | if (((resblock is not None) and isinstance(resblock, PreBasicBlock)) or 341 | ((odefunc is not None) and isinstance(odefunc, PreBasicBlock2))): 342 | self.is_preactivation = True 343 | 344 | self.conv1 = param_norm_layers_[2](nn.Conv2d(3, in_planes_, kernel_size=3, stride=1, padding=1, bias=False)) 345 | self.bn1 = norm_layers_[2](in_planes_) 346 | self.act = act_layers_[2] 347 | 348 | self.layer1 = MetaLayer(in_planes_, num_blocks[0], stride=1, 349 | norm_layers_=norm_layers_[:2], 350 | param_norm_layers_=param_norm_layers_[:2], 351 | act_layers_=act_layers_[:2], 352 | in_planes=self.in_planes, 353 | resblock=resblock, 354 | odefunc=odefunc 355 | ) 356 | 357 | if self.n_layers >= 2: 358 | self.n_features_linear *= 2 359 | self.layer2 = MetaLayer(in_planes_*2, num_blocks[1], stride=2, 360 | norm_layers_=norm_layers_[:2], 361 | param_norm_layers_=param_norm_layers_[:2], 362 | act_layers_=act_layers_[:2], 363 | in_planes=self.layer1.in_planes, 364 | resblock=resblock, 365 | odefunc=odefunc 366 | ) 367 | 368 | if self.n_layers >= 3: 369 | self.n_features_linear *= 2 370 | self.layer3 = MetaLayer(in_planes_*4, num_blocks[2], stride=2, 371 | norm_layers_=norm_layers_[:2], 372 | param_norm_layers_=param_norm_layers_[:2], 373 | act_layers_=act_layers_[:2], 374 | in_planes=self.layer2.in_planes, 375 | resblock=resblock, 376 | odefunc=odefunc 377 | ) 378 | 379 | if self.n_layers >= 4: 380 | self.n_features_linear *= 2 381 | self.layer4 = MetaLayer(in_planes_*8, num_blocks[3], stride=2, 382 | norm_layers_=norm_layers_[:2], 383 | param_norm_layers_=param_norm_layers_[:2], 384 | act_layers_=act_layers_[:2], 385 | in_planes=self.layer3.in_planes, 386 | resblock=resblock, 387 | odefunc=odefunc 388 | ) 389 | 390 | self.fc_layers = nn.Sequential(*[nn.AdaptiveAvgPool2d((1, 1)), 391 | Flatten(), 392 | nn.Linear(self.n_features_linear * resblock.expansion, num_classes)]) 393 | self.nfe = 0 394 | 395 | @property 396 | def nfe(self): 397 | nfe = 0 398 | for idx in range(1, self.n_layers+1): 399 | nfe += self.__getattr__('layer{}'.format(idx)).nfe 400 | return nfe 401 | 402 | @nfe.setter 403 | def nfe(self, value): 404 | for idx in range(1, self.n_layers+1): 405 | self.__getattr__('layer{}'.format(idx)).nfe = value 406 | 407 | 408 | def forward(self, x, solvers=None, solver_options=None, loss_options = None): 409 | self.ss_loss = 0 410 | 411 | out = self.conv1(x) 412 | if not self.is_preactivation: 413 | out = self.act(self.bn1(out)) 414 | 415 | for idx in range(1, self.n_layers + 1): 416 | out = self.__getattr__('layer{}'.format(idx))(out, 417 | solvers=solvers, solver_options=solver_options, 418 | loss_options = loss_options) 419 | 420 | self.ss_loss += self.__getattr__('layer{}'.format(idx)).ss_loss 421 | 422 | if self.is_preactivation: 423 | out = self.act(self.bn1(out)) 424 | 425 | out = self.fc_layers(out) 426 | return out 427 | 428 | 429 | def metanode4(norm_layers, param_norm_layers, act_layers, in_planes, is_odenet = True): 430 | if is_odenet: 431 | num_blocks = [(0, 1)] 432 | else: 433 | num_blocks = [(1, 0)] 434 | return MetaNODE(num_blocks, 435 | norm_layers_= norm_layers, param_norm_layers_=param_norm_layers, act_layers_=act_layers, 436 | in_planes_ = in_planes, 437 | resblock=BasicBlock, 438 | odefunc=BasicBlock2 439 | ) 440 | 441 | 442 | def metanode6(norm_layers, param_norm_layers, act_layers, in_planes, is_odenet = True): 443 | if is_odenet: 444 | num_blocks = [(1, 1)] 445 | else: 446 | num_blocks = [(2, 0)] 447 | return MetaNODE(num_blocks, 448 | norm_layers_= norm_layers, param_norm_layers_=param_norm_layers, act_layers_=act_layers, 449 | in_planes_ = in_planes, 450 | resblock=BasicBlock, 451 | odefunc=BasicBlock2 452 | ) 453 | 454 | 455 | def metanode10(norm_layers, param_norm_layers, act_layers, in_planes, is_odenet = True): 456 | if is_odenet: 457 | num_blocks = [(1, 1), (1, 1)] 458 | else: 459 | num_blocks = [(2, 0), (2, 0)] 460 | return MetaNODE(num_blocks, 461 | norm_layers_= norm_layers, param_norm_layers_=param_norm_layers, act_layers_=act_layers, 462 | in_planes_ = in_planes, 463 | resblock=BasicBlock, 464 | odefunc=BasicBlock2 465 | ) 466 | 467 | 468 | def metanode18(norm_layers, param_norm_layers, act_layers, in_planes, is_odenet = True): 469 | if is_odenet: 470 | num_blocks = [(1, 1), (1, 1), (1, 1), (1, 1)] 471 | else: 472 | num_blocks = [(2, 0), (2, 0), (2, 0), (2, 0)] 473 | return MetaNODE(num_blocks, 474 | norm_layers_= norm_layers, param_norm_layers_=param_norm_layers, act_layers_=act_layers, 475 | in_planes_ = in_planes, 476 | resblock=BasicBlock, 477 | odefunc=BasicBlock2 478 | ) 479 | 480 | 481 | def metanode34(norm_layers, param_norm_layers, act_layers, in_planes, is_odenet = True): 482 | if is_odenet: 483 | num_blocks = [(1, 2), (1, 3), (1, 5), (1, 2)] 484 | else: 485 | num_blocks = [(3, 0), (4, 0), (6, 0), (3, 0)] 486 | return MetaNODE(num_blocks, 487 | norm_layers_= norm_layers, param_norm_layers_=param_norm_layers, act_layers_=act_layers, 488 | in_planes_ = in_planes, 489 | resblock=BasicBlock, 490 | odefunc=BasicBlock2 491 | ) 492 | 493 | 494 | def premetanode4(norm_layers, param_norm_layers, act_layers, in_planes, is_odenet=True): 495 | if is_odenet: 496 | num_blocks = [(0, 1)] 497 | else: 498 | num_blocks = [(1, 0)] 499 | return MetaNODE(num_blocks, 500 | norm_layers_=norm_layers, param_norm_layers_=param_norm_layers, act_layers_=act_layers, 501 | in_planes_=in_planes, 502 | resblock=PreBasicBlock, 503 | odefunc=PreBasicBlock2 504 | ) 505 | 506 | 507 | def premetanode6(norm_layers, param_norm_layers, act_layers, in_planes, is_odenet=True): 508 | if is_odenet: 509 | num_blocks = [(1, 1)] 510 | else: 511 | num_blocks = [(2, 0)] 512 | return MetaNODE(num_blocks, 513 | norm_layers_=norm_layers, param_norm_layers_=param_norm_layers, act_layers_=act_layers, 514 | in_planes_=in_planes, 515 | resblock=PreBasicBlock, 516 | odefunc=PreBasicBlock2 517 | ) 518 | 519 | 520 | def premetanode10(norm_layers, param_norm_layers, act_layers, in_planes, is_odenet=True): 521 | if is_odenet: 522 | num_blocks = [(1, 1), (1, 1)] 523 | else: 524 | num_blocks = [(2, 0), (2, 0)] 525 | return MetaNODE(num_blocks, 526 | norm_layers_=norm_layers, param_norm_layers_=param_norm_layers, act_layers_=act_layers, 527 | in_planes_=in_planes, 528 | resblock=PreBasicBlock, 529 | odefunc=PreBasicBlock2 530 | ) 531 | 532 | 533 | def premetanode18(norm_layers, param_norm_layers, act_layers, in_planes, is_odenet=True): 534 | if is_odenet: 535 | num_blocks = [(1, 1), (1, 1), (1, 1), (1, 1)] 536 | else: 537 | num_blocks = [(2, 0), (2, 0), (2, 0), (2, 0)] 538 | return MetaNODE(num_blocks, 539 | norm_layers_=norm_layers, param_norm_layers_=param_norm_layers, act_layers_=act_layers, 540 | in_planes_=in_planes, 541 | resblock=PreBasicBlock, 542 | odefunc=PreBasicBlock2 543 | ) 544 | 545 | 546 | def premetanode34(norm_layers, param_norm_layers, act_layers, in_planes, is_odenet=True): 547 | if is_odenet: 548 | num_blocks = [(1, 2), (1, 3), (1, 5), (1, 2)] 549 | else: 550 | num_blocks = [(3, 0), (4, 0), (6, 0), (3, 0)] 551 | return MetaNODE(num_blocks, 552 | norm_layers_=norm_layers, param_norm_layers_=param_norm_layers, act_layers_=act_layers, 553 | in_planes_=in_planes, 554 | resblock=PreBasicBlock, 555 | odefunc=PreBasicBlock2 556 | ) 557 | -------------------------------------------------------------------------------- /sopa/src/models/odenet_cifar10/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch.nn.init as init 4 | from torch.nn.utils import spectral_norm, weight_norm 5 | from functools import partial 6 | 7 | 8 | class Identity(nn.Module): 9 | def __init__(self, *args, **kwargs): 10 | super().__init__() 11 | 12 | def forward(self, x): 13 | return x 14 | 15 | def get_normalization(key, num_groups=32): 16 | '''Create a normalization layer given name of layer type. 17 | 18 | :param key: str 19 | Type of normalization layer to use after convolutional layer (e.g. type of layer output normalization). 20 | Can be one of BN (batch normalization), GN (group normalization), LN (layer normalization), 21 | IN (instance normalization), NF(no normalization) 22 | :param num_groups: int 23 | Number of groups for GN normalization 24 | :return: nn.Module 25 | Normalization layer 26 | ''' 27 | if key == 'BN': 28 | return nn.BatchNorm2d 29 | elif key == 'LN': 30 | return partial(nn.GroupNorm, 1) 31 | elif key == 'GN': 32 | return partial(nn.GroupNorm, num_groups) 33 | elif key == 'IN': 34 | return nn.InstanceNorm2d 35 | elif key == 'NF': 36 | return Identity 37 | else: 38 | raise NameError('Unknown layer normalization type') 39 | 40 | def get_param_normalization(key): 41 | '''Create a function to normalize layer weights given name of normalization. 42 | 43 | :param key: str 44 | Type of normalization applied to layer's weights. 45 | Can be one of SN (spectral normalization), WN (weight normalization), PNF (no weight normalization). 46 | :return: function 47 | ''' 48 | if key == 'SN': 49 | return spectral_norm 50 | elif key == 'WN': 51 | return weight_norm 52 | elif key == 'PNF': 53 | return lambda x: x 54 | else: 55 | raise NameError('Unknown param normalization type') 56 | 57 | def get_activation(key): 58 | '''Create an activation function given name of function type. 59 | 60 | :param key: str 61 | Type of activation layer. 62 | Can be one of: ReLU, AF (no activation/linear activation) 63 | :return: function 64 | ''' 65 | if key == 'ReLU': 66 | return F.relu 67 | elif key == 'GeLU': 68 | return F.gelu 69 | elif key == 'Softsign': 70 | return F.softsign 71 | elif key == 'Tanh': 72 | return F.tanh 73 | elif key == 'AF': 74 | return partial(F.leaky_relu, negative_slope=1) 75 | else: 76 | raise NameError('Unknown activation type') 77 | 78 | def conv_init(m): 79 | class_name = m.__class__.__name__ 80 | if class_name.find('Conv') != -1 and m.bias is not None: 81 | init.xavier_uniform_(m.weight, gain=np.sqrt(2)) 82 | init.constant_(m.bias, 0) 83 | elif class_name.find('BatchNorm') != -1: 84 | init.constant_(m.weight, 1) 85 | init.constant_(m.bias, 0) 86 | 87 | def conv_init_orthogonal(m): 88 | if isinstance(m, nn.Conv2d): 89 | init.orthogonal_(m.weight) 90 | 91 | def fc_init_orthogonal(m): 92 | if isinstance(m, nn.Linear): 93 | init.orthogonal_(m.weight) 94 | init.constant_(m.bias, 1e-3) 95 | 96 | -------------------------------------------------------------------------------- /sopa/src/models/odenet_mnist/attacks_runner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms as transforms 4 | from torch.utils.data import DataLoader 5 | import torchvision.datasets as datasets 6 | 7 | 8 | import numpy as np 9 | import glob 10 | import pandas as pd 11 | from collections import defaultdict 12 | from argparse import Namespace 13 | 14 | import os 15 | import sys 16 | sys.path.append('/workspace/home/jgusak/neural-ode-sopa') 17 | 18 | import dataloaders 19 | 20 | from sopa.src.models.odenet_mnist.layers import MetaNODE 21 | from sopa.src.solvers.utils import create_solver 22 | 23 | from sopa.src.models.utils import load_model 24 | from sopa.src.models.odenet_mnist.attacks_utils import run_attack 25 | 26 | from MegaAdversarial.src.utils.runner import fix_seeds 27 | 28 | import argparse 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--data_root', type=str, default='') 31 | parser.add_argument('--models_root', type=str, default='') 32 | parser.add_argument('--save_path', type=str, default='') 33 | parser.add_argument('--key_path_word', type=str, default='') 34 | 35 | parser.add_argument('--min_eps', type=float, default=0.) 36 | parser.add_argument('--max_eps', type=float, default=0.3) 37 | parser.add_argument('--num_eps', type=int, default=20) 38 | parser.add_argument('--epsilons', type=lambda s: [float(eps) for eps in s.split(',')], default=None) 39 | 40 | args = parser.parse_args() 41 | 42 | 43 | def run_set_of_attacks(epsilons, attack_modes, loaders, models_root, save_path = None, device = 'cuda', key_path_word = ''): 44 | fix_seeds() 45 | 46 | df = pd.DataFrame() 47 | 48 | if key_path_word == 'u2': 49 | meta_model_path = "{}/*/*/*.pth".format(models_root) 50 | elif key_path_word == 'euler': 51 | meta_model_path = "{}/*/*.pth".format(models_root) 52 | else: 53 | meta_model_path = "{}/*/*.pth".format(models_root) 54 | 55 | 56 | i = 0 57 | for state_dict_path in glob.glob(meta_model_path, recursive = True): 58 | if not (key_path_word in state_dict_path): 59 | continue 60 | 61 | model, model_args = load_model(state_dict_path) 62 | model.eval() 63 | model.cuda() 64 | 65 | val_solver = create_solver(*model_args.solvers[0], dtype = torch.float32, device = device) 66 | val_solver_options = Namespace(solver_mode = 'standalone') 67 | solvers_kwargs = {'solvers':[val_solver], 'solver_options':val_solver_options} 68 | 69 | robust_accuracies = run_attack(model, epsilons, attack_modes, loaders, device, solvers_kwargs=solvers_kwargs) 70 | robust_accuracies = {k : np.array(v) for k,v in robust_accuracies.items()} 71 | 72 | data = [list(dict(model_args._get_kwargs()).values()) +\ 73 | list(robust_accuracies.values()) +\ 74 | [epsilons] 75 | ] 76 | columns = list(dict(model_args._get_kwargs()).keys()) +\ 77 | list(robust_accuracies.keys()) +\ 78 | ['epsilons'] 79 | 80 | df_tmp = pd.DataFrame(data = data, columns = columns) 81 | df = df.append(df_tmp) 82 | 83 | if save_path is not None: 84 | df.to_csv(save_path, index = False) 85 | 86 | i += 1 87 | print('{} models have been processed'.format(i)) 88 | 89 | 90 | 91 | if __name__=="__main__": 92 | loaders = dataloaders.get_loader(batch_size=256, 93 | data_name='mnist', 94 | data_root=args.data_root, 95 | num_workers = 4, 96 | train=False, val=True) 97 | device = 'cuda' 98 | 99 | if args.epsilons is not None: 100 | epsilons = args.epsilons 101 | else: 102 | epsilons = np.linspace(args.min_eps, args.max_eps, num=args.num_eps) 103 | 104 | run_set_of_attacks(epsilons=epsilons, 105 | attack_modes = ["fgsm", "at", "at_ls", "av", "fs"][:1], 106 | loaders = loaders, 107 | models_root = args.models_root, 108 | save_path = args.save_path, 109 | key_path_word = args.key_path_word, 110 | device = device) 111 | 112 | 113 | # CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=4 python3 attacks_runner.py --data_root "/workspace/raid/data/datasets" --models_root "/workspace/raid/data/jgusak/neural-ode-sopa/odenet_mnist_noise/" --save_path '/workspace/home/jgusak/neural-ode-sopa/experiments/odenet_mnist/results/robust_accuracies_fgsm/test_part1.csv' --epsilons 0.15,0.3,0.5 114 | -------------------------------------------------------------------------------- /sopa/src/models/odenet_mnist/attacks_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms as transforms 4 | from torch.utils.data import DataLoader 5 | import torchvision.datasets as datasets 6 | 7 | 8 | import numpy as np 9 | from collections import defaultdict 10 | 11 | from MegaAdversarial.src.utils.runner import test, test_ensemble, fix_seeds 12 | from MegaAdversarial.src.attacks import ( 13 | Clean, 14 | FGSM, 15 | PGD, 16 | LabelSmoothing, 17 | AdversarialVertex, 18 | AdversarialVertexExtra, 19 | FeatureScatter, 20 | NCEScatter, 21 | NCEScatterWithBuffer, 22 | MetaAttack, 23 | FGSM2Ensemble 24 | ) 25 | 26 | 27 | 28 | def run_attack(model, epsilons, attack_modes, loaders, device='cuda', solvers_kwargs = None, 29 | at_lr = None, at_n_iter = None): 30 | robust_accuracies = defaultdict(list) 31 | 32 | for mode in attack_modes: 33 | for epsilon in epsilons: 34 | # CONFIG_PGD_TEST = {"eps": epsilon, "lr": 2.0 / 255 * 10, "n_iter": 20} 35 | CONFIG_PGD_TEST = {"eps": epsilon, "lr": at_lr, "n_iter": at_n_iter} 36 | CONFIG_FGSM_TEST = {"eps": epsilon} 37 | 38 | if mode == "clean": 39 | test_attack = Clean(model) 40 | elif mode == "fgsm": 41 | test_attack = FGSM(model, **CONFIG_FGSM_TEST) 42 | 43 | elif mode == "at": 44 | test_attack = PGD(model, **CONFIG_PGD_TEST) 45 | 46 | # elif mode == "at_ls": 47 | # test_attack = PGD(model, **CONFIG_PGD_TEST) # wrong func, fix this 48 | 49 | # elif mode == "av": 50 | # test_attack = PGD(model, **CONFIG_PGD_TEST) # wrong func, fix this 51 | 52 | # elif mode == "fs": 53 | # test_attack = PGD(model, **CONFIG_PGD_TEST) # wrong func, fix this 54 | 55 | print("Attack {}".format(mode)) 56 | test_metrics = test(loaders["val"], model, test_attack, device, show_progress=True, solvers_kwargs = solvers_kwargs) 57 | test_log = f"Test: | " + " | ".join( 58 | map(lambda x: f"{x[0]}: {x[1]:.6f}", test_metrics.items()) 59 | ) 60 | print(test_log) 61 | 62 | robust_accuracies['accuracy_{}'.format(mode)].append(test_metrics['accuracy_adv']) 63 | 64 | return robust_accuracies 65 | 66 | 67 | 68 | def run_attack2ensemble(models, epsilons, attack_modes, loaders, device='cuda', solvers_kwargs_arr = None, 69 | at_lr = None, at_n_iter = None): 70 | robust_accuracies = defaultdict(list) 71 | 72 | for mode in attack_modes: 73 | for epsilon in epsilons: 74 | # CONFIG_PGD_TEST = {"eps": epsilon, "lr": 2.0 / 255 * 10, "n_iter": 20} 75 | CONFIG_PGD_TEST = {"eps": epsilon, "lr": at_lr, "n_iter": at_n_iter} 76 | CONFIG_FGSM_TEST = {"eps": epsilon} 77 | 78 | if mode == "fgsm": 79 | test_attack2ensemble = FGSM2Ensemble(models, **CONFIG_FGSM_TEST) 80 | else: 81 | raise NotImplementedError 82 | 83 | print("Attack {}".format(mode)) 84 | test_metrics = test_ensemble(loaders["val"], 85 | models, 86 | test_attack2ensemble, 87 | device, 88 | show_progress=True, 89 | solvers_kwargs_arr = solvers_kwargs_arr) 90 | 91 | test_log = f"Test: | " + " | ".join( 92 | map(lambda x: f"{x[0]}: {x[1]:.6f}", test_metrics.items()) 93 | ) 94 | print(test_log) 95 | 96 | robust_accuracies['accuracy_{}'.format(mode)].append(test_metrics['accuracy_adv']) 97 | 98 | return robust_accuracies -------------------------------------------------------------------------------- /sopa/src/models/odenet_mnist/data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import torchvision.datasets as datasets 3 | import torchvision.transforms as transforms 4 | 5 | 6 | def get_mnist_loaders(data_aug=False, batch_size=128, test_batch_size=1000, perc=1.0, data_root=None): 7 | if data_aug: 8 | transform_train = transforms.Compose([ 9 | transforms.RandomCrop(28, padding=4), 10 | transforms.ToTensor(), 11 | ]) 12 | else: 13 | transform_train = transforms.Compose([ 14 | transforms.ToTensor(), 15 | ]) 16 | 17 | transform_test = transforms.Compose([ 18 | transforms.ToTensor(), 19 | ]) 20 | 21 | train_loader = DataLoader( 22 | datasets.MNIST(root=data_root, train=True, download=True, transform=transform_train), batch_size=batch_size, 23 | shuffle=True, num_workers=2, drop_last=True 24 | ) 25 | 26 | train_eval_loader = DataLoader( 27 | datasets.MNIST(root=data_root, train=True, download=True, transform=transform_test), 28 | batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True 29 | ) 30 | 31 | test_loader = DataLoader( 32 | datasets.MNIST(root=data_root, train=False, download=True, transform=transform_test), 33 | batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True 34 | ) 35 | 36 | return train_loader, test_loader, train_eval_loader 37 | 38 | 39 | def get_svhn_loaders(data_aug=False, batch_size=128, test_batch_size=1000, perc=1.0, data_root=None): 40 | if data_aug: 41 | transform_train = transforms.Compose([ 42 | transforms.RandomCrop(32, padding=4), 43 | transforms.ToTensor(), 44 | ]) 45 | else: 46 | transform_train = transforms.Compose([ 47 | transforms.ToTensor(), 48 | ]) 49 | 50 | transform_test = transforms.Compose([ 51 | transforms.ToTensor(), 52 | ]) 53 | 54 | train_loader = DataLoader( 55 | datasets.SVHN(root=data_root, split='train', download=True, transform=transform_train), batch_size=batch_size, 56 | shuffle=True, num_workers=2, drop_last=True 57 | ) 58 | 59 | train_eval_loader = DataLoader( 60 | datasets.SVHN(root=data_root, split='train', download=True, transform=transform_test), 61 | batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True 62 | ) 63 | 64 | test_loader = DataLoader( 65 | datasets.SVHN(root=data_root, split='test', download=True, transform=transform_test), 66 | batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True 67 | ) 68 | 69 | return train_loader, test_loader, train_eval_loader 70 | 71 | 72 | def inf_generator(iterable): 73 | """Allows training with DataLoaders in a single infinite loop: 74 | for i, (x, y) in enumerate(inf_generator(train_loader)): 75 | """ 76 | iterator = iterable.__iter__() 77 | while True: 78 | try: 79 | yield iterator.__next__() 80 | except StopIteration: 81 | iterator = iterable.__iter__() 82 | -------------------------------------------------------------------------------- /sopa/src/models/odenet_mnist/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from functools import partial 6 | 7 | 8 | class MetaODEBlock(nn.Module): 9 | def __init__(self, activation_type = 'relu'): 10 | super(MetaODEBlock, self).__init__() 11 | 12 | self.rhs_func = ODEfunc(64, activation_type) 13 | self.integration_time = torch.tensor([0, 1]).float() 14 | 15 | 16 | def forward(self, x, solvers, solver_options): 17 | nsolvers = len(solvers) 18 | 19 | if solver_options.solver_mode == 'standalone': 20 | y = solvers[0].integrate(self.rhs_func, x = x, t = self.integration_time) 21 | 22 | elif solver_options.solver_mode == 'switch': 23 | if solver_options.switch_probs is not None: 24 | switch_probs = solver_options.switch_probs 25 | else: 26 | switch_probs = [1./nsolvers for _ in range(nsolvers)] 27 | solver_id = np.random.choice(range(nsolvers), p = switch_probs) 28 | solver_options.switch_solver_id = solver_id 29 | 30 | y = solvers[solver_id].integrate(self.rhs_func, x = x, t = self.integration_time) 31 | 32 | elif solver_options.solver_mode == 'ensemble': 33 | coin_flip = torch.bernoulli(torch.tensor((1,)), solver_options.ensemble_prob) 34 | solver_options.ensemble_coin_flip = coin_flip 35 | 36 | if coin_flip : 37 | if solver_options.ensemble_weights is not None: 38 | ensemble_weights = solver_options.ensemble_weights 39 | else: 40 | ensemble_weights = [1./nsolvers for _ in range(nsolvers)] 41 | 42 | for i, (wi, solver) in enumerate(zip(ensemble_weights, solvers)): 43 | if i == 0: 44 | y = wi * solver.integrate(self.rhs_func, x = x, t = self.integration_time) 45 | else: 46 | y += wi * solver.integrate(self.rhs_func, x = x, t = self.integration_time) 47 | else: 48 | y = solvers[0].integrate(self.rhs_func, x = x, t = self.integration_time) 49 | 50 | return y[-1,:,:,:,:] 51 | 52 | 53 | def ss_loss(self, y, solvers, solver_options): 54 | z0 = y 55 | rhs_func_ss = partial(self.rhs_func, ss_loss = True) 56 | integration_time_ss = self.integration_time + 1 57 | 58 | nsolvers = len(solvers) 59 | 60 | if solver_options.solver_mode == 'standalone': 61 | z = solvers[0].integrate(rhs_func_ss.func, x = y, t = integration_time_ss) 62 | 63 | elif solver_options.solver_mode == 'switch': 64 | if solver_options.switch_probs is not None: 65 | switch_probs = solver_options.switch_probs 66 | else: 67 | switch_probs = [1./nsolvers for _ in range(nsolvers)] 68 | solver_id = solver_options.switch_solver_id 69 | 70 | z = solvers[solver_id].integrate(rhs_func_ss.func, x = y, t = integration_time_ss) 71 | 72 | elif solver_options.solver_mode == 'ensemble': 73 | coin_flip = solver_options.ensemble_coin_flip 74 | 75 | if coin_flip : 76 | if solver_options.ensemble_weights is not None: 77 | ensemble_weights = solver_options.ensemble_weights 78 | else: 79 | ensemble_weights = [1./nsolvers for _ in range(nsolvers)] 80 | 81 | for i, (wi, solver) in enumerate(zip(ensemble_weights, solvers)): 82 | if i == 0: 83 | z = wi * solver.integrate(rhs_func_ss.func, x = y, t = integration_time_ss) 84 | else: 85 | z += wi * solver.integrate(rhs_func_ss.func, x = y, t = integration_time_ss) 86 | else: 87 | z = solvers[0].integrate(rhs_func_ss.func, x = y, t = integration_time_ss) 88 | 89 | z = z[-1,:,:,:,:] - z0 90 | z = torch.norm(z.reshape((z.shape[0], -1)), dim = 1) 91 | z = torch.mean(z) 92 | 93 | return z 94 | 95 | 96 | class MetaNODE(nn.Module): 97 | 98 | def __init__(self, downsampling_method = 'conv', is_odenet = True, activation_type = 'relu', in_channels = 1): 99 | super(MetaNODE, self).__init__() 100 | 101 | self.is_odenet = is_odenet 102 | 103 | self.downsampling_layers = nn.Sequential(*build_downsampling_layers(downsampling_method, in_channels)) 104 | self.fc_layers = nn.Sequential(*build_fc_layers()) 105 | 106 | if is_odenet: 107 | self.blocks = nn.ModuleList([MetaODEBlock(activation_type)]) 108 | else: 109 | self.blocks = nn.ModuleList([ResBlock(64, 64) for _ in range(6)]) 110 | 111 | 112 | def forward(self, x, solvers=None, solver_options=None, loss_options = None): 113 | self.ss_loss = 0 114 | 115 | x = self.downsampling_layers(x) 116 | 117 | for block in self.blocks: 118 | if self.is_odenet: 119 | x = block(x, solvers, solver_options) 120 | 121 | if (loss_options is not None) and loss_options.ss_loss: 122 | z = block.ss_loss(x, solvers, solver_options) 123 | self.ss_loss += z 124 | else: 125 | x = block(x) 126 | 127 | x = self.fc_layers(x) 128 | return x 129 | 130 | def get_ss_loss(self): 131 | return self.ss_loss 132 | 133 | 134 | class ODEfunc(nn.Module): 135 | 136 | def __init__(self, dim, activation_type = 'relu'): 137 | super(ODEfunc, self).__init__() 138 | 139 | if activation_type == 'tanh': 140 | activation = nn.Tanh() 141 | elif activation_type == 'softplus': 142 | activation = nn.Softplus() 143 | elif activation_type == 'softsign': 144 | activation = nn.Softsign() 145 | elif activation_type == 'relu': 146 | activation = nn.ReLU() 147 | else: 148 | raise NotImplementedError('{} activation is not implemented'.format(activation_type)) 149 | 150 | self.norm1 = norm(dim) 151 | self.relu = nn.ReLU(inplace=True) 152 | self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1) 153 | self.norm2 = norm(dim) 154 | self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1) 155 | self.norm3 = norm(dim) 156 | self.nfe = 0 157 | 158 | def forward(self, t, x, ss_loss = False): 159 | self.nfe += 1 160 | out = self.norm1(x) 161 | out = self.relu(out) 162 | out = self.conv1(t, out) 163 | out = self.norm2(out) 164 | out = self.relu(out) 165 | out = self.conv2(t, out) 166 | out = self.norm3(out) 167 | 168 | if ss_loss: 169 | out = torch.abs(out) 170 | 171 | return out 172 | 173 | def build_downsampling_layers(downsampling_method = 'conv', in_channels = 1): 174 | if downsampling_method == 'conv': 175 | downsampling_layers = [ 176 | nn.Conv2d(in_channels, 64, 3, 1), 177 | norm(64), 178 | nn.ReLU(inplace=True), 179 | nn.Conv2d(64, 64, 4, 2, 1), 180 | norm(64), 181 | nn.ReLU(inplace=True), 182 | nn.Conv2d(64, 64, 4, 2, 1), 183 | ] 184 | elif downsampling_method == 'res': 185 | downsampling_layers = [ 186 | nn.Conv2d(in_channels, 64, 3, 1), 187 | ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)), 188 | ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)), 189 | ] 190 | return downsampling_layers 191 | 192 | 193 | def build_fc_layers(): 194 | fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)] 195 | return fc_layers 196 | 197 | 198 | def conv3x3(in_planes, out_planes, stride=1): 199 | """3x3 convolution with padding""" 200 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 201 | 202 | 203 | def conv1x1(in_planes, out_planes, stride=1): 204 | """1x1 convolution""" 205 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 206 | 207 | 208 | def norm(dim): 209 | return nn.GroupNorm(min(32, dim), dim) 210 | 211 | 212 | class ResBlock(nn.Module): 213 | expansion = 1 214 | 215 | def __init__(self, inplanes, planes, stride=1, downsample=None): 216 | super(ResBlock, self).__init__() 217 | self.norm1 = norm(inplanes) 218 | self.relu = nn.ReLU(inplace=True) 219 | self.downsample = downsample 220 | self.conv1 = conv3x3(inplanes, planes, stride) 221 | self.norm2 = norm(planes) 222 | self.conv2 = conv3x3(planes, planes) 223 | 224 | def forward(self, x): 225 | shortcut = x 226 | 227 | out = self.relu(self.norm1(x)) 228 | 229 | if self.downsample is not None: 230 | shortcut = self.downsample(out) 231 | 232 | out = self.conv1(out) 233 | out = self.norm2(out) 234 | out = self.relu(out) 235 | out = self.conv2(out) 236 | 237 | return out + shortcut 238 | 239 | 240 | class ConcatConv2d(nn.Module): 241 | 242 | def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): 243 | super(ConcatConv2d, self).__init__() 244 | module = nn.ConvTranspose2d if transpose else nn.Conv2d 245 | self._layer = module( 246 | dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups, 247 | bias=bias 248 | ) 249 | 250 | def forward(self, t, x): 251 | tt = torch.ones_like(x[:, :1, :, :]) * t 252 | ttx = torch.cat([tt, x], 1) 253 | return self._layer(ttx) 254 | 255 | 256 | class Flatten(nn.Module): 257 | 258 | def __init__(self): 259 | super(Flatten, self).__init__() 260 | 261 | def forward(self, x): 262 | shape = torch.prod(torch.tensor(x.shape[1:])).item() 263 | return x.view(-1, shape) 264 | 265 | 266 | -------------------------------------------------------------------------------- /sopa/src/models/odenet_mnist/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def one_hot(x, K): 5 | return np.array(x[:, None] == np.arange(K)[None, :], dtype=int) 6 | 7 | 8 | ######### device inside !!! 9 | def accuracy(model, dataset_loader, device, solver = None, solver_options = None): 10 | total_correct = 0 11 | for x, y in dataset_loader: 12 | x = x.to(device) 13 | y = one_hot(np.array(y.numpy()), 10) 14 | 15 | target_class = np.argmax(y, axis=1) 16 | 17 | if solver is not None: 18 | out = model(x, solver, solver_options).cpu().detach().numpy() 19 | else: 20 | out = model(x).cpu().detach().numpy() 21 | 22 | predicted_class = np.argmax(out, axis=1) 23 | total_correct += np.sum(predicted_class == target_class) 24 | return total_correct / len(dataset_loader.dataset) 25 | 26 | 27 | def sn_test(model, test_loader, device, solvers, solver_options, nsteps_grid): 28 | model.eval() 29 | for solver in solvers: 30 | solver.freeze_params() 31 | 32 | accs = [] 33 | for nsteps in nsteps_grid: 34 | for solver in solvers: 35 | solver.grid_constructor = lambda t: torch.linspace(t[0], t[-1], nsteps + 1) 36 | 37 | with torch.no_grad(): 38 | acc = accuracy(model, test_loader, device, solvers, solver_options) 39 | accs.append(acc) 40 | 41 | return accs -------------------------------------------------------------------------------- /sopa/src/models/odenet_mnist/runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from argparse import Namespace 4 | import logging 5 | import time 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import copy 11 | 12 | from decimal import Decimal 13 | 14 | import sys 15 | sys.path.append('/workspace/home/jgusak/neural-ode-sopa/') 16 | 17 | from sopa.src.solvers.utils import create_solver, noise_params 18 | from sopa.src.models.utils import fix_seeds, RunningAverageMeter 19 | from sopa.src.models.odenet_mnist.layers import MetaNODE 20 | from sopa.src.models.odenet_mnist.utils import makedirs, get_logger, count_parameters, learning_rate_with_decay 21 | from sopa.src.models.odenet_mnist.data import get_mnist_loaders, inf_generator 22 | from sopa.src.models.odenet_mnist.metrics import accuracy 23 | 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--network', type=str, choices=['resnet', 'odenet'], default='odenet') 27 | parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res']) 28 | parser.add_argument('--activation', type=str, choices=['tanh', 'softplus', 'softsign', 'relu'], default='relu') 29 | parser.add_argument('--in_channels', type=int, default=1) 30 | 31 | parser.add_argument('--solvers', 32 | type=lambda s:[tuple(map(lambda iparam: str(iparam[1]) if iparam[0] <= 1 else ( 33 | int(iparam[1]) if iparam[0]==2 else ( 34 | float(iparam[1]) if iparam[0] == 3 else Decimal(iparam[1]))), 35 | enumerate(item.split(',')))) for item in s.strip().split(';')], 36 | default = None, 37 | help='Each solver is represented with (method,parameterization,n_steps,step_size,u0,v0) \n' + 38 | 'If the solver has only one parameter u0, set v0 to -1; \n' + 39 | 'n_steps and step_size are exclusive parameters, only one of them can be != -1, \n' 40 | 'If n_steps = step_size = -1, automatic time grid_constructor is used \n;' 41 | 'For example, --solvers rk4,uv,2,-1,0.3,0.6;rk3,uv,-1,0.1,0.4,0.6;rk2,u,4,-1,0.3,-1') 42 | 43 | parser.add_argument('--solver_mode', type=str, choices=['switch', 'ensemble', 'standalone'], default='standalone') 44 | parser.add_argument('--switch_probs',type=lambda s: [float(item) for item in s.split(',')], default=None, 45 | help="--switch_probs 0.8,0.1,0.1") 46 | parser.add_argument('--ensemble_weights', type=lambda s: [float(item) for item in s.split(',')], default=None, 47 | help="ensemble_weights 0.6,0.2,0.2") 48 | parser.add_argument('--ensemble_prob', type=float, default=1.) 49 | 50 | parser.add_argument('--noise_type', type=str, choices=['cauchy', 'normal'], default=None) 51 | parser.add_argument('--noise_sigma', type=float, default = 0.001) 52 | parser.add_argument('--noise_prob', type=float, default = 0.) 53 | parser.add_argument('--minimize_rk2_error', type=eval, default=False, choices=[True, False]) 54 | 55 | parser.add_argument('--nepochs_nn', type=int, default=160) 56 | parser.add_argument('--nepochs_solver', type=int, default=0) 57 | parser.add_argument('--nstages', type=int, default=1) 58 | 59 | parser.add_argument('--ss_loss', type=eval, default=False, choices=[True, False]) 60 | parser.add_argument('--ss_loss_reg', type=float, default=0.1) 61 | 62 | parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False]) 63 | parser.add_argument('--lr', type=float, default=0.1) 64 | parser.add_argument('--weight_decay', type=float, default=0.0005) 65 | parser.add_argument('--batch_size', type=int, default=128) 66 | parser.add_argument('--test_batch_size', type=int, default=1000) 67 | 68 | parser.add_argument('--lr_uv', type=float, default=1e-3) 69 | parser.add_argument('--torch_dtype', type=str, default='float32') 70 | 71 | parser.add_argument('--data_root', type=str, default='/workspace/home/jgusak/neural-ode-sopa/.data/mnist') 72 | parser.add_argument('--save', type=str, default='./experiment2') 73 | parser.add_argument('--debug', action='store_true') 74 | parser.add_argument('--gpu', type=int, default=0) 75 | 76 | parser.add_argument('--seed', type=int, default=502) 77 | 78 | 79 | args = parser.parse_args() 80 | 81 | 82 | if args.torch_dtype == 'float64': 83 | dtype = torch.float64 84 | elif args.torch_dtype == 'float32': 85 | dtype = torch.float32 86 | else: 87 | raise ValueError('torch_type should be either float64 or float32') 88 | 89 | 90 | if __name__=="__main__": 91 | print(args.solvers) 92 | fix_seeds(args.seed) 93 | 94 | if args.torch_dtype == 'float64': 95 | dtype = torch.float64 96 | elif args.torch_dtype == 'float32': 97 | dtype = torch.float32 98 | else: 99 | raise ValueError('torch_type should be either float64 or float32') 100 | 101 | makedirs(args.save) 102 | logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=None) 103 | logger.info(args) 104 | 105 | device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') 106 | 107 | ########### Create train / val solvers 108 | train_solvers = [create_solver(*solver_params, dtype =dtype, device = device) for solver_params in args.solvers] 109 | for solver in train_solvers: 110 | print(solver.u) 111 | solver.freeze_params() 112 | 113 | train_solver_options = Namespace(**{key:vars(args)[key] for key in ['solver_mode','switch_probs', 114 | 'ensemble_prob','ensemble_weights']}) 115 | val_solver_options = Namespace(solver_mode = 'standalone') 116 | 117 | ########## Build the model 118 | is_odenet = args.network == 'odenet' 119 | 120 | model = MetaNODE(downsampling_method = args.downsampling_method, is_odenet=is_odenet, 121 | activation_type=args.activation, in_channels = args.in_channels) 122 | model.to(device) 123 | 124 | logger.info(model) 125 | 126 | ########### Create data loaders 127 | train_loader, test_loader, train_eval_loader = get_mnist_loaders( 128 | args.data_aug, args.batch_size, args.test_batch_size, data_root = args.data_root) 129 | 130 | data_gen = inf_generator(train_loader) 131 | batches_per_epoch = len(train_loader) 132 | 133 | ########### Creare criterion and optimizer 134 | criterion = nn.CrossEntropyLoss().to(device) 135 | loss_options = Namespace(ss_loss=args.ss_loss) 136 | 137 | lr_fn = learning_rate_with_decay( 138 | args.batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[60, 100, 140], 139 | decay_rates=[1, 0.1, 0.01, 0.001], lr0 = args.lr) 140 | 141 | optimizer = optim.RMSprop([{"params" : model.parameters(), 'lr' : args.lr},], lr=args.lr, weight_decay=args.weight_decay) 142 | 143 | ########### Train the model 144 | nsolvers = len(train_solvers) 145 | best_acc = [0]*nsolvers 146 | 147 | batch_time_meter = RunningAverageMeter() 148 | f_nfe_meter = RunningAverageMeter() 149 | b_nfe_meter = RunningAverageMeter() 150 | end = time.time() 151 | 152 | 153 | for itr in range(args.nepochs_nn * batches_per_epoch): 154 | 155 | for param_group in optimizer.param_groups: 156 | param_group['lr'] = lr_fn(itr) 157 | 158 | optimizer.zero_grad() 159 | x, y = data_gen.__next__() 160 | x = x.to(device) 161 | y = y.to(device) 162 | 163 | ##### Noise params 164 | if args.noise_type is not None: 165 | for i in range(len(train_solvers)): 166 | train_solvers[i].u, train_solvers[i].v = noise_params(train_solvers[i].u0, 167 | train_solvers[i].v0, 168 | std = args.noise_sigma, 169 | bernoulli_p = args.noise_prob, 170 | noise_type = args.noise_type) 171 | train_solvers[i].build_ButcherTableau() 172 | 173 | ##### Forward pass 174 | if is_odenet: 175 | logits = model(x, train_solvers, train_solver_options, loss_options) 176 | else: 177 | logits = model(x) 178 | 179 | loss = criterion(logits, y) 180 | if (loss_options is not None) and loss_options.ss_loss: 181 | loss += args.ss_loss_reg * model.get_ss_loss() 182 | 183 | ##### Compute NFE-forward 184 | if is_odenet: 185 | nfe_forward = 0 186 | for i in range(len(model.blocks)): 187 | nfe_forward += model.blocks[i].rhs_func.nfe 188 | model.blocks[i].rhs_func.nfe = 0 189 | 190 | loss.backward() 191 | optimizer.step() 192 | 193 | ##### Compute NFE-backward 194 | if is_odenet: 195 | nfe_backward = 0 196 | for i in range(len(model.blocks)): 197 | nfe_backward += model.blocks[i].rhs_func.nfe 198 | model.blocks[i].rhs_func.nfe = 0 199 | 200 | ##### Denoise params 201 | if args.noise_type is not None: 202 | for i in range(len(train_solvers)): 203 | train_solvers[i].u, train_solvers[i].v = train_solvers[i].u0, train_solvers[i].v0 204 | train_solvers[i].build_ButcherTableau() 205 | 206 | batch_time_meter.update(time.time() - end) 207 | if is_odenet: 208 | f_nfe_meter.update(nfe_forward) 209 | b_nfe_meter.update(nfe_backward) 210 | end = time.time() 211 | 212 | if itr % batches_per_epoch == 0: 213 | with torch.no_grad(): 214 | train_acc = [0]*nsolvers 215 | val_acc = [0]*nsolvers 216 | 217 | for solver_id, val_solver in enumerate(train_solvers): 218 | train_acc_id = accuracy(model, train_eval_loader, device, [val_solver], val_solver_options) 219 | val_acc_id = accuracy(model, test_loader, device, [val_solver], val_solver_options) 220 | 221 | train_acc[solver_id] = train_acc_id 222 | val_acc[solver_id] = val_acc_id 223 | 224 | if val_acc_id > best_acc[solver_id]: 225 | torch.save({'state_dict': model.state_dict(), 'args': args, 'solver_id':solver_id}, 226 | os.path.join(args.save,'model_best_{}.pth'.format(solver_id))) 227 | best_acc[solver_id] = val_acc_id 228 | del train_acc_id, val_acc_id 229 | 230 | if is_odenet: 231 | for i in range(len(model.blocks)): 232 | model.blocks[i].rhs_func.nfe = 0 233 | 234 | logger.info( 235 | "Epoch {:04d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f} | " 236 | "TrainAcc {} | TestAcc {} ".format( 237 | itr // batches_per_epoch, batch_time_meter.val, batch_time_meter.avg, f_nfe_meter.avg, 238 | b_nfe_meter.avg, train_acc, val_acc)) 239 | 240 | 241 | ## How to run 242 | # CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=3 python3 runner.py 243 | # --data_root "/workspace/raid/data/datasets" 244 | # --save "./experiment1" 245 | # --network odenet 246 | # --downsampling_method conv 247 | # --solvers rk2,u,-1,-1,0.6,-1;rk2,u,-1,-1,0.5,-1 248 | # --solver_mode switch 249 | # --switch_probs 0.8,0.2 250 | # --nepochs_nn 160 251 | # --nepochs_solver 0 252 | # --nstages 1 253 | # --lr 0.1 254 | # --seed 502 255 | 256 | # CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=4 python3 runner.py --data_root /workspace/raid/data/datasets/mnist --save ./experiment2_new --network 'odenet' --downsampling-method 'conv' --solvers "rk2,u,4,-1,0.5,-1;rk2,u,4,-1,1.0,-1" --solver_mode "switch" --activation "relu" --seed 702 --nepochs_nn 160 --nepochs_solver 0 --nstages 1 --lr 0.01 --ss_loss True 257 | 258 | # CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=4 python3 runner.py --data_root /workspace/raid/data/datasets/mnist --save ./experiment2_new --network 'odenet' --downsampling-method 'conv' --solvers "rk2,u,1,-1,0.66666666,-1" --solver_mode standalone --activation relu --seed 702 --nepochs_nn 160 --nepochs_solver 0 --nstages 1 --lr 0.1 --noise_type 'cauchy' --noise_sigma 0.001 --noise_prob 1. 259 | -------------------------------------------------------------------------------- /sopa/src/models/odenet_mnist/runner_new.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from argparse import Namespace 4 | import logging 5 | import time 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import copy 11 | 12 | from decimal import Decimal 13 | import wandb 14 | import sys 15 | 16 | sys.path.append('../../../../') 17 | 18 | from sopa.src.solvers.utils import create_solver 19 | from sopa.src.models.utils import fix_seeds, RunningAverageMeter 20 | from sopa.src.models.odenet_mnist.layers import MetaNODE 21 | from sopa.src.models.odenet_mnist.utils import makedirs, get_logger, count_parameters, learning_rate_with_decay 22 | from sopa.src.models.odenet_mnist.data import get_mnist_loaders, inf_generator 23 | # from sopa.src.models.odenet_mnist.metrics import accuracy 24 | from sopa.src.models.odenet_mnist.train_validate import train, validate 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--network', type=str, choices=['resnet', 'odenet'], default='odenet') 28 | parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res']) 29 | parser.add_argument('--activation', type=str, choices=['tanh', 'softplus', 'softsign', 'relu'], default='relu') 30 | parser.add_argument('--in_channels', type=int, default=1) 31 | 32 | parser.add_argument('--solvers', 33 | type=lambda s: [tuple(map(lambda iparam: str(iparam[1]) if iparam[0] <= 1 else ( 34 | int(iparam[1]) if iparam[0] == 2 else ( 35 | float(iparam[1]) if iparam[0] == 3 else Decimal(iparam[1]))), 36 | enumerate(item.split(',')))) for item in s.strip().split(';')], 37 | default=None, 38 | help='Each solver is represented with (method,parameterization,n_steps,step_size,u0,v0) \n' + 39 | 'If the solver has only one parameter u0, set v0 to -1; \n' + 40 | 'n_steps and step_size are exclusive parameters, only one of them can be != -1, \n' 41 | 'If n_steps = step_size = -1, automatic time grid_constructor is used \n;' 42 | 'For example, --solvers rk4,uv,2,-1,0.3,0.6;rk3,uv,-1,0.1,0.4,0.6;rk2,u,4,-1,0.3,-1') 43 | 44 | parser.add_argument('--solver_mode', type=str, choices=['switch', 'ensemble', 'standalone'], default='standalone') 45 | parser.add_argument('--val_solver_modes', 46 | type=lambda s: s.strip().split(','), 47 | default=['standalone', 'ensemble', 'switch'], 48 | help='Solver modes to use for validation step') 49 | 50 | parser.add_argument('--switch_probs', type=lambda s: [float(item) for item in s.split(',')], default=None, 51 | help="--switch_probs 0.8,0.1,0.1") 52 | parser.add_argument('--ensemble_weights', type=lambda s: [float(item) for item in s.split(',')], default=None, 53 | help="ensemble_weights 0.6,0.2,0.2") 54 | parser.add_argument('--ensemble_prob', type=float, default=1.) 55 | 56 | parser.add_argument('--noise_type', type=str, choices=['cauchy', 'normal'], default=None) 57 | parser.add_argument('--noise_sigma', type=float, default=0.001) 58 | parser.add_argument('--noise_prob', type=float, default=0.) 59 | parser.add_argument('--minimize_rk2_error', type=eval, default=False, choices=[True, False]) 60 | 61 | parser.add_argument('--nepochs_nn', type=int, default=160) 62 | parser.add_argument('--nepochs_solver', type=int, default=0) 63 | parser.add_argument('--nstages', type=int, default=1) 64 | 65 | parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False]) 66 | parser.add_argument('--lr', type=float, default=0.01) 67 | parser.add_argument('--weight_decay', type=float, default=0.0005) 68 | parser.add_argument('--batch_size', type=int, default=128) 69 | parser.add_argument('--test_batch_size', type=int, default=1000) 70 | 71 | parser.add_argument('--lr_uv', type=float, default=1e-3) 72 | parser.add_argument('--torch_dtype', type=str, default='float32') 73 | parser.add_argument('--wandb_name', type=str, default='mnist_tmp') 74 | 75 | parser.add_argument('--data_root', type=str, default='/gpfs/gpfs0/t.daulbaev/data/MNIST') 76 | parser.add_argument('--save', type=str, default='../../../rk2_tmp') 77 | parser.add_argument('--debug', action='store_true') 78 | parser.add_argument('--gpu', type=int, default=0) 79 | 80 | parser.add_argument('--seed', type=int, default=502) 81 | # Noise and adversarial attacks parameters: 82 | parser.add_argument('--data_noise_std', type=float, default=0., 83 | help='Applies Norm(0, std) gaussian noise to each batch entry') 84 | parser.add_argument('--eps_adv_training', type=float, default=0.3, 85 | help='Epsilon for adversarial training') 86 | parser.add_argument( 87 | "--adv_training_mode", 88 | default="clean", 89 | choices=["clean", "fgsm", "at"], #, "at_ls", "av", "fs", "nce", "nce_moco", "moco", "av_extra", "meta"], 90 | help='''Adverarial training method/mode, by default there is no adversarial training (clean). 91 | For further details see MegaAdversarial/train in this repository. 92 | ''' 93 | ) 94 | parser.add_argument('--use_wandb', type=eval, default=True, choices=[True, False]) 95 | parser.add_argument('--use_logger', type=eval, default=False, choices=[True, False]) 96 | parser.add_argument('--ss_loss', type=eval, default=False, choices=[True, False]) 97 | parser.add_argument('--ss_loss_reg', type=float, default=0.1) 98 | parser.add_argument('--timestamp', type=int, default=int(1e6 * time.time())) 99 | 100 | args = parser.parse_args() 101 | 102 | sys.path.append('../../') 103 | 104 | if args.use_logger: 105 | makedirs(args.save) 106 | logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=None) 107 | logger.info(args) 108 | if args.use_wandb: 109 | wandb.init(project=args.wandb_name, entity="sopa_node") 110 | makedirs(args.save) 111 | wandb.config.update(args) 112 | os.makedirs(os.path.join(args.save, str(args.timestamp))) 113 | 114 | if args.torch_dtype == 'float64': 115 | dtype = torch.float64 116 | elif args.torch_dtype == 'float32': 117 | dtype = torch.float32 118 | else: 119 | raise ValueError('torch_type should be either float64 or float32') 120 | 121 | if __name__ == "__main__": 122 | print(args.solvers) 123 | fix_seeds(args.seed) 124 | 125 | if args.torch_dtype == 'float64': 126 | dtype = torch.float64 127 | elif args.torch_dtype == 'float32': 128 | dtype = torch.float32 129 | else: 130 | raise ValueError('torch_type should be either float64 or float32') 131 | 132 | makedirs(args.save) 133 | logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=None) 134 | logger.info(args) 135 | 136 | device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') 137 | 138 | ########### Create train / val solvers 139 | train_solvers = [create_solver(*solver_params, dtype=dtype, device=device) for solver_params in args.solvers] 140 | for solver in train_solvers: 141 | solver.freeze_params() 142 | 143 | train_solver_options = Namespace(**{key: vars(args)[key] for key in ['solver_mode', 'switch_probs', 144 | 'ensemble_prob', 'ensemble_weights']}) 145 | 146 | val_solver_modes = args.val_solver_modes 147 | 148 | ########## Build the model 149 | is_odenet = args.network == 'odenet' 150 | 151 | model = MetaNODE(downsampling_method=args.downsampling_method, 152 | is_odenet=is_odenet, 153 | activation_type=args.activation, 154 | in_channels=args.in_channels) 155 | model.to(device) 156 | if args.use_wandb: 157 | wandb.watch(model) 158 | if args.use_logger: 159 | logger.info(model) 160 | 161 | ########### Create data loaders 162 | train_loader, test_loader, train_eval_loader = get_mnist_loaders(args.data_aug, 163 | args.batch_size, 164 | args.test_batch_size, 165 | data_root=args.data_root) 166 | 167 | data_gen = inf_generator(train_loader) 168 | batches_per_epoch = len(train_loader) 169 | 170 | ########### Creare criterion and optimizer 171 | criterion = nn.CrossEntropyLoss().to(device) 172 | loss_options = Namespace(ss_loss=args.ss_loss) 173 | 174 | lr_fn = learning_rate_with_decay( 175 | args.batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[60, 100, 140], 176 | decay_rates=[1, 0.1, 0.01, 0.001], lr0=args.lr) 177 | 178 | optimizer = optim.RMSprop([{"params": model.parameters(), 'lr': args.lr}, ], lr=args.lr, 179 | weight_decay=args.weight_decay) 180 | 181 | ########### Train the model 182 | best_acc = {'standalone': [0] * len(train_solvers), 183 | 'ensemble': 0, 184 | 'switch': 0} 185 | 186 | batch_time_meter = RunningAverageMeter() 187 | f_nfe_meter = RunningAverageMeter() 188 | b_nfe_meter = RunningAverageMeter() 189 | 190 | for itr in range(args.nepochs_nn * batches_per_epoch): 191 | 192 | for param_group in optimizer.param_groups: 193 | param_group['lr'] = lr_fn(itr) 194 | 195 | if itr % batches_per_epoch != 0: 196 | train(itr, 197 | model, 198 | data_gen, 199 | solvers=train_solvers, 200 | solver_options=train_solver_options, 201 | criterion=criterion, 202 | optimizer=optimizer, 203 | batch_time_meter=batch_time_meter, 204 | f_nfe_meter=f_nfe_meter, 205 | b_nfe_meter=b_nfe_meter, 206 | device=device, 207 | dtype=dtype, 208 | is_odenet=is_odenet, 209 | args=args, 210 | logger=None, 211 | wandb_logger=None) 212 | else: 213 | train(itr, 214 | model, 215 | data_gen, 216 | solvers=train_solvers, 217 | solver_options=train_solver_options, 218 | criterion=criterion, 219 | optimizer=optimizer, 220 | batch_time_meter=batch_time_meter, 221 | f_nfe_meter=f_nfe_meter, 222 | b_nfe_meter=b_nfe_meter, 223 | device=device, 224 | dtype=dtype, 225 | is_odenet=is_odenet, 226 | args=args, 227 | logger=logger, 228 | wandb_logger=wandb) 229 | 230 | best_acc = validate(best_acc, 231 | itr, 232 | model, 233 | train_eval_loader, 234 | test_loader, 235 | batches_per_epoch, 236 | solvers=train_solvers, 237 | val_solver_modes=val_solver_modes, 238 | batch_time_meter=batch_time_meter, 239 | f_nfe_meter=f_nfe_meter, 240 | b_nfe_meter=b_nfe_meter, 241 | device=device, 242 | dtype=dtype, 243 | args=args, 244 | logger=logger, 245 | wandb_logger=wandb) 246 | 247 | # # How to run 248 | # CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=3 python3 runner.py 249 | # --data_root "/workspace/raid/data/datasets" 250 | # --save "./experiment1" 251 | # --network odenet 252 | # --downsampling_method conv 253 | # --solvers rk2,u,-1,-1,0.6,-1;rk2,u,-1,-1,0.5,-1 254 | # --solver_mode switch 255 | # --switch_probs 0.8,0.2 256 | # --nepochs_nn 160 257 | # --nepochs_solver 0 258 | # --nstages 1 259 | # --lr 0.1 260 | # --seed 502 261 | 262 | # CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=1 python3 runner_new.py --data_root /workspace/raid/data/datasets/mnist --save ./experiment2_new --network 'odenet' --downsampling-method 'conv' --solvers "rk2,u,1,-1,0.5,-1;rk2,u,1,-1,1.0,-1" --solver_mode "switch" --activation "relu" --seed 702 --nepochs_nn 160 --nepochs_solver 0 --nstages 1 --lr 0.1 263 | 264 | # CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=4 python3 runner_new.py --data_root /workspace/raid/data/datasets/mnist --save ./experiment2_new --network 'odenet' --downsampling-method 'conv' --solvers "rk2,u,1,-1,0.66666666,-1" --solver_mode standalone --activation relu --seed 702 --nepochs_nn 160 --nepochs_solver 0 --nstages 1 --lr 0.1 --noise_type 'cauchy' --noise_sigma 0.001 --noise_prob 1. 265 | 266 | # Пересчитать MNIST -------------------------------------------------------------------------------- /sopa/src/models/odenet_mnist/runner_old.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import logging 4 | import time 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torch.utils.data import DataLoader 10 | import torchvision.datasets as datasets 11 | import torchvision.transforms as transforms 12 | 13 | from torchdiffeq._impl.rk_common import _ButcherTableau 14 | 15 | import sys 16 | sys.path.append('/workspace/home/jgusak/neural-ode-sopa/') 17 | from sopa.src.solvers.rk_parametric_old import rk_param_tableau, odeint_plus 18 | 19 | from sopa.src.models.odenet_mnist.layers import ResBlock, ODEfunc, build_downsampling_layers, build_fc_layers 20 | from sopa.src.models.odenet_mnist.utils import makedirs, get_logger, count_parameters, learning_rate_with_decay, RunningAverageMeter 21 | from sopa.src.models.odenet_mnist.data import get_mnist_loaders, inf_generator 22 | from sopa.src.models.odenet_mnist.metrics import accuracy 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--network', type=str, choices=['resnet', 'odenet'], default='odenet') 26 | parser.add_argument('--method', type=str, choices=['dopri5', 'adams', 'rk4', 'rk4_param', 'rk3_param','euler'], default='rk4') 27 | parser.add_argument('--step_size', type=float, default=None) 28 | parser.add_argument('--tol', type=float, default=1e-3) 29 | parser.add_argument('--adjoint', type=eval, default=False, choices=[True, False]) 30 | parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res']) 31 | parser.add_argument('--nepochs', type=int, default=160) 32 | parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False]) 33 | parser.add_argument('--lr', type=float, default=0.1) 34 | parser.add_argument('--batch_size', type=int, default=128) 35 | parser.add_argument('--test_batch_size', type=int, default=1000) 36 | 37 | parser.add_argument('--lr_uv', type=float, default=1e-3) 38 | parser.add_argument('--parameterization', type=str, choices=['uv', 'u1', 'u2', 'u3'], default='uv') 39 | parser.add_argument('--u0', type=float, default=1/3.) 40 | parser.add_argument('--v0', type=float, default=2/3.) 41 | parser.add_argument('--fix_param', action='store_true') 42 | 43 | parser.add_argument('--data_root', type=str, default='/workspace/home/jgusak/neural-ode-sopa/.data/mnist') 44 | parser.add_argument('--save', type=str, default='./experiment1') 45 | parser.add_argument('--debug', action='store_true') 46 | parser.add_argument('--gpu', type=int, default=0) 47 | 48 | parser.add_argument('--torch_seed', type=int, default=502) 49 | parser.add_argument('--numpy_seed', type=int, default=502) 50 | 51 | args = parser.parse_args([]) 52 | 53 | # Set random seed for reproducibility 54 | np.random.seed(args.numpy_seed) 55 | torch.manual_seed(args.torch_seed) 56 | torch.backends.cudnn.deterministic = True 57 | torch.backends.cudnn.benchmark = False 58 | 59 | if args.adjoint: 60 | from torchdiffeq import odeint_adjoint as odeint 61 | else: 62 | from torchdiffeq import odeint 63 | 64 | class ODEBlock(nn.Module): 65 | 66 | def __init__(self, odefunc, args, device = 'cpu'): 67 | super(ODEBlock, self).__init__() 68 | 69 | self.odefunc = odefunc 70 | self.integration_time = torch.tensor([0, 1]).float() 71 | 72 | # make trainable parameters as attributes of ODE block, 73 | # recompute tableau at each forward step 74 | self.step_size = args.step_size 75 | 76 | self.method = args.method 77 | self.fix_param = None 78 | self.parameterization = None 79 | self.u0, self.v0 = None, None 80 | self.u_, self.v_ = None, None 81 | self.u, self.v = None, None 82 | 83 | 84 | self.eps = torch.finfo(torch.float32).eps 85 | self.device = device 86 | 87 | 88 | if self.method in ['rk4_param', 'rk3_param']: 89 | self.fix_param = args.fix_param 90 | self.parameterization = args.parameterization 91 | self.u0 = args.u0 92 | 93 | if self.fix_param: 94 | self.u = torch.tensor(self.u0) 95 | 96 | if self.parameterization == 'uv': 97 | self.v0 = args.v0 98 | self.v = torch.tensor(self.v0) 99 | 100 | else: 101 | # an important issue about putting leaf variables to the device https://discuss.pytorch.org/t/tensor-to-device-changes-is-leaf-causing-cant-optimize-a-non-leaf-tensor/37659 102 | self.u_ = nn.Parameter(torch.tensor(self.u0)).to(self.device) 103 | self.u = torch.clamp(self.u_, self.eps, 1. - self.eps).detach().requires_grad_(True) 104 | 105 | if self.parameterization == 'uv': 106 | self.v0 = args.v0 107 | self.v_ = nn.Parameter(torch.tensor(self.v0)).to(self.device) 108 | self.v = torch.clamp(self.v_, self.eps, 1. - self.eps).detach().requires_grad_(True) 109 | 110 | logger.info('Init | u {} | v {}'.format(self.u.data, (self.v if self.v is None else self.v.data))) 111 | 112 | self.alpha, self.beta, self.c_sol = rk_param_tableau(self.u, self.v, device = self.device, 113 | parameterization=self.parameterization, 114 | method = self.method) 115 | self.tableau = _ButcherTableau(alpha = self.alpha, 116 | beta = self.beta, 117 | c_sol = self.c_sol, 118 | c_error = torch.zeros((len(self.c_sol),), device = self.device)) 119 | 120 | def forward(self, x): 121 | self.integration_time = self.integration_time.type_as(x) 122 | 123 | if self.method in ['rk4_param', 'rk3_param']: 124 | out = odeint_plus(self.odefunc, x, self.integration_time, 125 | method=self.method, options = {'tableau':self.tableau, 'step_size':self.step_size}) 126 | else: 127 | out = odeint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol, 128 | method=self.method, options = {'step_size':self.step_size}) 129 | 130 | return out[1] 131 | 132 | @property 133 | def nfe(self): 134 | return self.odefunc.nfe 135 | 136 | @nfe.setter 137 | def nfe(self, value): 138 | self.odefunc.nfe = value 139 | 140 | 141 | if __name__=="__main__": 142 | makedirs(args.save) 143 | logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=None) 144 | logger.info(args) 145 | 146 | device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') 147 | 148 | ########### Build model 149 | is_odenet = args.network == 'odenet' 150 | 151 | downsampling_layers = build_downsampling_layers(args.downsampling_method) 152 | fc_layers = build_fc_layers() 153 | 154 | feature_layers = [ODEBlock(ODEfunc(64), args, device)] if is_odenet else [ResBlock(64, 64) for _ in range(6)] 155 | model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device) 156 | 157 | logger.info(model) 158 | logger.info('Number of parameters: {}'.format(count_parameters(model))) 159 | 160 | ########### Create data loaders 161 | train_loader, test_loader, train_eval_loader = get_mnist_loaders( 162 | args.data_aug, args.batch_size, args.test_batch_size, data_root = args.data_root) 163 | 164 | data_gen = inf_generator(train_loader) 165 | batches_per_epoch = len(train_loader) 166 | 167 | ########### Creare criterion and optimizer 168 | criterion = nn.CrossEntropyLoss().to(device) 169 | 170 | lr_fn = learning_rate_with_decay( 171 | args.batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[60, 100, 140], 172 | decay_rates=[1, 0.1, 0.01, 0.001], lr0 = args.lr) 173 | 174 | if is_odenet and (args.method in ['rk4_param', 'rk3_param']) and not args.fix_param: 175 | params_uv = [] 176 | 177 | for mname, m in model.named_modules(): 178 | if isinstance(m, ODEBlock): 179 | if m.u_ is not None: 180 | params_uv.append(m.u_) 181 | if m.v_ is not None: 182 | params_uv.append(m.v_) 183 | 184 | optimizer = optim.SGD([{"params" : model.parameters()}, 185 | {"params" : params_uv, 'lr' : args.lr_uv}], lr=args.lr, momentum = 0.9) 186 | else: 187 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum = 0.9) 188 | 189 | 190 | ########### Train the model 191 | best_acc = 0 192 | batch_time_meter = RunningAverageMeter() 193 | f_nfe_meter = RunningAverageMeter() 194 | b_nfe_meter = RunningAverageMeter() 195 | end = time.time() 196 | 197 | for itr in range(args.nepochs * batches_per_epoch): 198 | 199 | for param_group in optimizer.param_groups: 200 | param_group['lr'] = lr_fn(itr) 201 | 202 | if is_odenet and (args.method in ['rk4_param', 'rk3_param']) and not args.fix_param: 203 | 204 | ### Iterate along the model, find ODEBlock, recalculate the tableau 205 | for mname, m in model.named_modules(): 206 | if isinstance(m, ODEBlock): 207 | m.alpha, m.beta, m.c_sol = rk_param_tableau(m.u, m.v, device = device, 208 | parameterization=args.parameterization, 209 | method = args.method) 210 | m.tableau = _ButcherTableau(alpha = m.alpha, beta = m.beta, c_sol = m.c_sol, 211 | c_error = torch.zeros((len(m.c_sol),), device = device)) 212 | 213 | optimizer.zero_grad() 214 | x, y = data_gen.__next__() 215 | x = x.to(device) 216 | y = y.to(device) 217 | 218 | logits = model(x) 219 | 220 | loss = criterion(logits, y) 221 | 222 | if is_odenet: 223 | nfe_forward = feature_layers[0].nfe 224 | feature_layers[0].nfe = 0 225 | 226 | loss.backward() 227 | optimizer.step() 228 | 229 | if is_odenet: 230 | nfe_backward = feature_layers[0].nfe 231 | feature_layers[0].nfe = 0 232 | 233 | batch_time_meter.update(time.time() - end) 234 | if is_odenet: 235 | f_nfe_meter.update(nfe_forward) 236 | b_nfe_meter.update(nfe_backward) 237 | end = time.time() 238 | 239 | if itr % batches_per_epoch == 0: 240 | with torch.no_grad(): 241 | train_acc = accuracy(model, train_eval_loader, device) 242 | val_acc = accuracy(model, test_loader, device) 243 | if val_acc > best_acc: 244 | torch.save({'state_dict': model.state_dict(), 'args': args}, os.path.join(args.save, 'model.pth')) 245 | best_acc = val_acc 246 | 247 | logger.info( 248 | "Epoch {:04d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f} | " 249 | "TrainAcc {:.6f} | TestAcc {:.6f}".format( 250 | itr // batches_per_epoch, batch_time_meter.val, batch_time_meter.avg, f_nfe_meter.avg, 251 | b_nfe_meter.avg, train_acc, val_acc)) 252 | 253 | #CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=3 python3 runner_old.py --network 'odenet' --data_root "/workspace/raid/data/datasets" --batch_size 128 --test_batch_size 1000 --nepochs 160 --lr 0.1 --save "./experiment1" --method 'euler' --step_size 1.0 --gpu 0 --torch_seed 502 --numpy_seed 502 --fix_param -------------------------------------------------------------------------------- /sopa/src/models/odenet_mnist/train_validate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from argparse import Namespace 8 | from .metrics import accuracy 9 | from sopa.src.solvers.utils import noise_params 10 | from MegaAdversarial.src.attacks import ( 11 | Clean, 12 | PGD, 13 | FGSM 14 | ) 15 | 16 | CONFIG_PGD_TRAIN = {"eps": 0.3, "lr": 2.0 / 255, "n_iter": 7} 17 | CONFIG_FGSM_TRAIN = {"eps": 0.3} 18 | 19 | def train(itr, 20 | model, 21 | data_gen, 22 | solvers, 23 | solver_options, 24 | criterion, 25 | optimizer, 26 | batch_time_meter, 27 | f_nfe_meter, 28 | b_nfe_meter, 29 | device = 'cpu', 30 | dtype = torch.float32, 31 | is_odenet = True, 32 | args = None, 33 | logger = None, 34 | wandb_logger = None): 35 | 36 | end = time.time() 37 | 38 | optimizer.zero_grad() 39 | x, y = data_gen.__next__() 40 | x = x.to(device) 41 | y = y.to(device) 42 | 43 | ##### Noise params 44 | if args.noise_type is not None: 45 | for i in range(len(solvers)): 46 | solvers[i].u, solvers[i].v = noise_params(solvers[i].u0, 47 | solvers[i].v0, 48 | std = args.noise_sigma, 49 | bernoulli_p = args.noise_prob, 50 | noise_type = args.noise_type) 51 | solvers[i].build_ButcherTableau() 52 | 53 | if args.adv_training_mode == "clean": 54 | train_attack = Clean(model) 55 | elif args.adv_training_mode == "fgsm": 56 | train_attack = FGSM(model, **CONFIG_FGSM_TRAIN) 57 | elif args.adv_training_mode == "at": 58 | train_attack = PGD(model, **CONFIG_PGD_TRAIN) 59 | else: 60 | raise ValueError("Attack type not understood.") 61 | x, y = train_attack(x, y, {"solvers": solvers, "solver_options": solver_options}) 62 | 63 | # Add noise: 64 | if args.data_noise_std > 1e-12: 65 | with torch.no_grad(): 66 | x = x + args.data_noise_std * torch.randn_like(x) 67 | ##### Forward pass 68 | if is_odenet: 69 | logits = model(x, solvers, solver_options, Namespace(ss_loss=args.ss_loss)) 70 | else: 71 | logits = model(x) 72 | 73 | xentropy = criterion(logits, y) 74 | if args.ss_loss: 75 | ss_loss = model.get_ss_loss() 76 | loss = xentropy + args.ss_loss_reg * ss_loss 77 | else: 78 | ss_loss = 0. 79 | loss = xentropy 80 | if wandb_logger is not None: 81 | wandb_logger.log({"xentropy": xentropy.item(), 82 | "ss_loss": ss_loss, 83 | "loss": loss.item(), 84 | "log_func": "train"}) 85 | # if logger is not None: 86 | # fix 87 | 88 | ##### Compute NFE-forward 89 | if is_odenet: 90 | nfe_forward = 0 91 | for i in range(len(model.blocks)): 92 | nfe_forward += model.blocks[i].rhs_func.nfe 93 | model.blocks[i].rhs_func.nfe = 0 94 | 95 | loss.backward() 96 | optimizer.step() 97 | 98 | ##### Compute NFE-backward 99 | if is_odenet: 100 | nfe_backward = 0 101 | for i in range(len(model.blocks)): 102 | nfe_backward += model.blocks[i].rhs_func.nfe 103 | model.blocks[i].rhs_func.nfe = 0 104 | 105 | ##### Denoise params 106 | if args.noise_type is not None: 107 | for i in range(len(solvers)): 108 | solvers[i].u, solvers[i].v = solvers[i].u0, solvers[i].v0 109 | solvers[i].build_ButcherTableau() 110 | 111 | batch_time_meter.update(time.time() - end) 112 | if is_odenet: 113 | f_nfe_meter.update(nfe_forward) 114 | b_nfe_meter.update(nfe_backward) 115 | 116 | 117 | 118 | def validate_standalone(best_acc, 119 | itr, 120 | model, 121 | train_eval_loader, 122 | test_loader, 123 | batches_per_epoch, 124 | solvers, 125 | solver_options, 126 | batch_time_meter, 127 | f_nfe_meter, 128 | b_nfe_meter, 129 | device = 'cpu', 130 | dtype = torch.float32, 131 | args = None, 132 | logger = None, 133 | wandb_logger=None): 134 | 135 | nsolvers = len(solvers) 136 | 137 | with torch.no_grad(): 138 | train_acc = [0] * nsolvers 139 | val_acc = [0] * nsolvers 140 | 141 | for solver_id, solver in enumerate(solvers): 142 | 143 | train_acc_id = accuracy(model, train_eval_loader, device, [solver], solver_options) 144 | val_acc_id = accuracy(model, test_loader, device, [solver], solver_options) 145 | 146 | train_acc[solver_id] = train_acc_id 147 | val_acc[solver_id] = val_acc_id 148 | 149 | 150 | if val_acc_id > best_acc[solver_id]: 151 | best_acc[solver_id] = val_acc_id 152 | 153 | torch.save({'state_dict': model.state_dict(), 154 | 'args': args, 155 | 'solver_id':solver_id, 156 | 'val_solver_mode':solver_options.solver_mode, 157 | 'acc': val_acc_id}, 158 | os.path.join(os.path.join(args.save, str(args.timestamp)), 159 | 'model_best_{}.pth'.format(solver_id))) 160 | if wandb_logger is not None: 161 | wandb_logger.save(os.path.join(os.path.join(args.save, str(args.timestamp)), 162 | 'model_best_{}.pth'.format(solver_id))) 163 | 164 | if logger is not None: 165 | logger.info("Epoch {:04d} | SolverMode {} | SolverId {} | " 166 | "TrainAcc {:.10f} | TestAcc {:.10f} | BestAcc {:.10f}".format( 167 | itr // batches_per_epoch, solver_options.solver_mode, solver_id, 168 | train_acc_id, val_acc_id, best_acc[solver_id])) 169 | if wandb_logger is not None: 170 | wandb_logger.log({ 171 | "epoch": itr // batches_per_epoch, 172 | "solver_mode": solver_options.solver_mode, 173 | "solver_id": solver_id, 174 | "train_acc": train_acc_id, 175 | "test_acc": val_acc_id, 176 | "best_acc": best_acc[solver_id], 177 | "log_func": "validate_standalone" 178 | }) 179 | for i in range(len(model.blocks)): 180 | model.blocks[i].rhs_func.nfe = 0 181 | 182 | return best_acc 183 | 184 | 185 | 186 | def validate_ensemble_switch(best_acc, 187 | itr, 188 | model, 189 | train_eval_loader, 190 | test_loader, 191 | batches_per_epoch, 192 | solvers, 193 | solver_options, 194 | batch_time_meter, 195 | f_nfe_meter, 196 | b_nfe_meter, 197 | device = 'cpu', 198 | dtype = torch.float32, 199 | args = None, 200 | logger=None, 201 | wandb_logger=None): 202 | 203 | nsolvers = len(solvers) 204 | 205 | with torch.no_grad(): 206 | 207 | train_acc = accuracy(model, train_eval_loader, device, solvers, solver_options) 208 | val_acc = accuracy(model, test_loader, device, solvers, solver_options) 209 | 210 | if val_acc > best_acc: 211 | best_acc = val_acc 212 | 213 | torch.save({'state_dict': model.state_dict(), 214 | 'args': args, 215 | 'solver_id':None, 216 | 'val_solver_mode':solver_options.solver_mode, 217 | 'acc': val_acc}, 218 | os.path.join(os.path.join(args.save, str(args.timestamp)), 219 | 'model_best.pth')) 220 | if wandb_logger is not None: 221 | wandb_logger.save(os.path.join(os.path.join(args.save, str(args.timestamp)), 222 | 'model_best.pth')) 223 | 224 | 225 | 226 | if logger is not None: 227 | logger.info("Epoch {:04d} | SolverMode {} | SolverId {} | " 228 | "TrainAcc {:.10f} | TestAcc {:.10f} | BestAcc {:.10f}".format( 229 | itr // batches_per_epoch, solver_options.solver_mode, None, 230 | train_acc, val_acc, best_acc)) 231 | if wandb_logger is not None: 232 | wandb_logger.log({ 233 | "epoch": itr // batches_per_epoch, 234 | "solver_mode": solver_options.solver_mode, 235 | "solver_id": None, 236 | "train_acc": train_acc, 237 | "test_acc": val_acc, 238 | "best_acc": best_acc, 239 | "log_func": "validate_ensemble_switch" 240 | }) 241 | 242 | 243 | for i in range(len(model.blocks)): 244 | model.blocks[i].rhs_func.nfe = 0 245 | 246 | return best_acc 247 | 248 | 249 | 250 | def validate(best_acc, 251 | itr, 252 | model, 253 | train_eval_loader, 254 | test_loader, 255 | batches_per_epoch, 256 | solvers, 257 | val_solver_modes, 258 | batch_time_meter, 259 | f_nfe_meter, 260 | b_nfe_meter, 261 | device = 'cpu', 262 | dtype = torch.float32, 263 | args = None, 264 | logger = None, 265 | wandb_logger=None): 266 | 267 | for solver_mode in val_solver_modes: 268 | 269 | if solver_mode == 'standalone': 270 | 271 | val_solver_options = Namespace(solver_mode = 'standalone') 272 | best_acc['standalone'] = validate_standalone(best_acc['standalone'], 273 | itr, 274 | model, 275 | train_eval_loader, 276 | test_loader, 277 | batches_per_epoch, 278 | solvers = solvers, 279 | solver_options = val_solver_options, 280 | batch_time_meter = batch_time_meter, 281 | f_nfe_meter = f_nfe_meter, 282 | b_nfe_meter = b_nfe_meter, 283 | device = device, 284 | dtype = dtype, 285 | args = args, 286 | logger = logger, 287 | wandb_logger = wandb_logger) 288 | elif solver_mode == 'ensemble': 289 | 290 | val_solver_options = Namespace(solver_mode = 'ensemble', 291 | ensemble_weights = args.ensemble_weights, 292 | ensemble_prob = args.ensemble_prob) 293 | 294 | best_acc['ensemble'] = validate_ensemble_switch(best_acc['ensemble'], 295 | itr, 296 | model, 297 | train_eval_loader, 298 | test_loader, 299 | batches_per_epoch, 300 | solvers = solvers, 301 | solver_options = val_solver_options, 302 | batch_time_meter = batch_time_meter, 303 | f_nfe_meter = f_nfe_meter, 304 | b_nfe_meter = b_nfe_meter, 305 | device = device, 306 | dtype = dtype, 307 | args = args, 308 | logger = logger, 309 | wandb_logger = wandb_logger) 310 | elif solver_mode == 'switch': 311 | 312 | val_solver_options = Namespace(solver_mode = 'switch', switch_probs = args.switch_probs) 313 | 314 | best_acc['switch'] = validate_ensemble_switch(best_acc['switch'], 315 | itr, 316 | model, 317 | train_eval_loader, 318 | test_loader, 319 | batches_per_epoch, 320 | solvers = solvers, 321 | solver_options = val_solver_options, 322 | batch_time_meter = batch_time_meter, 323 | f_nfe_meter = f_nfe_meter, 324 | b_nfe_meter = b_nfe_meter, 325 | device = device, 326 | dtype = dtype, 327 | args = args, 328 | logger = logger, 329 | wandb_logger = wandb_logger) 330 | if logger is not None: 331 | logger.info("Epoch {:04d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f}".format( 332 | itr // batches_per_epoch, 333 | batch_time_meter.val, batch_time_meter.avg, 334 | f_nfe_meter.avg, b_nfe_meter.avg)) 335 | if wandb_logger is not None: 336 | wandb_logger.log({ 337 | "epoch": itr // batches_per_epoch, 338 | "batch_time_val": batch_time_meter.val, 339 | "nfe": f_nfe_meter.avg, 340 | "nbe": b_nfe_meter.avg, 341 | "log_func": "validate" 342 | }) 343 | return best_acc -------------------------------------------------------------------------------- /sopa/src/models/odenet_mnist/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import logging 4 | 5 | 6 | ######### args.lr inside !!! 7 | def learning_rate_with_decay(batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates, lr0): 8 | initial_learning_rate = lr0 * batch_size / batch_denom 9 | 10 | boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs] 11 | vals = [initial_learning_rate * decay for decay in decay_rates] 12 | 13 | def learning_rate_fn(itr): 14 | lt = [itr < b for b in boundaries] + [True] 15 | i = np.argmax(lt) 16 | return vals[i] 17 | 18 | return learning_rate_fn 19 | 20 | 21 | def count_parameters(model): 22 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 23 | 24 | 25 | def makedirs(dirname): 26 | if not os.path.exists(dirname): 27 | os.makedirs(dirname) 28 | 29 | 30 | def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False): 31 | logger = logging.getLogger() 32 | if debug: 33 | level = logging.DEBUG 34 | else: 35 | level = logging.INFO 36 | logger.setLevel(level) 37 | if saving: 38 | info_file_handler = logging.FileHandler(logpath, mode="a") 39 | info_file_handler.setLevel(level) 40 | logger.addHandler(info_file_handler) 41 | if displaying: 42 | console_handler = logging.StreamHandler() 43 | console_handler.setLevel(level) 44 | logger.addHandler(console_handler) 45 | logger.info(filepath) 46 | 47 | if filepath is not None: 48 | with open(filepath, "r") as f: 49 | logger.info(f.read()) 50 | 51 | for f in package_files: 52 | logger.info(f) 53 | with open(f, "r") as package_f: 54 | logger.info(package_f.read()) 55 | 56 | return logger -------------------------------------------------------------------------------- /sopa/src/models/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | 5 | from .odenet_mnist.layers import MetaNODE 6 | 7 | def fix_seeds(seed=502): 8 | np.random.seed(seed) 9 | random.seed(seed) 10 | torch.manual_seed(seed) 11 | torch.cuda.manual_seed_all(seed) 12 | torch.set_printoptions(precision=10) 13 | 14 | torch.backends.cudnn.deterministic = True 15 | torch.backends.cudnn.benchmark = False 16 | 17 | class RunningAverageMeter(object): 18 | """Computes and stores the average and current value""" 19 | 20 | def __init__(self, momentum=0.99): 21 | self.momentum = momentum 22 | self.reset() 23 | 24 | def reset(self): 25 | self.val = None 26 | self.avg = 0 27 | 28 | def update(self, val): 29 | if self.val is None: 30 | self.avg = val 31 | else: 32 | self.avg = self.avg * self.momentum + val * (1 - self.momentum) 33 | self.val = val 34 | 35 | 36 | def load_model(path): 37 | (_, state_dict), (_, model_args), (_, slover_id) = torch.load(path, map_location='cpu').items() 38 | 39 | is_odenet = model_args.network == 'odenet' 40 | 41 | if not hasattr(model_args, 'in_channels'): 42 | model_args.in_channels = 1 43 | 44 | model = MetaNODE(downsampling_method=model_args.downsampling_method, 45 | is_odenet=is_odenet, 46 | in_channels=model_args.in_channels) 47 | model.load_state_dict(state_dict) 48 | 49 | return model, model_args -------------------------------------------------------------------------------- /sopa/src/solvers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /sopa/src/solvers/euler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .rk_parametric import RKParametricSolver 5 | 6 | 7 | class Euler(RKParametricSolver): 8 | def __init__(self, parameterization=None, u0 = None, v0 = None, dtype=None, device=None, **kwargs): 9 | super(Euler, self).__init__(**kwargs) 10 | self.dtype = dtype 11 | self.device = device 12 | 13 | self.parameterization = None 14 | 15 | self.u = None 16 | self.u0 = None 17 | self.v = None 18 | self.v0 = None 19 | 20 | self.build_ButcherTableau() 21 | 22 | 23 | def _compute_c(self): 24 | self.c1 = torch.tensor((0.,),dtype=self.dtype, device=self.device) 25 | 26 | 27 | def _compute_b(self): 28 | self.b1 = torch.tensor((1.,),dtype=self.dtype, device=self.device) 29 | 30 | 31 | def _compute_w(self): 32 | self.w11 = torch.tensor((0.,),dtype=self.dtype, device=self.device) 33 | 34 | 35 | def _make_u_valid(self, eps): 36 | pass 37 | 38 | 39 | def _make_params_valid(self): 40 | pass 41 | 42 | 43 | def _get_c(self): 44 | c = torch.tensor([self.c1,]) 45 | return c 46 | 47 | 48 | def _get_w(self): 49 | w = [torch.tensor([self.w11,])] 50 | return w 51 | 52 | 53 | def _get_b(self): 54 | b = torch.tensor([self.b1,]) 55 | return b 56 | 57 | 58 | def _get_t(self, t, dt): 59 | t0 = t 60 | return t0 61 | 62 | 63 | def _make_step(self, rhs_func, x, t, dt): 64 | t0 = self._get_t(t, dt) 65 | 66 | k1 = rhs_func(t0, x) 67 | 68 | return (k1 * self.b1) * dt 69 | 70 | 71 | def freeze_params(self): 72 | pass 73 | 74 | 75 | def unfreeze_params(self): 76 | pass 77 | 78 | 79 | @property 80 | def order(self): 81 | return 1 82 | -------------------------------------------------------------------------------- /sopa/src/solvers/rk_parametric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import abc 4 | 5 | class RKParametricSolver(object, metaclass = abc.ABCMeta): 6 | def __init__(self, n_steps = None, step_size = None, grid_constructor=None): 7 | 8 | ## Compute number of grid related args != None 9 | if sum(1 for _ in filter(None.__ne__, [n_steps, step_size, grid_constructor])) >= 2: 10 | raise ValueError("n_steps, step_size and grid_constructor are pairwise exclusive arguments.") 11 | 12 | ## Initialize time grid 13 | if n_steps is not None: 14 | self.grid_constructor = self._grid_constructor_from_n_steps(n_steps) 15 | elif step_size is not None: 16 | self.grid_constructor = self._grid_constructor_from_step_size(step_size) 17 | elif grid_constructor is not None: 18 | self.grid_constructor = grid_constructor 19 | else: 20 | self.grid_constructor = lambda t: t 21 | 22 | 23 | def _grid_constructor_from_step_size(self, step_size): 24 | 25 | def _grid_constructor(t): 26 | start_time = t[0] 27 | end_time = t[-1] 28 | 29 | n_steps = torch.ceil((end_time - start_time) / step_size + 1).item() 30 | t_infer = torch.arange(0, n_steps).to(t) * step_size + start_time 31 | if t_infer[-1] > t[-1]: 32 | t_infer[-1] = t[-1] 33 | return t_infer 34 | 35 | return _grid_constructor 36 | 37 | 38 | def _grid_constructor_from_n_steps(self, n_steps): 39 | 40 | def _grid_constructor(t): 41 | start_time = t[0] 42 | end_time = t[-1] 43 | 44 | t_infer = torch.linspace(start_time, end_time, n_steps + 1).to(t) 45 | return t_infer 46 | 47 | return _grid_constructor 48 | 49 | 50 | @property 51 | @abc.abstractmethod 52 | def order(self): 53 | pass 54 | 55 | @abc.abstractmethod 56 | def freeze_params(self): 57 | pass 58 | 59 | @abc.abstractmethod 60 | def unfreeze_params(self): 61 | pass 62 | 63 | @abc.abstractmethod 64 | def _make_params_valid(self): 65 | pass 66 | 67 | 68 | def build_ButcherTableau(self, return_tableau = False): 69 | self._make_params_valid() 70 | self._compute_c() 71 | self._compute_b() 72 | self._compute_w() 73 | 74 | if return_tableau: 75 | return self._collect_ButcherTableau() 76 | 77 | 78 | def _collect_ButcherTableau(self): 79 | c = self._get_c() 80 | w = self._get_w() 81 | b = self._get_b() 82 | return c, w, b 83 | 84 | 85 | @abc.abstractmethod 86 | def _make_step(self, rhs_func, x, t, dt): 87 | pass 88 | 89 | def integrate(self, rhs_func, x, t): 90 | # _assert_increasing(t) 91 | t = t.type_as(x[0]) 92 | 93 | time_grid = self.grid_constructor(t) 94 | 95 | assert time_grid[0] == t[0] and time_grid[-1] == t[-1] 96 | time_grid = time_grid.to(x[0]) 97 | 98 | # print('\ntime_grid (for evaluation):', time_grid) 99 | 100 | solution = [x] 101 | 102 | j = 1 103 | y0 = x # x has shape (batch_size, *x.shape) 104 | for t0, t1 in zip(time_grid[:-1], time_grid[1:]): 105 | dy = self._make_step(rhs_func, y0, t0, t1 - t0) # dy has shape (batch_size, *x.shape) 106 | y1 = y0 + dy 107 | 108 | # interpolate values at intermediate points 109 | while j < len(t) and t1 >= t[j]: 110 | solution.append(self._linear_interp(t0, t1, y0, y1, t[j])) 111 | j += 1 112 | y0 = y1 113 | return torch.stack(solution) # has shape (len(t), batch_size, *x.shape) 114 | 115 | 116 | def _linear_interp(self, t0, t1, y0, y1, t): 117 | if t == t0: 118 | return y0 119 | if t == t1: 120 | return y1 121 | t0, t1, t = t0.to(y0[0]), t1.to(y0[0]), t.to(y0[0]) 122 | slope = (y1 - y0) / (t1 - t0) # slope has shape (batch_size, *x.shape) 123 | return y0 + slope * (t - t0) 124 | 125 | 126 | def print_is_requires_grad(self): 127 | print('\nIs requires grad? (RK solver)') 128 | 129 | for pname, p in self.__dict__.items(): 130 | if hasattr(p, 'requires_grad'): 131 | print(pname, p.requires_grad) 132 | -------------------------------------------------------------------------------- /sopa/src/solvers/rk_parametric_order2stage2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .rk_parametric import RKParametricSolver 5 | 6 | def build_ButcherTableau_Midpoint(dtype=None, device=None): 7 | c = torch.tensor([0., 1/2.],dtype=dtype, device=device) 8 | w = [torch.tensor([0.,], dtype=dtype, device=device)] + [torch.tensor([1/2., 0.],dtype=dtype, device=device)] 9 | b = torch.tensor([0., 1.],dtype=dtype, device=device) 10 | return c, w, b 11 | 12 | 13 | def build_ButcherTableau_Heun(dtype=None, device=None): 14 | c = torch.tensor([0., 1.],dtype=dtype, device=device) 15 | w = [torch.tensor([0.,],dtype=dtype, device=device)] + [torch.tensor([1., 0.],dtype=dtype, device=device)] 16 | b = torch.tensor([1/2., 1/2.],dtype=dtype, device=device) 17 | return c, w, b 18 | 19 | 20 | class RKOrder2Stage2(RKParametricSolver): 21 | def __init__(self, parameterization='u', u0 = None, v0 = None, dtype=None, device=None, **kwargs): 22 | super(RKOrder2Stage2, self).__init__(**kwargs) 23 | self.dtype = dtype 24 | self.device = device 25 | 26 | if parameterization != 'u': 27 | raise ValueError('Unknown parameterization for RKOrder2Stage2 solver') 28 | self.parameterization = parameterization 29 | self.u = nn.Parameter(data = torch.tensor((u0,), dtype=self.dtype, device=self.device)) 30 | self.u0 = torch.tensor((u0,), dtype=self.dtype, device=self.device) 31 | self.v = None 32 | self.v0 = None 33 | 34 | self.build_ButcherTableau() 35 | 36 | 37 | def _compute_c(self): 38 | self.c1 = torch.tensor((0.,),dtype=self.dtype, device=self.device) 39 | self.c2 = self.u_.clone() 40 | 41 | 42 | def _compute_b(self): 43 | self.b2 = 1. / (2 * self.u_) 44 | self.b1 = 1. - self.b2 45 | 46 | 47 | def _compute_w(self): 48 | self.w21 = self.c2 49 | self.w11, self.w22 = [torch.tensor((0.,),dtype=self.dtype, device=self.device) for _ in range(2)] 50 | 51 | 52 | def _make_u_valid(self, eps): 53 | self.u_ = torch.clamp(self.u, eps, 1.) 54 | 55 | 56 | def _make_params_valid(self): 57 | if self.u.dtype == torch.float64: 58 | eps = torch.finfo(torch.float32).eps 59 | elif self.u.dtype == torch.float32: 60 | eps = torch.finfo(torch.float16).eps 61 | 62 | self._make_u_valid(eps) 63 | 64 | 65 | def _get_c(self): 66 | c = torch.tensor([self.c1, self.c2]) 67 | return c 68 | 69 | 70 | def _get_w(self): 71 | w = [torch.tensor([self.w11,])] + [ 72 | torch.tensor([self.w21, self.w22])] 73 | return w 74 | 75 | 76 | def _get_b(self): 77 | b = torch.tensor([self.b1, self.b2]) 78 | return b 79 | 80 | 81 | def _get_t(self, t, dt): 82 | t0 = t 83 | t1 = t + self.c2 * dt 84 | return (t0, t1) 85 | 86 | 87 | def _make_step(self, rhs_func, x, t, dt): 88 | t0, t1 = self._get_t(t, dt) 89 | 90 | k1 = rhs_func(t0, x) 91 | k2 = rhs_func(t1, x + k1 * self.w21 * dt) 92 | 93 | return (k1 * self.b1 + k2 * self.b2) * dt 94 | 95 | 96 | def freeze_params(self): 97 | self.u.requires_grad = False 98 | if self.v is not None: 99 | self.v.requires_grad = False 100 | 101 | self.build_ButcherTableau() # recompute params to set non leaf requires_grad to False 102 | 103 | 104 | def unfreeze_params(self): 105 | self.u.requires_grad = True 106 | if self.v is not None: 107 | self.v.requires_grad = True 108 | 109 | self.build_ButcherTableau() # recompute params to set non leaf requires_grad to True 110 | 111 | @property 112 | def order(self): 113 | return 2 -------------------------------------------------------------------------------- /sopa/src/solvers/rk_parametric_order3stage3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .rk_parametric import RKParametricSolver 5 | 6 | 7 | class RKOrder3Stage3(RKParametricSolver): 8 | def __init__(self, parameterization = 'uv', u0 = 1/3., v0 = 2/3., dtype=None, device=None, **kwargs): 9 | super(RKOrder3Stage3, self).__init__(**kwargs) 10 | self.dtype = dtype 11 | self.device = device 12 | 13 | if parameterization != 'uv': 14 | raise ValueError('Unknown parameterization for RKOrder3Stage3 solver') 15 | self.parameterization = parameterization 16 | self.u = nn.Parameter(data = torch.tensor((u0,), dtype=self.dtype, device=self.device)) 17 | self.u0 = torch.tensor((u0,), dtype=self.dtype, device=self.device) 18 | 19 | self.v = nn.Parameter(data = torch.tensor((v0,), dtype=self.dtype, device=self.device)) 20 | self.v0 = torch.tensor((v0,), dtype=self.dtype, device=self.device) 21 | 22 | self.build_ButcherTableau() 23 | 24 | 25 | def _compute_c(self): 26 | self.c1 = torch.tensor((0.,), dtype=self.dtype, device=self.device) 27 | self.c2 = self.u_.clone() 28 | self.c3 = self.v_.clone() 29 | 30 | 31 | def _compute_b(self): 32 | v_sub_u = self.v_ - self.u_ 33 | 34 | self.b2 = (2. - 3. * self.v_) / (6. * self.u_ * (-v_sub_u)) 35 | self.b3 = (2. - 3. * self.u_) / (6. * self.v_ * v_sub_u) 36 | self.b1 = 1. - self.b2 - self.b3 37 | 38 | 39 | def _compute_w(self): 40 | self.w32 = self.v_ * (self.v_ - self.u_) / (self.u_ * (2. - 3. * self.u_)) 41 | self.w31 = self.c3 - self.w32 42 | self.w21 = self.c2 43 | 44 | self.w11, self.w22, self.w33 = [torch.tensor((0.,), dtype=self.dtype, device=self.device) for _ in range(3)] 45 | 46 | 47 | def _make_u_valid(self, eps): 48 | self.u_ = torch.clamp(self.u, eps, 1.) 49 | 50 | 51 | def _make_v_valid(self, eps): 52 | self.v_ = torch.clamp(self.v, eps, 1.) 53 | 54 | 55 | def _make_params_valid(self): 56 | if self.u.dtype == torch.float64: 57 | eps = torch.finfo(torch.float32).eps 58 | elif self.u.dtype == torch.float32: 59 | eps = torch.finfo(torch.float16).eps 60 | 61 | self._make_u_valid(eps) 62 | self._make_v_valid(eps) 63 | 64 | if self.u_ == self.v_: 65 | if self.u_ < 1. - eps: 66 | self.v_ = self.u_ + eps 67 | else: 68 | self.u_ = self.v_ - eps 69 | 70 | 71 | def _get_c(self): 72 | c = torch.tensor([self.c1, self.c2, self.c3]) 73 | return c 74 | 75 | 76 | def _get_w(self): 77 | w = [torch.tensor([self.w11,])] + [ 78 | torch.tensor([self.w21, self.w22])] + [ 79 | torch.tensor([self.w31, self.w32, self.w33])] 80 | return w 81 | 82 | 83 | def _get_b(self): 84 | b = torch.tensor([self.b1, self.b2, self.b3]) 85 | return b 86 | 87 | 88 | def _get_t(self, t, dt): 89 | t0 = t 90 | t1 = t + self.c2 * dt 91 | t2 = t + self.c3 * dt 92 | 93 | return (t0, t1, t2) 94 | 95 | 96 | def _make_step(self, rhs_func, x, t, dt): 97 | t0, t1, t2 = self._get_t(t, dt) 98 | 99 | k1 = rhs_func(t0, x) 100 | k2 = rhs_func(t1, x + k1 * self.w21 * dt) 101 | k3 = rhs_func(t2, x + (k1 * self.w31 + k2 * self.w32) * dt) 102 | 103 | return (k1 * self.b1 + k2 * self.b2 + k3 * self.b3) * dt 104 | 105 | 106 | def freeze_params(self): 107 | self.u.requires_grad = False 108 | if self.v is not None: 109 | self.v.requires_grad = False 110 | 111 | self.build_ButcherTableau() # recompute params to set non leaf requires_grad to False 112 | 113 | 114 | def unfreeze_params(self): 115 | self.u.requires_grad = True 116 | if self.v is not None: 117 | self.v.requires_grad = True 118 | 119 | self.build_ButcherTableau() # recompute params to set non leaf requires_grad to True 120 | 121 | @property 122 | def order(self): 123 | return 3 -------------------------------------------------------------------------------- /sopa/src/solvers/rk_parametric_order4stage4.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .rk_parametric import RKParametricSolver 5 | 6 | def build_ButcherTableau_RKStandard(dtype=None, device=None): 7 | c = torch.tensor([0., 1/2., 1/2., 1.], dtype=dtype, device=device) 8 | w = [torch.tensor([0.,], dtype=dtype, device=device)] + [torch.tensor([1/2., 0.], dtype=dtype, device=device)] + [torch.tensor(w_i, dtype=dtype, device=device) for w_i in [[0., 1/2., 0.], [0., 0., 1., 0.]]] 9 | b = torch.tensor([1/6., 1/3., 1/3., 1/6.]) 10 | return c, w, b 11 | 12 | 13 | def build_ButcherTableau_RK38(dtype=None, device=None): 14 | c = torch.tensor([0., 1/3., 2/3., 1.], dtype=dtype, device=device) 15 | w = [torch.tensor([0.,], dtype=dtype, device=device)] + [torch.tensor([1/3., 0.], dtype=dtype, device=device)] + [torch.tensor(w_i, dtype=dtype, device=device) for w_i in [[-1/3., 1., 0.], [1., -1., 1., 0.]]] 16 | b = torch.tensor([1/8., 3/8., 3/8., 1/8.], dtype=dtype, device=device) 17 | return c, w, b 18 | 19 | 20 | class RKOrder4Stage4(RKParametricSolver): 21 | def __init__(self, parameterization = 'u2', u0 = 1/3., v0 = 2/3., dtype=torch.float64, device='cpu', **kwargs): 22 | super(RKOrder4Stage4, self).__init__(**kwargs) 23 | self.dtype = dtype 24 | self.device = device 25 | 26 | self.parameterization = parameterization 27 | self.u = nn.Parameter(data = torch.tensor((u0,), dtype=self.dtype, device=self.device)) 28 | self.u0 = torch.tensor((u0,), dtype=self.dtype, device=self.device) 29 | 30 | if self.parameterization == 'uv': 31 | self.v = nn.Parameter(data = torch.tensor((v0,), dtype=self.dtype, device=self.device)) 32 | self.v0 = torch.tensor((v0,), dtype=self.dtype, device=self.device) 33 | else: 34 | self.v = None 35 | self.v0 = None 36 | 37 | self.build_ButcherTableau() 38 | 39 | 40 | def _compute_c(self): 41 | self.c1 = torch.tensor((0.,), dtype=self.dtype, device=self.device) 42 | 43 | if self.parameterization == 'u1': 44 | self.c2 = torch.tensor((0.5,), dtype=self.dtype, device=self.device) 45 | self.c3 = torch.tensor((0.,), dtype=self.dtype, device=self.device) 46 | 47 | elif self.parameterization == 'u2': 48 | self.c2 = torch.tensor((0.5,), dtype=self.dtype, device=self.device) 49 | self.c3 = torch.tensor((0.5,), dtype=self.dtype, device=self.device) 50 | 51 | elif self.parameterization == 'u3': 52 | self.c2 = torch.tensor((1.,), dtype=self.dtype, device=self.device) 53 | self.c3 = torch.tensor((0.5,), dtype=self.dtype, device=self.device) 54 | 55 | elif self.parameterization == 'uv': 56 | self.c2 = self.u_.clone() # .clone() is nedeed, because without it self.c2 will be nn.Parameter 57 | self.c3 = self.v_.clone() 58 | 59 | self.c4 = torch.tensor((1.,), dtype=self.dtype, device=self.device) 60 | 61 | 62 | 63 | 64 | def _compute_b(self): 65 | if self.parameterization == 'u1': 66 | self.b1 = torch.tensor((1/6.,),dtype=self.dtype, device=self.device) - self.u_ 67 | self.b2 = torch.tensor((2/3.,),dtype=self.dtype, device=self.device) 68 | self.b3 = self.u_.clone() 69 | self.b4 = torch.tensor((1/6.,),dtype=self.dtype, device=self.device) 70 | 71 | elif self.parameterization == 'u2': 72 | self.b1 = torch.tensor((1/6.,),dtype=self.dtype, device=self.device) 73 | self.b2 = torch.tensor((2/3.,),dtype=self.dtype, device=self.device) - self.u_ 74 | self.b3 = self.u_.clone() 75 | self.b4 = torch.tensor((1/6.,),dtype=self.dtype, device=self.device) 76 | 77 | elif self.parameterization == 'u3': 78 | self.b1 = torch.tensor((1/6.,),dtype=self.dtype, device=self.device) 79 | self.b2 = torch.tensor((1/6.,),dtype=self.dtype, device=self.device) - self.u_ 80 | self.b3 = torch.tensor((2/3.,),dtype=self.dtype, device=self.device) 81 | self.b4 = self.u_.clone() 82 | 83 | elif self.parameterization == 'uv': 84 | sub_u = 1. - self.u_ 85 | sub_v = 1. - self.v_ 86 | v_sub_u = self.v_ - self.u_ 87 | 88 | self.b2 = (2. * self.v_ - 1.) / (12 * self.u_ * sub_u * v_sub_u) 89 | self.b3 = (1. - 2 * self.u_) / (12 * self.v_ * sub_v * v_sub_u) 90 | self.b4 = (6. * self.u_ * self.v_ + 3. - 4. * self.u_ - 4. * self.v_) / (12 * sub_u * sub_v) 91 | self.b1 = 1. - self.b2 - self.b3 - self.b4 92 | 93 | 94 | def _compute_w(self): 95 | self.w43 = self.b3 * (1 - self.c3) / self.b4 96 | 97 | A00 = self.b3 * self.c3 * self.c2 98 | A01 = self.b4 * self.c4 * self.c2 99 | A10 = self.b3 100 | A11 = self.b4 101 | 102 | B0 = 0.125 - self.b4 * self.c4 * self.c3 * self.w43 103 | B1 = self.b2 * (1 - self.c2) 104 | 105 | ### Find w32, w42 using Cramer's rule 106 | detA = A00 * A11 - A01 * A10 107 | detA0 = B0 * A11 - B1 * A01 108 | detA1 = A00 * B1 - A10 * B0 109 | 110 | self.w32 = detA0 / detA 111 | self.w42 = detA1 / detA 112 | ### 113 | 114 | # ### Find w32, w42 using torch.solve 115 | # A = torch.cat((A00, A01, A10, A11)).reshape((2,2)) 116 | # B = torch.cat((B0, B1)).reshape((2,1)) 117 | # (self.w32, self.w42), _ = torch.solve(B, A) 118 | # ### 119 | 120 | self.w41 = self.c4 - (self.w42 + self.w43) 121 | self.w31 = self.c3 - self.w32 122 | self.w21 = self.c2 123 | 124 | self.w11, self.w22, self.w33, self.w44 = [torch.tensor((0.,),dtype=self.dtype, device = self.device) for _ in range(4)] 125 | 126 | 127 | def _make_u_valid(self, eps): 128 | if self.v is not None: 129 | if self.u < 0.5: 130 | self.u_ = torch.clamp(self.u, eps, 0.5 - eps) 131 | else: 132 | self.u_ = torch.clamp(self.u, 0.5 + eps, 1. - eps) 133 | else: 134 | self.u_ = torch.clamp(self.u, eps, 1. - eps) 135 | 136 | 137 | def _make_v_valid(self, eps): 138 | self.v_ = torch.clamp(self.v, eps, 1. - eps) 139 | 140 | 141 | def _make_params_valid(self): 142 | if self.u.dtype == torch.float64: 143 | eps = torch.finfo(torch.float32).eps 144 | elif self.u.dtype == torch.float32: 145 | eps = torch.finfo(torch.float16).eps 146 | 147 | self._make_u_valid(eps) 148 | 149 | if self.v is not None: 150 | self._make_v_valid(eps) 151 | 152 | if self.u_ == self.v_: 153 | if self.u_ < 1. - eps: 154 | self.v_ = self.u_ + eps 155 | else: 156 | self.u_ = self.v_ - eps 157 | 158 | 159 | def _get_c(self): 160 | c = torch.tensor([self.c1, self.c2, self.c3, self.c4]) 161 | return c 162 | 163 | def _get_w(self): 164 | w = [torch.tensor([self.w11,])] + [ 165 | torch.tensor([self.w21, self.w22])] + [ 166 | torch.tensor([self.w31, self.w32, self.w33])] + [ 167 | torch.tensor([self.w41, self.w42, self.w43, self.w44])] 168 | 169 | return w 170 | 171 | def _get_b(self): 172 | b = torch.tensor([self.b1, self.b2, self.b3, self.b4]) 173 | return b 174 | 175 | def _get_t(self, t, dt): 176 | t0 = t 177 | t1 = t + self.c2 * dt 178 | t2 = t + self.c3 * dt 179 | t3 = t + self.c4 * dt 180 | 181 | return (t0, t1, t2, t3) 182 | 183 | 184 | def _make_step(self, rhs_func, x, t, dt): 185 | t0, t1, t2, t3 = self._get_t(t, dt) 186 | 187 | k1 = rhs_func(t0, x) 188 | k2 = rhs_func(t1, x + k1 * self.w21 * dt) 189 | k3 = rhs_func(t2, x + (k1 * self.w31 + k2 * self.w32) * dt) 190 | k4 = rhs_func(t3, x + (k1 * self.w41 + k2 * self.w42 + k3 * self.w43) * dt) 191 | 192 | return (k1 * self.b1 + k2 * self.b2 + k3 * self.b3 + k4 * self.b4) * dt 193 | 194 | 195 | def freeze_params(self): 196 | self.u.requires_grad = False 197 | if self.v is not None: 198 | self.v.requires_grad = False 199 | 200 | self.build_ButcherTableau() # recompute params to set non leaf requires_grad to False 201 | 202 | def unfreeze_params(self): 203 | self.u.requires_grad = True 204 | if self.v is not None: 205 | self.v.requires_grad = True 206 | 207 | self.build_ButcherTableau() # recompute params to set non leaf requires_grad to True 208 | 209 | @property 210 | def order(self): 211 | return 4 -------------------------------------------------------------------------------- /sopa/src/solvers/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import copy 4 | 5 | from torch.distributions.normal import Normal 6 | from torch.distributions.cauchy import Cauchy 7 | 8 | from .rk_parametric_order4stage4 import RKOrder4Stage4 9 | from .rk_parametric_order3stage3 import RKOrder3Stage3 10 | from .rk_parametric_order2stage2 import RKOrder2Stage2 11 | from .euler import Euler 12 | 13 | def create_solver(method, parameterization, n_steps, step_size, u0, v0, dtype, device): 14 | ''' 15 | method: str 16 | parameterization: str 17 | n_steps: int 18 | step_size: Decimal 19 | u0: Decimal 20 | v0: Decimal 21 | dtype: torch dtype 22 | ''' 23 | if n_steps == -1: 24 | n_steps = None 25 | 26 | if step_size == -1: 27 | step_size = None 28 | 29 | if dtype == torch.float64: 30 | u0, v0 = map(lambda el : np.float64(el), [u0, v0]) 31 | elif dtype == torch.float32: 32 | u0, v0 = map(lambda el : np.float32(el), [u0, v0]) 33 | 34 | if method == 'euler': 35 | return Euler(n_steps = n_steps, 36 | step_size = step_size, 37 | parameterization=parameterization, 38 | u0 = u0, v0 = v0, 39 | dtype =dtype, device = device) 40 | elif method == 'rk2': 41 | return RKOrder2Stage2(n_steps = n_steps, 42 | step_size = step_size, 43 | parameterization=parameterization, 44 | u0 = u0, v0 = v0, 45 | dtype =dtype, device = device) 46 | elif method == 'rk3': 47 | return RKOrder3Stage3(n_steps = n_steps, 48 | step_size = step_size, 49 | parameterization=parameterization, 50 | u0 = u0, v0 = v0, 51 | dtype =dtype, device = device) 52 | elif method == 'rk4': 53 | return RKOrder4Stage4(n_steps = n_steps, 54 | step_size = step_size, 55 | parameterization=parameterization, 56 | u0 = u0, v0 = v0, 57 | dtype =dtype, device = device) 58 | 59 | 60 | def sample_noise(mu, sigma, noise_type='cauchy', size=1, device='cpu', minimize_rk2_error=False): 61 | if not minimize_rk2_error: 62 | if noise_type == 'cauchy': 63 | d = Cauchy(torch.tensor([mu]), torch.tensor([sigma])) 64 | elif noise_type == 'normal': 65 | d = Normal(torch.tensor([mu]), torch.tensor([sigma])) 66 | else: 67 | if noise_type == 'cauchy': 68 | d = Cauchy(torch.tensor([2/3.]), torch.tensor([2/3. * sigma])) 69 | elif noise_type == 'normal': 70 | d = Normal(torch.tensor([2/3.]), torch.tensor([2/3. * sigma])) 71 | 72 | return torch.tensor([d.sample() for _ in range(size)], device=device) 73 | 74 | 75 | def noise_params(mean_u, mean_v=None, std=0.01, bernoulli_p=1.0, noise_type='cauchy', minimize_rk2_error=False): 76 | ''' Noise solver paramers with Cauchy/Normal noise with probability p 77 | ''' 78 | d = torch.distributions.Bernoulli(torch.tensor([bernoulli_p], dtype=torch.float32)) 79 | v = None 80 | device = mean_u.device 81 | eps = torch.finfo(mean_u.dtype).eps 82 | 83 | if d.sample(): 84 | std = torch.abs(torch.tensor(std, device=device)) 85 | 86 | u = sample_noise(mean_u, std, noise_type=noise_type, size=1, device=device, minimize_rk2_error=minimize_rk2_error) 87 | if u <= mean_u - 2*std or u >= mean_u + 2*std: 88 | u = mean_u 89 | # u = min(max(u, mean_u - 2*std,0), mean_u + 2*std, 1.) 90 | 91 | if mean_v is not None: 92 | v = sample_noise(mean_v, std, noise_type=noise_type, size=1, device=device, minimize_rk2_error=minimize_rk2_error) 93 | else: 94 | u = mean_u 95 | if mean_v is not None: 96 | v = mean_v 97 | 98 | return u, v 99 | 100 | def sample_solver_by_noising_params(solver, std=0.01, bernoulli_p=1., noise_type='cauchy', minimize_rk2_error=False): 101 | new_solver = copy.deepcopy(solver) 102 | new_solver.u, new_solver.v = noise_params(mean_u=new_solver.u0, 103 | mean_v=new_solver.v0, 104 | std=std, 105 | bernoulli_p=bernoulli_p, 106 | noise_type=noise_type, 107 | minimize_rk2_error=minimize_rk2_error) 108 | new_solver.build_ButcherTableau() 109 | print(new_solver.u, new_solver.v) 110 | return new_solver 111 | 112 | def create_solver_ensemble_by_noising_params(solver, ensemble_size=1, kwargs_noise={}): 113 | solver_ensemble = [solver] 114 | for _ in range(1, ensemble_size): 115 | new_solver = sample_solver_by_noising_params(solver, **kwargs_noise) 116 | solver_ensemble.append(new_solver) 117 | return solver_ensemble 118 | 119 | 120 | 121 | 122 | 123 | 124 | --------------------------------------------------------------------------------