├── .gitignore
├── LICENSE
├── MegaAdversarial
├── README.md
└── src
│ ├── __init__.py
│ └── attacks
│ ├── __init__.py
│ ├── attack.py
│ ├── base.py
│ ├── fgsm.py
│ └── pgd.py
├── README.md
├── examples
├── cifar10
│ ├── Build the model.ipynb
│ ├── Evaluate the model.ipynb
│ ├── assets
│ │ ├── fgsm_random_train.png
│ │ ├── fgsm_random_train_clean_test.png
│ │ ├── fgsm_random_train_fgsm_eps_8_255_test.png
│ │ └── fgsm_random_train_pgd_eps_8_255_lr_2_255_iters_7_test.png
│ ├── checkpoints
│ │ ├── accuracy
│ │ │ ├── fgsm_random_8_255_clean.pkl
│ │ │ ├── fgsm_random_8_255_fgsm_8_255.pkl
│ │ │ ├── fgsm_random_8_255_pgd_8_255_2_255_7.pkl
│ │ │ ├── fgsm_random_8_255_smoothing_00125_clean.pkl
│ │ │ ├── fgsm_random_8_255_smoothing_00125_fgsm_8_255.pkl
│ │ │ └── fgsm_random_8_255_smoothing_00125_pgd_8_255_2_255_7.pkl
│ │ ├── fgsm_random_8_255.pth
│ │ ├── fgsm_random_8_255_smoothing_00125.pth
│ │ └── fgsm_random_8_255_switching_05_05215_04875.pth
│ └── train_and_attack.py
└── mnist
│ ├── Build the model.ipynb
│ ├── Evaluate the model.ipynb
│ ├── assets
│ ├── mnist_adv.pdf
│ └── mnist_adv.png
│ ├── checkpoints
│ └── checkpoint_15444.pth
│ └── train_and_attack.py
└── sopa
├── __init__.py
└── src
├── __init__.py
├── models
├── __init__.py
├── odenet_cifar10
│ ├── __init__.py
│ ├── data.py
│ ├── layers.py
│ └── utils.py
├── odenet_mnist
│ ├── attacks_runner.py
│ ├── attacks_utils.py
│ ├── data.py
│ ├── layers.py
│ ├── metrics.py
│ ├── runner.py
│ ├── runner_new.py
│ ├── runner_old.py
│ ├── train_validate.py
│ └── utils.py
└── utils.py
└── solvers
├── __init__.py
├── euler.py
├── rk_parametric.py
├── rk_parametric_order2stage2.py
├── rk_parametric_order3stage3.py
├── rk_parametric_order4stage4.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *~
2 | *.log
3 | .data/
4 | data/*
5 | data/
6 | *.pyc
7 | events/*
8 | .ipynb_checkpoints/
9 | .ipynb_checkpoints/*
10 | gifs/*
11 | wandb/
12 | tests/
13 | tests/*
14 | checkpoints/accuracy
15 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2021, MetaSolver
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/MegaAdversarial/README.md:
--------------------------------------------------------------------------------
1 | ### Make neural networks robust again
--------------------------------------------------------------------------------
/MegaAdversarial/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/MegaAdversarial/src/__init__.py
--------------------------------------------------------------------------------
/MegaAdversarial/src/attacks/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import Clean, Clean2Ensemble
2 | from .pgd import PGD
3 | from .fgsm import FGSM, FGSMRandom, FGSM2Ensemble
4 |
--------------------------------------------------------------------------------
/MegaAdversarial/src/attacks/attack.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4 |
5 | class Attack(nn.Module):
6 | def __init__(self, model):
7 | super().__init__()
8 | self.model = model
9 | self.device = device
10 |
11 | def _project(self, x):
12 | return torch.clamp(x, 0, 1)
13 |
14 | def _clamp(self, x, min, max):
15 | return torch.max(torch.min(x, max), min)
16 |
17 | def forward(self, *args, **kwargs):
18 | raise NotImplementedError
19 |
20 |
21 | class Attack2Ensemble(nn.Module):
22 | def __init__(self, models):
23 | super().__init__()
24 | self.models = models
25 | self.device = device
26 |
27 | def _project(self, x):
28 | return torch.clamp(x, 0, 1)
29 |
30 | def _clamp(self, x, min, max):
31 | return torch.max(torch.min(x, max), min)
32 |
33 | def forward(self, *args, **kwargs):
34 | raise NotImplementedError
--------------------------------------------------------------------------------
/MegaAdversarial/src/attacks/base.py:
--------------------------------------------------------------------------------
1 | from .attack import Attack, Attack2Ensemble
2 |
3 |
4 | class Clean(Attack):
5 | def forward(self, x, y, kwargs):
6 | return x, y
7 |
8 | class Clean2Ensemble(Attack2Ensemble):
9 | def forward(self, x, y, kwargs_arr):
10 | return x, y
11 |
--------------------------------------------------------------------------------
/MegaAdversarial/src/attacks/fgsm.py:
--------------------------------------------------------------------------------
1 | from .attack import Attack, Attack2Ensemble
2 | import torch
3 | import torch.nn as nn
4 | import torchvision.transforms as transforms
5 |
6 |
7 |
8 | class FGSM(Attack):
9 | """
10 | The standard FGSM attack. Assumes (0, 1) normalization.
11 | """
12 |
13 | def __init__(self, model, eps=None, mean=None, std=None):
14 | super(FGSM, self).__init__(model)
15 | self.eps = eps
16 | self.loss_fn = nn.CrossEntropyLoss().to(self.device)
17 | self.mean = mean if mean is not None else (0., 0., 0.)
18 | self.std = std if std is not None else (1., 1., 1.)
19 |
20 |
21 | def forward(self, x, y, kwargs):
22 |
23 | training = self.model.training
24 | if training:
25 | self.model.eval()
26 |
27 | inv_normalize = transforms.Normalize(mean=[-m/s for m, s in zip(self.mean, self.std)], std=[1/s for s in self.std])
28 | normalize = transforms.Normalize(mean=self.mean, std=self.std)
29 | x = inv_normalize(x) # x in [0, 1]
30 |
31 | x_attacked = x.clone().detach()
32 | x_attacked.requires_grad_(True)
33 | loss = self.loss_fn(self.model(normalize(x_attacked), **kwargs), y)
34 | grad = torch.autograd.grad(
35 | [loss], [x_attacked], create_graph=False, retain_graph=False
36 | )[0]
37 | x_attacked = x_attacked + self.eps * grad.sign()
38 | x_attacked = self._project(x_attacked)
39 | x_attacked = normalize(x_attacked)
40 | x_attacked = x_attacked.detach()
41 | if training:
42 | self.model.train()
43 | return x_attacked, y
44 |
45 |
46 | def clamp(X, lower_limit, upper_limit):
47 | if not isinstance(upper_limit, torch.Tensor):
48 | upper_limit = torch.tensor(upper_limit, device=X.device, dtype=X.dtype)
49 | if not isinstance(lower_limit, torch.Tensor):
50 | lower_limit = torch.tensor(lower_limit, device=X.device, dtype=X.dtype)
51 | return torch.max(torch.min(X, upper_limit), lower_limit)
52 |
53 |
54 | class FGSMRandom(Attack):
55 | """
56 | The standard FGSM attack. Assumes (0, 1) normalization.
57 | This implementation is inspired by the implementation from here:
58 | https://github.com/locuslab/fast_adversarial/blob/master/CIFAR10/train_fgsm.py
59 | """
60 |
61 | def __init__(self, model, alpha, epsilon=None, mu=None, std=None):
62 | '''
63 | Args:
64 | model: the neural network model
65 | alpha: the step size
66 | epsilon: the radius of the random noise
67 | mu: the mean value of all dataset samples
68 | std: the std value of all dataset samples
69 | '''
70 | super(FGSMRandom, self).__init__(model)
71 | self.epsilon = epsilon
72 | self.alpha = alpha
73 | if (mu is not None) and (std is not None):
74 | mu = torch.tensor(mu, device=self.device).view(1, 3, 1, 1)
75 | std = torch.tensor(std, device=self.device).view(1, 3, 1, 1)
76 |
77 | self.lower_limit = (0. - mu) / std
78 | self.upper_limit = (1. - mu) / std # lower = -mu/std, upper=(1-mu)/std
79 |
80 | self.epsilon = self.epsilon / std
81 | self.alpha = self.alpha / std
82 | else:
83 | self.lower_limit = 0.
84 | self.upper_limit = 1.
85 |
86 | self.loss_fn = nn.CrossEntropyLoss().to(self.device)
87 |
88 | def forward(self, x, y, kwargs):
89 | training = self.model.training
90 | if training:
91 | self.model.eval()
92 |
93 | delta = self.epsilon - (2 * self.epsilon) * torch.rand_like(x) # Uniform[-eps, eps]
94 | delta.data = clamp(delta, self.lower_limit - x, self.upper_limit - x)
95 | delta.requires_grad = True
96 | output = self.model(x + delta, **kwargs)
97 | loss = self.loss_fn(output, y)
98 | loss.backward()
99 | grad = delta.grad.detach()
100 | delta.data = clamp(delta + self.alpha * torch.sign(grad), -self.epsilon, self.epsilon)
101 | delta.data = clamp(delta, self.lower_limit - x, self.upper_limit - x)
102 | delta = delta.detach()
103 |
104 | if training:
105 | self.model.train()
106 | return x + delta, y
107 |
108 |
109 | class FGSM2Ensemble(Attack2Ensemble):
110 | """
111 | The standard FGSM attack. Assumes (0, 1) normalization.
112 | """
113 |
114 | def __init__(self, models, eps=None, mean=None, std=None):
115 | super(FGSM2Ensemble, self).__init__(models)
116 | self.eps = eps
117 | self.loss_fn = nn.NLLLoss().to(self.device)
118 | self.mean = mean if mean is not None else (0., 0., 0.)
119 | self.std = std if std is not None else (1., 1., 1.)
120 |
121 | def forward(self, x, y, kwargs_arr):
122 |
123 | training = self.models[0].training
124 | if training:
125 | for model in self.models:
126 | model.eval()
127 |
128 | inv_normalize = transforms.Normalize(mean=[-m/s for m, s in zip(self.mean, self.std)], std=[1/s for s in self.std])
129 | normalize = transforms.Normalize(mean=self.mean, std=self.std)
130 | x = inv_normalize(x) # x in [0, 1]
131 |
132 | x_attacked = x.clone().detach()
133 | x_attacked.requires_grad_(True)
134 |
135 | probs_ensemble = 0
136 |
137 | for model, kwargs in zip(self.models, kwargs_arr):
138 | logits = model(normalize(x_attacked), **kwargs)
139 | probs_ensemble = probs_ensemble + nn.Softmax()(logits)
140 |
141 | probs_ensemble /= len(self.models)
142 |
143 | loss = self.loss_fn(torch.log(probs_ensemble), y)
144 | grad = torch.autograd.grad(
145 | [loss], [x_attacked], create_graph=False, retain_graph=False
146 | )[0]
147 | x_attacked = x_attacked + self.eps * grad.sign()
148 | x_attacked = self._project(x_attacked)
149 | x_attacked = normalize(x_attacked)
150 | x_attacked = x_attacked.detach()
151 |
152 | if training:
153 | for model in self.models:
154 | model.train()
155 | return x_attacked, y
156 |
--------------------------------------------------------------------------------
/MegaAdversarial/src/attacks/pgd.py:
--------------------------------------------------------------------------------
1 | from .attack import Attack
2 | import torch
3 | import torch.nn as nn
4 | import torchvision.transforms as transforms
5 |
6 |
7 |
8 | class PGD(Attack):
9 | """
10 | The standard PGD attack. Assumes (0, 1) normalization.
11 | """
12 |
13 | def __init__(self, model, eps=None, lr=None, n_iter=None, randomized_start=True, mean=None, std=None):
14 | super(PGD, self).__init__(model)
15 | self.eps = eps
16 | self.lr = lr
17 | self.n_iter = n_iter
18 | self.randomized_start = randomized_start
19 | self.loss_fn = nn.CrossEntropyLoss().to(self.device)
20 | self.mean = mean if mean is not None else (0., 0., 0.)
21 | self.std = std if std is not None else (1., 1., 1.)
22 |
23 | def forward(self, x, y, kwargs):
24 | training = self.model.training
25 | if training:
26 | self.model.eval()
27 |
28 | inv_normalize = transforms.Normalize(mean=[-m/s for m, s in zip(self.mean, self.std)], std=[1/s for s in self.std])
29 | normalize = transforms.Normalize(mean=self.mean, std=self.std)
30 | x = inv_normalize(x) # x in [0, 1]
31 |
32 | if self.randomized_start:
33 | x_attacked = (
34 | self._project(x + torch.zeros_like(x).uniform_(-self.eps, self.eps))
35 | .clone()
36 | .detach()
37 | )
38 | else:
39 | x_attacked = x.clone().detach()
40 |
41 | for i in range(self.n_iter):
42 | x_attacked.requires_grad_(True)
43 | loss = self.loss_fn(self.model(normalize(x_attacked), **kwargs), y)
44 | grad = torch.autograd.grad(
45 | [loss], [x_attacked], create_graph=False, retain_graph=False
46 | )[0]
47 | x_attacked = self._clamp(
48 | x_attacked + self.lr * grad.sign(), x - self.eps, x + self.eps
49 | )
50 | x_attacked = self._project(x_attacked)
51 | if i == self.n_iter - 1:
52 | x_attacked = normalize(x_attacked)
53 | x_attacked = x_attacked.detach()
54 |
55 | if training:
56 | self.model.train()
57 | return x_attacked, y
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Meta-Solver for Neural Ordinary Differential Equations
2 | Towards robust neural ODEs using parametrized solvers.
3 |
4 | # Main idea
5 | Each *Runge-Kutta (RK)* solver with `s` stages and of the `p`-th order is defined by a table of coefficients (*Butcher tableau*). For `s=p=2`, `s=p=3` and `s=p=4` all coefficient in the table can be parametrized with no more than two variables [1].
6 |
7 | Usually, during neural ODE training RK solver with fixed Butcher tableau is used, and only the *right-hand side (RHS)* function is trained. We propose to use the whole parametric family of RK solvers to improve robustness of neural ODEs.
8 |
9 | # Requirements
10 | - pytorch==1.7
11 | - apex==0.1 (for training)
12 |
13 | # Examples
14 | For CIFAR-10 and MNIST demo, please, check `examples` folder.
15 |
16 | # Meta Solver Regimes
17 | In the notebook `examples/cifar10/Evaluate model.ipynb` we show how to perform the forward pass through the Neural ODE using different types of Meta Solver regimes, namely
18 | - Standalone
19 | - Solver switching/smoothing
20 | - Solver ensembling
21 | - Model ensembling
22 |
23 | In more details, usage of different regimes means
24 | - **Standalone**
25 | - Use one solver during inference.
26 | - This regime is applied in the training and testing stages.
27 |
28 |
29 |
30 | - **Solver switching / smoothing**
31 | - For each batch one solver is chosen from a group of solvers with finite (in switching regime) or infinite (in smoothing regime) number of candidates.
32 | - This regime is applied in the training stage
33 |
34 |
35 | - **Solver ensembling**
36 | - Use several solvers durung inference.
37 | - Outputs of ODE Block (obtained with different solvers) are averaged before propagating through the next layer.
38 | - This regime is applied in the training and testing stages.
39 |
40 |
41 | - **Model ensembling**
42 | - Use several solvers durung inference.
43 | - Model probabilites obtained via propagation with different solvers are averaged to get the final result.
44 | - This regime is applied in the training and testing stages.
45 |
46 | # Selected results
47 | ## Different solver parameterizations yield different robustness
48 | We have trained a neural ODE model several times, using different ``u`` values in parametrization of the 2-nd order Runge-Kutta solver. The image below depicts robust accuracies for the MNIST classification task. We use PGD attack (eps=0.3, lr=2/255 and iters=7). The mean values of robust accuracy (bold lines) and +- standard error mean (shaded region) computed across 9 random seeds are shown in this image.
49 |
50 |
51 |
52 | ## Solver smoothing improves robustness
53 | We compare results of neural ODE adversarial training on CIFAR-10 dataset with meta-solver in standalone, switching or smoothing regimes. We choose 8-steps RK2 solvers for this experiment.
54 | - We perform training using FGSM random technique described in https://arxiv.org/abs/2001.03994 (with eps=8/255, alpha=10/255).
55 | - We use cyclic learning rate schedule with one cycle (36 epochs, max_lr=0.1, base_lr=1e-7).
56 | - We measure robust accuracy of resulting models after FGSM (eps=8/255) and PGD (eps=8/255, lr=2/255, iters=7) attacks.
57 | - We use `premetanode10` architecture from `sopa/src/models/odenet_cifar10/layers.py` that has the following form
58 | `Conv -> PreResNet block -> ODE block -> PreResNet block -> ODE block -> GeLU -> Average Pooling -> Fully Connected`
59 | - We compute mean and standard error across 3 random seeds.
60 |
61 |
65 |
66 | 
67 |
68 | # References
69 | [1] [Wanner, G., & Hairer, E. (1993). Solving ordinary differential equations I. Springer Berlin Heidelberg](https://www.springer.com/gp/book/9783540566700)
70 |
--------------------------------------------------------------------------------
/examples/cifar10/Build the model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 11,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import argparse\n",
10 | "import copy\n",
11 | "import sys\n",
12 | "\n",
13 | "sys.path.append('../../')\n",
14 | "import sopa.src.models.odenet_cifar10.layers as cifar10_models\n",
15 | "from sopa.src.models.odenet_cifar10.utils import *"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": 12,
21 | "metadata": {},
22 | "outputs": [],
23 | "source": [
24 | "parser = argparse.ArgumentParser()\n",
25 | "# Architecture params\n",
26 | "parser.add_argument('--is_odenet', type=eval, default=True, choices=[True, False])\n",
27 | "parser.add_argument('--network', type=str, choices=['metanode34', 'metanode18', 'metanode10', 'metanode6', 'metanode4',\n",
28 | " 'premetanode34', 'premetanode18', 'premetanode10', 'premetanode6',\n",
29 | " 'premetanode4'],\n",
30 | " default='premetanode10')\n",
31 | "parser.add_argument('--in_planes', type=int, default=64)\n",
32 | "\n",
33 | "# Type of layer's output normalization\n",
34 | "parser.add_argument('--normalization_resblock', type=str, default='NF',\n",
35 | " choices=['BN', 'GN', 'LN', 'IN', 'NF'])\n",
36 | "parser.add_argument('--normalization_odeblock', type=str, default='NF',\n",
37 | " choices=['BN', 'GN', 'LN', 'IN', 'NF'])\n",
38 | "parser.add_argument('--normalization_bn1', type=str, default='NF',\n",
39 | " choices=['BN', 'GN', 'LN', 'IN', 'NF'])\n",
40 | "parser.add_argument('--num_gn_groups', type=int, default=32, help='Number of groups for GN normalization')\n",
41 | "\n",
42 | "# Type of layer's weights normalization\n",
43 | "parser.add_argument('--param_normalization_resblock', type=str, default='PNF',\n",
44 | " choices=['WN', 'SN', 'PNF'])\n",
45 | "parser.add_argument('--param_normalization_odeblock', type=str, default='PNF',\n",
46 | " choices=['WN', 'SN', 'PNF'])\n",
47 | "parser.add_argument('--param_normalization_bn1', type=str, default='PNF',\n",
48 | " choices=['WN', 'SN', 'PNF'])\n",
49 | "# Type of activation\n",
50 | "parser.add_argument('--activation_resblock', type=str, default='ReLU',\n",
51 | " choices=['ReLU', 'GeLU', 'Softsign', 'Tanh', 'AF'])\n",
52 | "parser.add_argument('--activation_odeblock', type=str, default='ReLU',\n",
53 | " choices=['ReLU', 'GeLU', 'Softsign', 'Tanh', 'AF'])\n",
54 | "parser.add_argument('--activation_bn1', type=str, default='ReLU',\n",
55 | " choices=['ReLU', 'GeLU', 'Softsign', 'Tanh', 'AF'])\n",
56 | "\n",
57 | "args, unknown_args = parser.parse_known_args()"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": 13,
63 | "metadata": {},
64 | "outputs": [],
65 | "source": [
66 | "# Initialize Neural ODE model\n",
67 | "config = copy.deepcopy(args)\n",
68 | "\n",
69 | "norm_layers = (get_normalization(config.normalization_resblock),\n",
70 | " get_normalization(config.normalization_odeblock),\n",
71 | " get_normalization(config.normalization_bn1))\n",
72 | "param_norm_layers = (get_param_normalization(config.param_normalization_resblock),\n",
73 | " get_param_normalization(config.param_normalization_odeblock),\n",
74 | " get_param_normalization(config.param_normalization_bn1))\n",
75 | "act_layers = (get_activation(config.activation_resblock),\n",
76 | " get_activation(config.activation_odeblock),\n",
77 | " get_activation(config.activation_bn1))\n",
78 | "\n",
79 | "model = getattr(cifar10_models, config.network)(norm_layers, param_norm_layers, act_layers,\n",
80 | " config.in_planes, is_odenet=config.is_odenet)"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": 14,
86 | "metadata": {},
87 | "outputs": [
88 | {
89 | "data": {
90 | "text/plain": [
91 | "MetaNODE(\n",
92 | " (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
93 | " (bn1): Identity()\n",
94 | " (layer1): MetaLayer(\n",
95 | " (blocks_res): Sequential(\n",
96 | " (0): PreBasicBlock(\n",
97 | " (bn1): Identity()\n",
98 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
99 | " (bn2): Identity()\n",
100 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
101 | " (shortcut): Sequential()\n",
102 | " )\n",
103 | " )\n",
104 | " (blocks_ode): ModuleList(\n",
105 | " (0): MetaODEBlock(\n",
106 | " (rhs_func): PreBasicBlock2(\n",
107 | " (bn1): Identity()\n",
108 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
109 | " (bn2): Identity()\n",
110 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
111 | " (shortcut): Sequential()\n",
112 | " )\n",
113 | " )\n",
114 | " )\n",
115 | " )\n",
116 | " (layer2): MetaLayer(\n",
117 | " (blocks_res): Sequential(\n",
118 | " (0): PreBasicBlock(\n",
119 | " (bn1): Identity()\n",
120 | " (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
121 | " (bn2): Identity()\n",
122 | " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
123 | " (shortcut): Sequential(\n",
124 | " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
125 | " )\n",
126 | " )\n",
127 | " )\n",
128 | " (blocks_ode): ModuleList(\n",
129 | " (0): MetaODEBlock(\n",
130 | " (rhs_func): PreBasicBlock2(\n",
131 | " (bn1): Identity()\n",
132 | " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
133 | " (bn2): Identity()\n",
134 | " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
135 | " (shortcut): Sequential()\n",
136 | " )\n",
137 | " )\n",
138 | " )\n",
139 | " )\n",
140 | " (fc_layers): Sequential(\n",
141 | " (0): AdaptiveAvgPool2d(output_size=(1, 1))\n",
142 | " (1): Flatten()\n",
143 | " (2): Linear(in_features=128, out_features=10, bias=True)\n",
144 | " )\n",
145 | ")"
146 | ]
147 | },
148 | "execution_count": 14,
149 | "metadata": {},
150 | "output_type": "execute_result"
151 | }
152 | ],
153 | "source": [
154 | "model"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": null,
160 | "metadata": {},
161 | "outputs": [],
162 | "source": []
163 | }
164 | ],
165 | "metadata": {
166 | "kernelspec": {
167 | "display_name": "Python 3",
168 | "language": "python",
169 | "name": "python3"
170 | },
171 | "language_info": {
172 | "codemirror_mode": {
173 | "name": "ipython",
174 | "version": 3
175 | },
176 | "file_extension": ".py",
177 | "mimetype": "text/x-python",
178 | "name": "python",
179 | "nbconvert_exporter": "python",
180 | "pygments_lexer": "ipython3",
181 | "version": "3.6.8"
182 | }
183 | },
184 | "nbformat": 4,
185 | "nbformat_minor": 2
186 | }
187 |
--------------------------------------------------------------------------------
/examples/cifar10/assets/fgsm_random_train.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/assets/fgsm_random_train.png
--------------------------------------------------------------------------------
/examples/cifar10/assets/fgsm_random_train_clean_test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/assets/fgsm_random_train_clean_test.png
--------------------------------------------------------------------------------
/examples/cifar10/assets/fgsm_random_train_fgsm_eps_8_255_test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/assets/fgsm_random_train_fgsm_eps_8_255_test.png
--------------------------------------------------------------------------------
/examples/cifar10/assets/fgsm_random_train_pgd_eps_8_255_lr_2_255_iters_7_test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/assets/fgsm_random_train_pgd_eps_8_255_lr_2_255_iters_7_test.png
--------------------------------------------------------------------------------
/examples/cifar10/checkpoints/accuracy/fgsm_random_8_255_clean.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/checkpoints/accuracy/fgsm_random_8_255_clean.pkl
--------------------------------------------------------------------------------
/examples/cifar10/checkpoints/accuracy/fgsm_random_8_255_fgsm_8_255.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/checkpoints/accuracy/fgsm_random_8_255_fgsm_8_255.pkl
--------------------------------------------------------------------------------
/examples/cifar10/checkpoints/accuracy/fgsm_random_8_255_pgd_8_255_2_255_7.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/checkpoints/accuracy/fgsm_random_8_255_pgd_8_255_2_255_7.pkl
--------------------------------------------------------------------------------
/examples/cifar10/checkpoints/accuracy/fgsm_random_8_255_smoothing_00125_clean.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/checkpoints/accuracy/fgsm_random_8_255_smoothing_00125_clean.pkl
--------------------------------------------------------------------------------
/examples/cifar10/checkpoints/accuracy/fgsm_random_8_255_smoothing_00125_fgsm_8_255.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/checkpoints/accuracy/fgsm_random_8_255_smoothing_00125_fgsm_8_255.pkl
--------------------------------------------------------------------------------
/examples/cifar10/checkpoints/accuracy/fgsm_random_8_255_smoothing_00125_pgd_8_255_2_255_7.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/checkpoints/accuracy/fgsm_random_8_255_smoothing_00125_pgd_8_255_2_255_7.pkl
--------------------------------------------------------------------------------
/examples/cifar10/checkpoints/fgsm_random_8_255.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/checkpoints/fgsm_random_8_255.pth
--------------------------------------------------------------------------------
/examples/cifar10/checkpoints/fgsm_random_8_255_smoothing_00125.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/checkpoints/fgsm_random_8_255_smoothing_00125.pth
--------------------------------------------------------------------------------
/examples/cifar10/checkpoints/fgsm_random_8_255_switching_05_05215_04875.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/cifar10/checkpoints/fgsm_random_8_255_switching_05_05215_04875.pth
--------------------------------------------------------------------------------
/examples/mnist/Build the model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "import argparse\n",
11 | "from argparse import Namespace\n",
12 | "import torch\n",
13 | "import sys\n",
14 | "\n",
15 | "sys.path.append('../../')\n",
16 | "from sopa.src.models.odenet_mnist.layers import MetaNODE"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 2,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "parser = argparse.ArgumentParser()\n",
26 | "parser.add_argument('--network', type=str, choices=['resnet', 'odenet'], default='odenet')\n",
27 | "parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res'])\n",
28 | "parser.add_argument('--activation', type=str, choices=['tanh', 'softplus', 'softsign', 'relu'], default='relu')\n",
29 | "parser.add_argument('--in_channels', type=int, default=1)\n",
30 | "args, unknown_args = parser.parse_known_args()"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": 3,
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "is_odenet = args.network == 'odenet'\n",
40 | "\n",
41 | "model = MetaNODE(downsampling_method=args.downsampling_method,\n",
42 | " is_odenet=is_odenet,\n",
43 | " activation_type=args.activation,\n",
44 | " in_channels=args.in_channels)"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": 4,
50 | "metadata": {},
51 | "outputs": [
52 | {
53 | "data": {
54 | "text/plain": [
55 | "MetaNODE(\n",
56 | " (downsampling_layers): Sequential(\n",
57 | " (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))\n",
58 | " (1): GroupNorm(32, 64, eps=1e-05, affine=True)\n",
59 | " (2): ReLU(inplace=True)\n",
60 | " (3): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
61 | " (4): GroupNorm(32, 64, eps=1e-05, affine=True)\n",
62 | " (5): ReLU(inplace=True)\n",
63 | " (6): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
64 | " )\n",
65 | " (fc_layers): Sequential(\n",
66 | " (0): GroupNorm(32, 64, eps=1e-05, affine=True)\n",
67 | " (1): ReLU(inplace=True)\n",
68 | " (2): AdaptiveAvgPool2d(output_size=(1, 1))\n",
69 | " (3): Flatten()\n",
70 | " (4): Linear(in_features=64, out_features=10, bias=True)\n",
71 | " )\n",
72 | " (blocks): ModuleList(\n",
73 | " (0): MetaODEBlock(\n",
74 | " (rhs_func): ODEfunc(\n",
75 | " (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)\n",
76 | " (relu): ReLU(inplace=True)\n",
77 | " (conv1): ConcatConv2d(\n",
78 | " (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
79 | " )\n",
80 | " (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)\n",
81 | " (conv2): ConcatConv2d(\n",
82 | " (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
83 | " )\n",
84 | " (norm3): GroupNorm(32, 64, eps=1e-05, affine=True)\n",
85 | " )\n",
86 | " )\n",
87 | " )\n",
88 | ")"
89 | ]
90 | },
91 | "execution_count": 4,
92 | "metadata": {},
93 | "output_type": "execute_result"
94 | }
95 | ],
96 | "source": [
97 | "model"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "metadata": {},
104 | "outputs": [],
105 | "source": []
106 | }
107 | ],
108 | "metadata": {
109 | "kernelspec": {
110 | "display_name": "Python 3",
111 | "language": "python",
112 | "name": "python3"
113 | },
114 | "language_info": {
115 | "codemirror_mode": {
116 | "name": "ipython",
117 | "version": 3
118 | },
119 | "file_extension": ".py",
120 | "mimetype": "text/x-python",
121 | "name": "python",
122 | "nbconvert_exporter": "python",
123 | "pygments_lexer": "ipython3",
124 | "version": "3.8.3"
125 | }
126 | },
127 | "nbformat": 4,
128 | "nbformat_minor": 4
129 | }
130 |
--------------------------------------------------------------------------------
/examples/mnist/Evaluate the model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "! python3 -m pip install wandb -q"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 2,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import os\n",
19 | "import argparse\n",
20 | "from argparse import Namespace\n",
21 | "import time\n",
22 | "import numpy as np\n",
23 | "import torch\n",
24 | "import torch.nn as nn\n",
25 | "import torch.optim as optim\n",
26 | "\n",
27 | "from decimal import Decimal\n",
28 | "import wandb\n",
29 | "import sys\n",
30 | "\n",
31 | "sys.path.append('../../')\n",
32 | "\n",
33 | "from sopa.src.solvers.utils import create_solver\n",
34 | "from sopa.src.models.utils import fix_seeds, RunningAverageMeter\n",
35 | "from sopa.src.models.odenet_mnist.layers import MetaNODE\n",
36 | "from sopa.src.models.odenet_mnist.utils import makedirs, learning_rate_with_decay\n",
37 | "from sopa.src.models.odenet_mnist.data import get_mnist_loaders, inf_generator\n",
38 | "from MegaAdversarial.src.attacks import (\n",
39 | " Clean,\n",
40 | " PGD,\n",
41 | " FGSM\n",
42 | ")"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": 3,
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "parser = argparse.ArgumentParser()\n",
52 | "parser.add_argument('--network', type=str, choices=['resnet', 'odenet'], default='odenet')\n",
53 | "parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res'])\n",
54 | "parser.add_argument('--activation', type=str, choices=['tanh', 'softplus', 'softsign', 'relu'], default='relu')\n",
55 | "parser.add_argument('--in_channels', type=int, default=1)\n",
56 | "\n",
57 | "parser.add_argument('--solvers',\n",
58 | " type=lambda s: [tuple(map(lambda iparam: str(iparam[1]) if iparam[0] <= 1 else (\n",
59 | " int(iparam[1]) if iparam[0] == 2 else (\n",
60 | " float(iparam[1]) if iparam[0] == 3 else Decimal(iparam[1]))),\n",
61 | " enumerate(item.split(',')))) for item in s.strip().split(';')],\n",
62 | " default=None,\n",
63 | " help='Each solver is represented with (method,parameterization,n_steps,step_size,u0,v0) \\n' +\n",
64 | " 'If the solver has only one parameter u0, set v0 to -1; \\n' +\n",
65 | " 'n_steps and step_size are exclusive parameters, only one of them can be != -1, \\n'\n",
66 | " 'If n_steps = step_size = -1, automatic time grid_constructor is used \\n;'\n",
67 | " 'For example, --solvers rk4,uv,2,-1,0.3,0.6;rk3,uv,-1,0.1,0.4,0.6;rk2,u,4,-1,0.3,-1')\n",
68 | "\n",
69 | "parser.add_argument('--solver_mode', type=str, choices=['switch', 'ensemble', 'standalone'], default='standalone')\n",
70 | "parser.add_argument('--val_solver_modes',\n",
71 | " type=lambda s: s.strip().split(','),\n",
72 | " default=['standalone'],\n",
73 | " help='Solver modes to use for validation step')\n",
74 | "\n",
75 | "parser.add_argument('--switch_probs', type=lambda s: [float(item) for item in s.split(',')], default=None,\n",
76 | " help=\"--switch_probs 0.8,0.1,0.1\")\n",
77 | "parser.add_argument('--ensemble_weights', type=lambda s: [float(item) for item in s.split(',')], default=None,\n",
78 | " help=\"ensemble_weights 0.6,0.2,0.2\")\n",
79 | "parser.add_argument('--ensemble_prob', type=float, default=1.)\n",
80 | "\n",
81 | "parser.add_argument('--noise_type', type=str, choices=['cauchy', 'normal'], default=None)\n",
82 | "parser.add_argument('--noise_sigma', type=float, default=0.001)\n",
83 | "parser.add_argument('--noise_prob', type=float, default=0.)\n",
84 | "parser.add_argument('--minimize_rk2_error', type=eval, default=False, choices=[True, False])\n",
85 | "\n",
86 | "parser.add_argument('--nepochs_nn', type=int, default=50)\n",
87 | "parser.add_argument('--nepochs_solver', type=int, default=0)\n",
88 | "parser.add_argument('--nstages', type=int, default=1)\n",
89 | "\n",
90 | "parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False])\n",
91 | "parser.add_argument('--lr', type=float, default=0.01)\n",
92 | "parser.add_argument('--weight_decay', type=float, default=0.0005)\n",
93 | "parser.add_argument('--batch_size', type=int, default=128)\n",
94 | "parser.add_argument('--test_batch_size', type=int, default=1000)\n",
95 | "parser.add_argument('--base_lr', type=float, default=1e-5, help='base_lr for CyclicLR scheduler')\n",
96 | "parser.add_argument('--max_lr', type=float, default=1e-3, help='max_lr for CyclicLR scheduler')\n",
97 | "parser.add_argument('--step_size_up', type=int, default=2000, help='step_size_up for CyclicLR scheduler')\n",
98 | "parser.add_argument('--cyclic_lr_mode', type=str, default='triangular2', help='mode for CyclicLR scheduler')\n",
99 | "parser.add_argument('--lr_uv', type=float, default=1e-3)\n",
100 | "parser.add_argument('--torch_dtype', type=str, default='float32')\n",
101 | "parser.add_argument('--wandb_name', type=str, default='find_best_solver')\n",
102 | "\n",
103 | "parser.add_argument('--data_root', type=str, default='./')\n",
104 | "parser.add_argument('--save_dir', type=str, default='./')\n",
105 | "parser.add_argument('--debug', action='store_true')\n",
106 | "parser.add_argument('--gpu', type=int, default=0)\n",
107 | "\n",
108 | "parser.add_argument('--seed', type=int, default=502)\n",
109 | "# Noise and adversarial attacks parameters:\n",
110 | "parser.add_argument('--data_noise_std', type=float, default=0.,\n",
111 | " help='Applies Norm(0, std) gaussian noise to each training batch')\n",
112 | "parser.add_argument('--eps_adv_training', type=float, default=0.3,\n",
113 | " help='Epsilon for adversarial training')\n",
114 | "parser.add_argument(\n",
115 | " \"--adv_training_mode\",\n",
116 | " default=\"clean\",\n",
117 | " choices=[\"clean\", \"fgsm\", \"at\"], # , \"at_ls\", \"av\", \"fs\", \"nce\", \"nce_moco\", \"moco\", \"av_extra\", \"meta\"],\n",
118 | " help='''Adverarial training method/mode, by default there is no adversarial training (clean).\n",
119 | " For further details see MegaAdversarial/train in this repository.\n",
120 | " '''\n",
121 | ")\n",
122 | "parser.add_argument('--use_wandb', type=eval, default=True, choices=[True, False])\n",
123 | "parser.add_argument('--ss_loss', type=eval, default=False, choices=[True, False])\n",
124 | "parser.add_argument('--ss_loss_reg', type=float, default=0.1)\n",
125 | "parser.add_argument('--timestamp', type=int, default=int(1e6 * time.time()))\n",
126 | "\n",
127 | "parser.add_argument('--eps_adv_testing', type=float, default=0.3,\n",
128 | " help='Epsilon for adversarial testing')\n",
129 | "parser.add_argument('--adv_testing_mode',\n",
130 | " default=\"clean\",\n",
131 | " choices=[\"clean\", \"fgsm\", \"at\"],\n",
132 | " help='''Adversarsarial testing mode''')\n",
133 | "\n",
134 | "args = parser.parse_args(['--solvers', 'rk4,u3,4,-1,0.3,-1', '--seed', '902', '--adv_testing_mode', 'at', \n",
135 | " '--max_lr', '0.001', '--base_lr', '1e-05'])"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": 4,
141 | "metadata": {},
142 | "outputs": [
143 | {
144 | "name": "stderr",
145 | "output_type": "stream",
146 | "text": [
147 | "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mtalgat\u001b[0m (use `wandb login --relogin` to force relogin)\n",
148 | "\u001b[34m\u001b[1mwandb\u001b[0m: wandb version 0.10.19 is available! To upgrade, please run:\n",
149 | "\u001b[34m\u001b[1mwandb\u001b[0m: $ pip install wandb --upgrade\n"
150 | ]
151 | },
152 | {
153 | "data": {
154 | "text/html": [
155 | "\n",
156 | " Tracking run with wandb version 0.10.7
\n",
157 | " Syncing run breezy-violet-42 to Weights & Biases (Documentation).
\n",
158 | " Project page: https://wandb.ai/sopa_node/find_best_solver
\n",
159 | " Run page: https://wandb.ai/sopa_node/find_best_solver/runs/16j2vx1d
\n",
160 | " Run data is saved locally in wandb/run-20210216_151114-16j2vx1d
\n",
161 | " "
162 | ],
163 | "text/plain": [
164 | ""
165 | ]
166 | },
167 | "metadata": {},
168 | "output_type": "display_data"
169 | }
170 | ],
171 | "source": [
172 | "makedirs(args.save_dir)\n",
173 | "if args.use_wandb:\n",
174 | " wandb.init(project=args.wandb_name, anonymous=\"allow\", entity=\"sopa_node\")\n",
175 | " wandb.config.update(args)\n",
176 | " wandb.config.update({'u': float(args.solvers[0][-2])})\n",
177 | " makedirs(wandb.config.save_dir)\n",
178 | " makedirs(os.path.join(wandb.config.save_dir, wandb.run.path))"
179 | ]
180 | },
181 | {
182 | "cell_type": "markdown",
183 | "metadata": {},
184 | "source": [
185 | "### Load the model"
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": 5,
191 | "metadata": {},
192 | "outputs": [],
193 | "source": [
194 | "# Load a checkpoint\n",
195 | "checkpoint_name = './checkpoints/checkpoint_15444.pth'\n",
196 | "device = f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu'\n",
197 | "model = torch.load(checkpoint_name, map_location=device)"
198 | ]
199 | },
200 | {
201 | "cell_type": "code",
202 | "execution_count": 6,
203 | "metadata": {},
204 | "outputs": [
205 | {
206 | "data": {
207 | "text/plain": [
208 | "MetaNODE(\n",
209 | " (downsampling_layers): Sequential(\n",
210 | " (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))\n",
211 | " (1): GroupNorm(32, 64, eps=1e-05, affine=True)\n",
212 | " (2): ReLU(inplace=True)\n",
213 | " (3): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
214 | " (4): GroupNorm(32, 64, eps=1e-05, affine=True)\n",
215 | " (5): ReLU(inplace=True)\n",
216 | " (6): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
217 | " )\n",
218 | " (fc_layers): Sequential(\n",
219 | " (0): GroupNorm(32, 64, eps=1e-05, affine=True)\n",
220 | " (1): ReLU(inplace=True)\n",
221 | " (2): AdaptiveAvgPool2d(output_size=(1, 1))\n",
222 | " (3): Flatten()\n",
223 | " (4): Linear(in_features=64, out_features=10, bias=True)\n",
224 | " )\n",
225 | " (blocks): ModuleList(\n",
226 | " (0): MetaODEBlock(\n",
227 | " (rhs_func): ODEfunc(\n",
228 | " (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)\n",
229 | " (relu): ReLU(inplace=True)\n",
230 | " (conv1): ConcatConv2d(\n",
231 | " (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
232 | " )\n",
233 | " (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)\n",
234 | " (conv2): ConcatConv2d(\n",
235 | " (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
236 | " )\n",
237 | " (norm3): GroupNorm(32, 64, eps=1e-05, affine=True)\n",
238 | " )\n",
239 | " )\n",
240 | " )\n",
241 | ")"
242 | ]
243 | },
244 | "execution_count": 6,
245 | "metadata": {},
246 | "output_type": "execute_result"
247 | }
248 | ],
249 | "source": [
250 | "model"
251 | ]
252 | },
253 | {
254 | "cell_type": "markdown",
255 | "metadata": {},
256 | "source": [
257 | "### Build a data loader"
258 | ]
259 | },
260 | {
261 | "cell_type": "code",
262 | "execution_count": 7,
263 | "metadata": {},
264 | "outputs": [],
265 | "source": [
266 | "train_loader, test_loader, train_eval_loader = get_mnist_loaders(args.data_aug,\n",
267 | " args.batch_size,\n",
268 | " args.test_batch_size,\n",
269 | " data_root=args.data_root)\n",
270 | "data_gen = inf_generator(train_loader)\n",
271 | "batches_per_epoch = len(train_loader)"
272 | ]
273 | },
274 | {
275 | "cell_type": "markdown",
276 | "metadata": {},
277 | "source": [
278 | "### Evaluate the model"
279 | ]
280 | },
281 | {
282 | "cell_type": "code",
283 | "execution_count": 8,
284 | "metadata": {},
285 | "outputs": [],
286 | "source": [
287 | "if args.torch_dtype == 'float64':\n",
288 | " dtype = torch.float64\n",
289 | "elif args.torch_dtype == 'float32':\n",
290 | " dtype = torch.float32\n",
291 | " \n",
292 | "solvers = [create_solver(*solver_params, dtype=dtype, device=device) for solver_params in args.solvers]\n",
293 | "for solver in solvers:\n",
294 | " solver.freeze_params()\n",
295 | " \n",
296 | "solver_options = Namespace(**{key: vars(args)[key] for key in ['solver_mode', 'switch_probs',\n",
297 | " 'ensemble_prob', 'ensemble_weights']})"
298 | ]
299 | },
300 | {
301 | "cell_type": "code",
302 | "execution_count": null,
303 | "metadata": {
304 | "scrolled": true
305 | },
306 | "outputs": [],
307 | "source": [
308 | "def one_hot(x, K):\n",
309 | " return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)\n",
310 | "\n",
311 | "def accuracy(model, dataset_loader, device, solvers=None, solver_options=None):\n",
312 | " model.eval()\n",
313 | " total_correct = 0\n",
314 | " for x, y in dataset_loader:\n",
315 | " x = x.to(device)\n",
316 | " y = one_hot(np.array(y.numpy()), 10)\n",
317 | " target_class = np.argmax(y, axis=1)\n",
318 | " with torch.no_grad():\n",
319 | " if solver is not None:\n",
320 | " out = model(x, solvers, solver_options).cpu().detach().numpy()\n",
321 | " else:\n",
322 | " out = model(x).cpu().detach().numpy()\n",
323 | " predicted_class = np.argmax(out, axis=1)\n",
324 | " total_correct += np.sum(predicted_class == target_class)\n",
325 | " return total_correct / len(dataset_loader.dataset)\n",
326 | "\n",
327 | "\n",
328 | "def adversarial_accuracy(model, dataset_loader, device, solvers=None, solver_options=None, args=None):\n",
329 | " model.eval()\n",
330 | " total_correct = 0\n",
331 | " if args.adv_testing_mode == \"clean\":\n",
332 | " test_attack = Clean(model)\n",
333 | " elif args.adv_testing_mode == \"fgsm\":\n",
334 | " test_attack = FGSM(model, mean=[0.], std=[1.], **CONFIG_FGSM_TEST)\n",
335 | " elif args.adv_testing_mode == \"at\":\n",
336 | " test_attack = PGD(model, mean=[0.], std=[1.], **CONFIG_PGD_TEST)\n",
337 | " else:\n",
338 | " raise ValueError(\"Attack type not understood.\")\n",
339 | " for x, y in dataset_loader:\n",
340 | " x, y = x.to(device), y.to(device)\n",
341 | " x, y = test_attack(x, y, {\"solvers\": solvers, \"solver_options\": solver_options})\n",
342 | " y = one_hot(np.array(y.cpu().numpy()), 10)\n",
343 | " target_class = np.argmax(y, axis=1)\n",
344 | " with torch.no_grad():\n",
345 | " if solver is not None:\n",
346 | " out = model(x, solvers, solver_options).cpu().detach().numpy()\n",
347 | " else:\n",
348 | " out = model(x).cpu().detach().numpy()\n",
349 | " predicted_class = np.argmax(out, axis=1)\n",
350 | " total_correct += np.sum(predicted_class == target_class)\n",
351 | " return total_correct / len(dataset_loader.dataset)"
352 | ]
353 | },
354 | {
355 | "cell_type": "code",
356 | "execution_count": null,
357 | "metadata": {},
358 | "outputs": [],
359 | "source": [
360 | "accuracy_test = accuracy(model, test_loader, device=device,\n",
361 | " solvers=solvers, solver_options=solver_options)\n",
362 | "accuracy_test"
363 | ]
364 | },
365 | {
366 | "cell_type": "code",
367 | "execution_count": null,
368 | "metadata": {},
369 | "outputs": [],
370 | "source": [
371 | "CONFIG_PGD_TEST = {\"eps\": 0.3, \"lr\": 2.0 / 255, \"n_iter\": 7}\n",
372 | "adv_accuracy_test = adversarial_accuracy(model, test_loader, device,\n",
373 | " solvers=solvers, solver_options=solver_options, args=args\n",
374 | " )\n",
375 | "adv_accuracy_test"
376 | ]
377 | },
378 | {
379 | "cell_type": "code",
380 | "execution_count": null,
381 | "metadata": {},
382 | "outputs": [],
383 | "source": []
384 | }
385 | ],
386 | "metadata": {
387 | "kernelspec": {
388 | "display_name": "Python 3",
389 | "language": "python",
390 | "name": "python3"
391 | },
392 | "language_info": {
393 | "codemirror_mode": {
394 | "name": "ipython",
395 | "version": 3
396 | },
397 | "file_extension": ".py",
398 | "mimetype": "text/x-python",
399 | "name": "python",
400 | "nbconvert_exporter": "python",
401 | "pygments_lexer": "ipython3",
402 | "version": "3.8.3"
403 | }
404 | },
405 | "nbformat": 4,
406 | "nbformat_minor": 4
407 | }
408 |
--------------------------------------------------------------------------------
/examples/mnist/assets/mnist_adv.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/mnist/assets/mnist_adv.pdf
--------------------------------------------------------------------------------
/examples/mnist/assets/mnist_adv.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/mnist/assets/mnist_adv.png
--------------------------------------------------------------------------------
/examples/mnist/checkpoints/checkpoint_15444.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/examples/mnist/checkpoints/checkpoint_15444.pth
--------------------------------------------------------------------------------
/examples/mnist/train_and_attack.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from argparse import Namespace
4 | import time
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | import copy
10 |
11 | from decimal import Decimal
12 | import wandb
13 | import sys
14 |
15 | sys.path.append('../../')
16 |
17 | from sopa.src.solvers.utils import create_solver
18 | from sopa.src.models.utils import fix_seeds, RunningAverageMeter
19 | from sopa.src.models.odenet_mnist.layers import MetaNODE
20 | from sopa.src.models.odenet_mnist.utils import makedirs, learning_rate_with_decay
21 | from sopa.src.models.odenet_mnist.data import get_mnist_loaders, inf_generator
22 | from MegaAdversarial.src.attacks import (
23 | Clean,
24 | PGD,
25 | FGSM
26 | )
27 |
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument('--network', type=str, choices=['resnet', 'odenet'], default='odenet')
30 | parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res'])
31 | parser.add_argument('--activation', type=str, choices=['tanh', 'softplus', 'softsign', 'relu'], default='relu')
32 | parser.add_argument('--in_channels', type=int, default=1)
33 |
34 | parser.add_argument('--solvers',
35 | type=lambda s: [tuple(map(lambda iparam: str(iparam[1]) if iparam[0] <= 1 else (
36 | int(iparam[1]) if iparam[0] == 2 else (
37 | float(iparam[1]) if iparam[0] == 3 else Decimal(iparam[1]))),
38 | enumerate(item.split(',')))) for item in s.strip().split(';')],
39 | default=None,
40 | help='Each solver is represented with (method,parameterization,n_steps,step_size,u0,v0) \n' +
41 | 'If the solver has only one parameter u0, set v0 to -1; \n' +
42 | 'n_steps and step_size are exclusive parameters, only one of them can be != -1, \n'
43 | 'If n_steps = step_size = -1, automatic time grid_constructor is used \n;'
44 | 'For example, --solvers rk4,uv,2,-1,0.3,0.6;rk3,uv,-1,0.1,0.4,0.6;rk2,u,4,-1,0.3,-1')
45 |
46 | parser.add_argument('--solver_mode', type=str, choices=['switch', 'ensemble', 'standalone'], default='standalone')
47 | parser.add_argument('--val_solver_modes',
48 | type=lambda s: s.strip().split(','),
49 | default=['standalone'],
50 | help='Solver modes to use for validation step')
51 |
52 | parser.add_argument('--switch_probs', type=lambda s: [float(item) for item in s.split(',')], default=None,
53 | help="--switch_probs 0.8,0.1,0.1")
54 | parser.add_argument('--ensemble_weights', type=lambda s: [float(item) for item in s.split(',')], default=None,
55 | help="ensemble_weights 0.6,0.2,0.2")
56 | parser.add_argument('--ensemble_prob', type=float, default=1.)
57 |
58 | parser.add_argument('--noise_type', type=str, choices=['cauchy', 'normal'], default=None)
59 | parser.add_argument('--noise_sigma', type=float, default=0.001)
60 | parser.add_argument('--noise_prob', type=float, default=0.)
61 | parser.add_argument('--minimize_rk2_error', type=eval, default=False, choices=[True, False])
62 |
63 | parser.add_argument('--nepochs_nn', type=int, default=50)
64 | parser.add_argument('--nepochs_solver', type=int, default=0)
65 | parser.add_argument('--nstages', type=int, default=1)
66 |
67 | parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False])
68 | parser.add_argument('--lr', type=float, default=0.01)
69 | parser.add_argument('--weight_decay', type=float, default=0.0005)
70 | parser.add_argument('--batch_size', type=int, default=128)
71 | parser.add_argument('--test_batch_size', type=int, default=1000)
72 | parser.add_argument('--base_lr', type=float, default=1e-5, help='base_lr for CyclicLR scheduler')
73 | parser.add_argument('--max_lr', type=float, default=1e-3, help='max_lr for CyclicLR scheduler')
74 | parser.add_argument('--step_size_up', type=int, default=2000, help='step_size_up for CyclicLR scheduler')
75 | parser.add_argument('--cyclic_lr_mode', type=str, default='triangular2', help='mode for CyclicLR scheduler')
76 | parser.add_argument('--lr_uv', type=float, default=1e-3)
77 | parser.add_argument('--torch_dtype', type=str, default='float32')
78 | parser.add_argument('--wandb_name', type=str, default='find_best_solver')
79 |
80 | parser.add_argument('--data_root', type=str, default='./')
81 | parser.add_argument('--save_dir', type=str, default='./')
82 | parser.add_argument('--debug', action='store_true')
83 | parser.add_argument('--gpu', type=int, default=0)
84 |
85 | parser.add_argument('--seed', type=int, default=502)
86 | # Noise and adversarial attacks parameters:
87 | parser.add_argument('--data_noise_std', type=float, default=0.,
88 | help='Applies Norm(0, std) gaussian noise to each training batch')
89 | parser.add_argument('--eps_adv_training', type=float, default=0.3,
90 | help='Epsilon for adversarial training')
91 | parser.add_argument(
92 | "--adv_training_mode",
93 | default="clean",
94 | choices=["clean", "fgsm", "at"], # , "at_ls", "av", "fs", "nce", "nce_moco", "moco", "av_extra", "meta"],
95 | help='''Adverarial training method/mode, by default there is no adversarial training (clean).
96 | For further details see MegaAdversarial/train in this repository.
97 | '''
98 | )
99 | parser.add_argument('--use_wandb', type=eval, default=True, choices=[True, False])
100 | parser.add_argument('--ss_loss', type=eval, default=False, choices=[True, False])
101 | parser.add_argument('--ss_loss_reg', type=float, default=0.1)
102 | parser.add_argument('--timestamp', type=int, default=int(1e6 * time.time()))
103 |
104 | parser.add_argument('--eps_adv_testing', type=float, default=0.3,
105 | help='Epsilon for adversarial testing')
106 | parser.add_argument('--adv_testing_mode',
107 | default="clean",
108 | choices=["clean", "fgsm", "at"],
109 | help='''Adversarsarial testing mode''')
110 |
111 | args = parser.parse_args()
112 |
113 | sys.path.append('../../')
114 |
115 | makedirs(args.save_dir)
116 | if args.use_wandb:
117 | wandb.init(project=args.wandb_name, entity="sopa_node", anonymous="allow")
118 | wandb.config.update(args)
119 | wandb.config.update({'u': float(args.solvers[0][-2])}) # a dirty way to extract u from the rk2 solver
120 | # Path to save checkpoints locally in /// [Julia style]
121 | makedirs(wandb.config.save_dir)
122 | makedirs(os.path.join(wandb.config.save_dir, wandb.run.path))
123 |
124 | if args.torch_dtype == 'float64':
125 | dtype = torch.float64
126 | elif args.torch_dtype == 'float32':
127 | dtype = torch.float32
128 | else:
129 | raise ValueError('torch_type should be either float64 or float32')
130 |
131 | # I've decided to copy and modify functions from src/models/odenet_mnist/train_validate.py
132 |
133 | CONFIG_PGD_TRAIN = {"eps": 0.3, "lr": 2.0 / 255, "n_iter": 7}
134 | CONFIG_FGSM_TRAIN = {"alpha": 0.3, "epsilon": 0.05}
135 |
136 |
137 | def one_hot(x, K):
138 | return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)
139 |
140 |
141 | def accuracy(model, dataset_loader, device, solvers=None, solver_options=None):
142 | model.eval()
143 | total_correct = 0
144 | for x, y in dataset_loader:
145 | x = x.to(device)
146 | y = one_hot(np.array(y.numpy()), 10)
147 | target_class = np.argmax(y, axis=1)
148 | with torch.no_grad():
149 | if solver is not None:
150 | out = model(x, solvers, solver_options).cpu().detach().numpy()
151 | else:
152 | out = model(x).cpu().detach().numpy()
153 | predicted_class = np.argmax(out, axis=1)
154 | total_correct += np.sum(predicted_class == target_class)
155 | return total_correct / len(dataset_loader.dataset)
156 |
157 |
158 | def adversarial_accuracy(model, dataset_loader, device, solvers=None, solver_options=None, args=None):
159 | model.eval()
160 | total_correct = 0
161 | if args.adv_testing_mode == "clean":
162 | test_attack = Clean(model)
163 | elif args.adv_testing_mode == "fgsm":
164 | test_attack = FGSM(model, mean=[0.], std=[1.], **CONFIG_FGSM_TRAIN)
165 | elif args.adv_testing_mode == "at":
166 | test_attack = PGD(model, mean=[0.], std=[1.], **CONFIG_PGD_TRAIN)
167 | else:
168 | raise ValueError("Attack type not understood.")
169 | for x, y in dataset_loader:
170 | x, y = x.to(device), y.to(device)
171 | x, y = test_attack(x, y, {"solvers": solvers, "solver_options": solver_options})
172 | y = one_hot(np.array(y.cpu().numpy()), 10)
173 | target_class = np.argmax(y, axis=1)
174 | with torch.no_grad():
175 | if solver is not None:
176 | out = model(x, solvers, solver_options).cpu().detach().numpy()
177 | else:
178 | out = model(x).cpu().detach().numpy()
179 | predicted_class = np.argmax(out, axis=1)
180 | total_correct += np.sum(predicted_class == target_class)
181 | return total_correct / len(dataset_loader.dataset)
182 |
183 |
184 | def train(model,
185 | data_gen,
186 | solvers,
187 | solver_options,
188 | criterion,
189 | optimizer,
190 | device,
191 | is_odenet=True,
192 | args=None):
193 | model.train()
194 | optimizer.zero_grad()
195 | x, y = data_gen.__next__()
196 | x = x.to(device)
197 | y = y.to(device)
198 |
199 | if args.adv_training_mode == "clean":
200 | train_attack = Clean(model)
201 | elif args.adv_training_mode == "fgsm":
202 | train_attack = FGSM(model, **CONFIG_FGSM_TRAIN)
203 | elif args.adv_training_mode == "at":
204 | train_attack = PGD(model, **CONFIG_PGD_TRAIN)
205 | else:
206 | raise ValueError("Attack type not understood.")
207 | x, y = train_attack(x, y, {"solvers": solvers, "solver_options": solver_options})
208 |
209 | # Add noise:
210 | if args.data_noise_std > 1e-12:
211 | with torch.no_grad():
212 | x = x + args.data_noise_std * torch.randn_like(x)
213 | ##### Forward pass
214 | if is_odenet:
215 | logits = model(x, solvers, solver_options, Namespace(ss_loss=args.ss_loss))
216 | else:
217 | logits = model(x)
218 |
219 | xentropy = criterion(logits, y)
220 | if args.ss_loss:
221 | ss_loss = model.get_ss_loss()
222 | loss = xentropy + args.ss_loss_reg * ss_loss
223 | else:
224 | ss_loss = 0.
225 | loss = xentropy
226 |
227 | loss.backward()
228 | optimizer.step()
229 | if args.ss_loss:
230 | return {'xentropy': xentropy.item(), 'ss_loss': ss_loss.item()}
231 | return {'xentropy': xentropy.item()}
232 |
233 |
234 | if __name__ == "__main__":
235 | print(f'CUDA is available: {torch.cuda.is_available()}')
236 | print(args.solvers)
237 | fix_seeds(args.seed)
238 |
239 | if args.torch_dtype == 'float64':
240 | dtype = torch.float64
241 | elif args.torch_dtype == 'float32':
242 | dtype = torch.float32
243 | else:
244 | raise ValueError('torch_type should be either float64 or float32')
245 |
246 | device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
247 |
248 | ########### Create train / val solvers
249 | print(args.solvers)
250 | train_solvers = [create_solver(*solver_params, dtype=dtype, device=device) for solver_params in args.solvers]
251 | for solver in train_solvers:
252 | solver.freeze_params()
253 |
254 | train_solver_options = Namespace(**{key: vars(args)[key] for key in ['solver_mode', 'switch_probs',
255 | 'ensemble_prob', 'ensemble_weights']})
256 |
257 | val_solver_modes = args.val_solver_modes
258 |
259 | ########## Build the model
260 | is_odenet = args.network == 'odenet'
261 |
262 | model = MetaNODE(downsampling_method=args.downsampling_method,
263 | is_odenet=is_odenet,
264 | activation_type=args.activation,
265 | in_channels=args.in_channels)
266 | model.to(device)
267 | if args.use_wandb:
268 | wandb.watch(model)
269 |
270 | ########### Create data loaders
271 | train_loader, test_loader, train_eval_loader = get_mnist_loaders(args.data_aug,
272 | args.batch_size,
273 | args.test_batch_size,
274 | data_root=args.data_root)
275 | data_gen = inf_generator(train_loader)
276 | batches_per_epoch = len(train_loader)
277 |
278 | ########### Create criterion and optimizer
279 |
280 | criterion = nn.CrossEntropyLoss().to(device)
281 | loss_options = Namespace(ss_loss=args.ss_loss)
282 |
283 | ##### We exchange the learning rate with a circular learning rate
284 |
285 | optimizer = optim.RMSprop([{"params": model.parameters(), 'lr': args.lr}, ], lr=args.lr,
286 | weight_decay=args.weight_decay)
287 | scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=args.base_lr,
288 | max_lr=args.max_lr, step_size_up=args.step_size_up, mode=args.cyclic_lr_mode)
289 |
290 | ########### Train the model
291 |
292 | for itr in range(args.nepochs_nn * batches_per_epoch):
293 |
294 | # for param_group in optimizer.param_groups:
295 | # param_group['lr'] = lr_fn(itr)
296 |
297 | train_loss = train(model,
298 | data_gen,
299 | solvers=train_solvers,
300 | solver_options=train_solver_options,
301 | criterion=criterion,
302 | optimizer=optimizer,
303 | device=device,
304 | is_odenet=is_odenet,
305 | args=args)
306 | scheduler.step()
307 |
308 | if itr % batches_per_epoch == 0:
309 | train_acc = accuracy(model, train_loader, device, solvers=train_solvers, solver_options=train_solver_options)
310 | test_acc = accuracy(model, test_loader, device, solvers=train_solvers, solver_options=train_solver_options)
311 | adv_test_acc = adversarial_accuracy(model, test_loader, device, solvers=train_solvers,
312 | solver_options=train_solver_options, args=args)
313 | adv_train_acc = adversarial_accuracy(model, train_loader, device, solvers=train_solvers,
314 | solver_options=train_solver_options, args=args)
315 |
316 | makedirs(os.path.join(wandb.config.save_dir, wandb.run.path))
317 | save_path = os.path.join(wandb.config.save_dir, wandb.run.path, "checkpoint_{}.pth".format(itr))
318 | print(save_path)
319 | torch.save(model, save_path)
320 | wandb.save(save_path)
321 |
322 | wandb.log({'train_acc': train_acc,
323 | 'test_acc': test_acc,
324 | 'adv_test_acc': adv_test_acc,
325 | 'adv_train_acc': adv_train_acc,
326 | 'train_loss': train_loss['xentropy']})
--------------------------------------------------------------------------------
/sopa/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/sopa/__init__.py
--------------------------------------------------------------------------------
/sopa/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/sopa/src/__init__.py
--------------------------------------------------------------------------------
/sopa/src/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/sopa/src/models/__init__.py
--------------------------------------------------------------------------------
/sopa/src/models/odenet_cifar10/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juliagusak/neural-ode-metasolver/a5ca6ae0c00d2a8da3a5f4b77ee20fb151674d22/sopa/src/models/odenet_cifar10/__init__.py
--------------------------------------------------------------------------------
/sopa/src/models/odenet_cifar10/data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 | import torchvision.datasets as datasets
4 | import torchvision.transforms as transforms
5 | from torch.utils.data.sampler import SubsetRandomSampler
6 | import numpy as np
7 |
8 | def get_cifar10_train_val_loaders(data_aug=False,
9 | batch_size=128,
10 | val_perc=0.1,
11 | data_root=None,
12 | num_workers=1,
13 | pin_memory=True,
14 | shuffle=True,
15 | random_seed=None,
16 | download=False):
17 | ''' Returns iterators through train/val CIFAR10 datasets.
18 |
19 | If using CUDA, num_workers should be set to 1 and pin_memory to True.
20 |
21 | :param data_aug: bool
22 | Whether to apply the data augmentation scheme. Only applied on the train split.
23 | :param batch_size: int
24 | How many samples per batch to load.
25 | :param val_perc: float
26 | Percentage split of the training set used for the validation set. Should be a float in the range [0, 1].
27 | :param data_root: str
28 | Path to the directory with the dataset.
29 | :param num_workers: int
30 | Number of subprocesses to use when loading the dataset.
31 | :param pin_memory: bool
32 | Whether to copy tensors into CUDA pinned memory. Set it to True if using GPU.
33 | :param shuffle: bool
34 | Whether to shuffle the train/validation indices
35 | :param random_seed: int
36 | Fix seed for reproducibility.
37 | :return:
38 | '''
39 |
40 | if data_aug:
41 | transform_train = transforms.Compose([
42 | transforms.RandomCrop(32, padding=4),
43 | transforms.RandomHorizontalFlip(),
44 | transforms.ToTensor(),
45 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
46 | ])
47 | else:
48 | transform_train = transforms.Compose([
49 | transforms.RandomHorizontalFlip(),
50 | transforms.ToTensor(),
51 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
52 | ])
53 |
54 | transform_test = transforms.Compose([
55 | transforms.ToTensor(),
56 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
57 | ])
58 |
59 | train_dataset = datasets.CIFAR10(root=data_root, train=True, download=download, transform=transform_train)
60 | val_dataset = datasets.CIFAR10(root=data_root, train=True, download=download, transform=transform_test)
61 |
62 | num_train = len(train_dataset)
63 | indices = list(range(num_train))
64 | split = int(np.floor(val_perc * num_train))
65 |
66 | if shuffle:
67 | np.random.seed(random_seed)
68 | np.random.shuffle(indices)
69 |
70 | train_idx, val_idx = indices[split:], indices[:split]
71 | train_sampler = SubsetRandomSampler(train_idx)
72 | val_sampler = SubsetRandomSampler(val_idx)
73 |
74 | train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers,
75 | drop_last=True, pin_memory=pin_memory,)
76 | val_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=val_sampler, num_workers=num_workers,
77 | drop_last=True, pin_memory=pin_memory,)
78 |
79 | return train_loader, val_loader
80 |
81 |
82 | def get_cifar10_test_loader(batch_size=128, data_root=None, num_workers=1, pin_memory=True, shuffle=False, download=False):
83 | ''' Returns iterator through CIFAR10 test dataset
84 |
85 | If using CUDA, num_workers should be set to 1 and pin_memory to True.
86 |
87 | :param batch_size: int
88 | How many samples per batch to load.
89 | :param data_root: str
90 | Path to the directory with the dataset.
91 | :param num_workers: int
92 | Number of subprocesses to use when loading the dataset.
93 | :param pin_memory: bool
94 | Whether to copy tensors into CUDA pinned memory. Set it to True if using GPU.
95 | :param shuffle: bool
96 | Whether to shuffle the dataset after every epoch.
97 | :return:
98 | '''
99 | transform_test = transforms.Compose([
100 | transforms.ToTensor(),
101 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
102 | ])
103 | test_dataset = datasets.CIFAR10(root=data_root, train=False, download=download, transform=transform_test)
104 |
105 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
106 | drop_last=True, pin_memory=pin_memory)
107 | return test_loader
108 |
109 |
110 | def inf_generator(iterable):
111 | """Allows training with DataLoaders in a single infinite loop:
112 | for i, (x, y) in enumerate(inf_generator(train_loader)):
113 | """
114 | iterator = iterable.__iter__()
115 | while True:
116 | try:
117 | yield iterator.__next__()
118 | except StopIteration:
119 | iterator = iterable.__iter__()
120 |
121 |
--------------------------------------------------------------------------------
/sopa/src/models/odenet_cifar10/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import math
6 |
7 | import numpy as np
8 | from copy import deepcopy
9 |
10 | __all__ = ['MetaNODE', 'metanode4', 'metanode6', 'metanode10', 'metanode18', 'metanode34',
11 | 'premetanode4', 'premetanode6', 'premetanode10', 'premetanode18', 'premetanode34']
12 |
13 | class Flatten(nn.Module):
14 | def __init__(self):
15 | super(Flatten, self).__init__()
16 |
17 | def forward(self, x):
18 | shape = torch.prod(torch.tensor(x.shape[1:])).item()
19 | return x.view(-1, shape)
20 |
21 |
22 | class BasicBlock(nn.Module):
23 | '''Standard ResNet block
24 | '''
25 | expansion = 1
26 |
27 | def __init__(self, in_planes, planes, stride=1,\
28 | norm_layer=None, act_layer=None, param_norm=lambda x: x
29 | ):
30 | super(BasicBlock, self).__init__()
31 | self.conv1 = param_norm(nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False))
32 | self.bn1 = norm_layer(planes)
33 |
34 | self.conv2 = param_norm(nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False))
35 | self.bn2 = norm_layer(planes)
36 |
37 | self.act = act_layer
38 |
39 | self.shortcut = nn.Sequential()
40 | if stride != 1 or in_planes != self.expansion * planes:
41 | self.shortcut = nn.Sequential(
42 | param_norm(nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)),
43 | norm_layer(self.expansion * planes)
44 | )
45 |
46 | def forward(self, x):
47 | out = self.act(self.bn1(self.conv1(x)))
48 | out = self.bn2(self.conv2(out))
49 | out += self.shortcut(x)
50 | out = self.act(out)
51 | return out
52 |
53 |
54 | class PreBasicBlock(nn.Module):
55 | '''Standard PreResNet block
56 | '''
57 | expansion = 1
58 |
59 | def __init__(self, in_planes, planes, stride=1, \
60 | norm_layer=None, act_layer=None, param_norm=lambda x: x
61 | ):
62 | super(PreBasicBlock, self).__init__()
63 | self.bn1 = norm_layer(in_planes)
64 | self.conv1 = param_norm(nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False))
65 |
66 | self.bn2 = norm_layer(planes)
67 | self.conv2 = param_norm(nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False))
68 |
69 | self.act = act_layer
70 |
71 | self.shortcut = nn.Sequential()
72 | if stride != 1 or in_planes != self.expansion * planes:
73 | self.shortcut = nn.Sequential(
74 | param_norm(nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)),
75 | )
76 |
77 | def forward(self, x):
78 | out = self.conv1(self.act(self.bn1(x)))
79 | out = self.conv2(self.act(self.bn2(out)))
80 | out += self.shortcut(x)
81 | return out
82 |
83 |
84 | class BasicBlock2(nn.Module):
85 | '''Odefunc to use inside MetaODEBlock
86 | '''
87 | expansion = 1
88 |
89 | def __init__(self, dim,
90 | norm_layer=None, act_layer=None, param_norm=lambda x: x):
91 | super(BasicBlock2, self).__init__()
92 | in_planes = dim
93 | planes = dim
94 | stride = 1
95 | self.nfe = 0
96 |
97 | self.conv1 = param_norm(nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False))
98 | # Replace BN to GN because BN doesn't work with our method normaly
99 | self.bn1 = norm_layer(planes)
100 |
101 | self.conv2 = param_norm(nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False))
102 | self.bn2 = norm_layer(planes)
103 |
104 | self.act = act_layer
105 |
106 | self.shortcut = nn.Sequential()
107 |
108 | def forward(self, t, x, ss_loss = False):
109 | self.nfe += 1
110 | if isinstance(x, tuple):
111 | x = x[0]
112 | out = self.conv1(x)
113 | out = self.bn1(out)
114 | out = self.act(out)
115 | out = self.conv2(out)
116 | out = self.bn2(out)
117 | out = self.act(out)
118 |
119 | if ss_loss:
120 | out = torch.abs(out)
121 | return out
122 |
123 |
124 | class PreBasicBlock2(nn.Module):
125 | '''Odefunc to use inside MetaODEBlock
126 | '''
127 | expansion = 1
128 |
129 | def __init__(self, dim,
130 | norm_layer=None, act_layer=None, param_norm=lambda x: x):
131 | super(PreBasicBlock2, self).__init__()
132 | in_planes = dim
133 | planes = dim
134 | stride = 1
135 | self.nfe = 0
136 |
137 | # Replace BN to GN because BN doesn't work with our method normaly
138 | self.bn1 = norm_layer(in_planes)
139 | self.conv1 = param_norm(nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False))
140 |
141 | self.bn2 = norm_layer(planes)
142 | self.conv2 = param_norm(nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False))
143 |
144 | self.act = act_layer
145 |
146 | self.shortcut = nn.Sequential()
147 |
148 | def forward(self, t, x, ss_loss=False):
149 | self.nfe += 1
150 | if isinstance(x, tuple):
151 | x = x[0]
152 | out = self.bn1(x)
153 | out = self.act(out)
154 | out = self.conv1(out)
155 | out = self.bn2(out)
156 | out = self.act(out)
157 | out = self.conv2(out)
158 |
159 | if ss_loss:
160 | out = torch.abs(out)
161 | return out
162 |
163 |
164 | class MetaODEBlock(nn.Module):
165 | '''The same as MetaODEBlock for MNIST. Only difference is that odefunc is passed as an keyword argument'''
166 | def __init__(self, odefunc=None):
167 | super(MetaODEBlock, self).__init__()
168 |
169 | self.rhs_func = odefunc
170 | self.integration_time = torch.tensor([0, 1]).float()
171 |
172 |
173 | def forward(self, x, solvers, solver_options):
174 | nsolvers = len(solvers)
175 |
176 | if solver_options.solver_mode == 'standalone':
177 | y = solvers[0].integrate(self.rhs_func, x = x, t = self.integration_time)
178 |
179 | elif solver_options.solver_mode == 'switch':
180 | if solver_options.switch_probs is not None:
181 | switch_probs = solver_options.switch_probs
182 | else:
183 | switch_probs = [1./nsolvers for _ in range(nsolvers)]
184 | solver_id = np.random.choice(range(nsolvers), p = switch_probs)
185 | solver_options.switch_solver_id = solver_id
186 |
187 | y = solvers[solver_id].integrate(self.rhs_func, x = x, t = self.integration_time)
188 |
189 | elif solver_options.solver_mode == 'ensemble':
190 | coin_flip = torch.bernoulli(torch.tensor((1,)), solver_options.ensemble_prob)
191 | solver_options.ensemble_coin_flip = coin_flip
192 |
193 | if coin_flip :
194 | if solver_options.ensemble_weights is not None:
195 | ensemble_weights = solver_options.ensemble_weights
196 | else:
197 | ensemble_weights = [1./nsolvers for _ in range(nsolvers)]
198 |
199 | for i, (wi, solver) in enumerate(zip(ensemble_weights, solvers)):
200 | if i == 0:
201 | y = wi * solver.integrate(self.rhs_func, x = x, t = self.integration_time)
202 | else:
203 | y += wi * solver.integrate(self.rhs_func, x = x, t = self.integration_time)
204 | else:
205 | y = solvers[0].integrate(self.rhs_func, x = x, t = self.integration_time)
206 |
207 | return y[-1,:,:,:,:]
208 |
209 | def ss_loss(self, y, solvers, solver_options):
210 | z0 = y
211 | rhs_func_ss = partial(self.rhs_func, ss_loss = True)
212 | integration_time_ss = self.integration_time + 1
213 |
214 | nsolvers = len(solvers)
215 |
216 | if solver_options.solver_mode == 'standalone':
217 | z = solvers[0].integrate(rhs_func_ss.func, x = y, t = integration_time_ss)
218 |
219 | elif solver_options.solver_mode == 'switch':
220 | if solver_options.switch_probs is not None:
221 | switch_probs = solver_options.switch_probs
222 | else:
223 | switch_probs = [1./nsolvers for _ in range(nsolvers)]
224 | solver_id = solver_options.switch_solver_id
225 |
226 | z = solvers[solver_id].integrate(rhs_func_ss.func, x = y, t = integration_time_ss)
227 |
228 | elif solver_options.solver_mode == 'ensemble':
229 | coin_flip = solver_options.ensemble_coin_flip
230 |
231 | if coin_flip :
232 | if solver_options.ensemble_weights is not None:
233 | ensemble_weights = solver_options.ensemble_weights
234 | else:
235 | ensemble_weights = [1./nsolvers for _ in range(nsolvers)]
236 |
237 | for i, (wi, solver) in enumerate(zip(ensemble_weights, solvers)):
238 | if i == 0:
239 | z = wi * solver.integrate(rhs_func_ss.func, x = y, t = integration_time_ss)
240 | else:
241 | z += wi * solver.integrate(rhs_func_ss.func, x = y, t = integration_time_ss)
242 | else:
243 | z = solvers[0].integrate(rhs_func_ss.func, x = y, t = integration_time_ss)
244 |
245 | z = z[-1,:,:,:,:] - z0
246 | z = torch.norm(z.reshape((z.shape[0], -1)), dim = 1)
247 | z = torch.mean(z)
248 |
249 | return z
250 |
251 |
252 | class MetaLayer(nn.Module):
253 | '''
254 | norm_layers_: tuple of normalization layers for (BasicBlock, BasicBlock2, bn1)
255 | param_norm_layers_: tuple of normalizations for weights in (BasicBlock, BasicBlock2, conv1)
256 | act_layers_: tuple of activation layers for (BasicBlock, BasicBlock2, activation after bn1)
257 | resblock: BasicBlock or PreBasicBlock
258 | odefunc: BasicBlock2 or PreBasicBlock2
259 |
260 | '''
261 | def __init__(self, planes, num_blocks, stride, norm_layers_, param_norm_layers_, act_layers_,
262 | in_planes, resblock=None, odefunc=None):
263 | super(MetaLayer, self).__init__()
264 |
265 | num_resblocks, num_odeblocks = num_blocks
266 |
267 | strides = [stride] + [1] * (num_resblocks + num_odeblocks - 1)
268 | layers_res = []
269 | layers_ode = []
270 |
271 | self.in_planes = in_planes
272 | for stride in strides[:num_resblocks]:
273 | layers_res.append(resblock(self.in_planes, planes, stride,
274 | norm_layer = norm_layers_[0],
275 | param_norm=param_norm_layers_[0],
276 | act_layer = act_layers_[0]))
277 | self.in_planes = planes * resblock.expansion
278 |
279 | for stride in strides[num_resblocks:]:
280 | layers_ode.append(MetaODEBlock(odefunc(self.in_planes,
281 | norm_layer=norm_layers_[1],
282 | param_norm=param_norm_layers_[1],
283 | act_layer=act_layers_[1])))
284 |
285 | self.blocks_res = nn.Sequential(*layers_res)
286 | self.blocks_ode = nn.ModuleList(layers_ode)
287 |
288 |
289 | def forward(self, x, solvers=None, solver_options=None, loss_options = None):
290 | x = self.blocks_res(x)
291 |
292 | self.ss_loss = 0
293 |
294 | for block in self.blocks_ode:
295 | x = block(x, solvers, solver_options)
296 |
297 | if (loss_options is not None) and loss_options.ss_loss:
298 | z = block.ss_loss(x, solvers, solver_options)
299 | self.ss_loss += z
300 |
301 | return x
302 |
303 | def get_ss_loss(self):
304 | return self.ss_loss
305 |
306 | @property
307 | def nfe(self):
308 | per_block_nfe = {idx: block.nfe for idx, block in enumerate(self.blocks_ode)}
309 | return sum(per_block_nfe)
310 |
311 | @nfe.setter
312 | def nfe(self, value):
313 | for block in self.blocks_ode:
314 | block.nfe = value
315 |
316 |
317 | class MetaNODE(nn.Module):
318 | def __init__(self, num_blocks, num_classes=10,
319 | norm_layers_=(None, None, None),
320 | param_norm_layers_=(lambda x: x, lambda x: x, lambda x: x),
321 | act_layers_=(None, None, None),
322 | in_planes_=64,
323 | resblock=None,
324 | odefunc=None):
325 | '''
326 | norm_layers_: tuple of normalization layers for (BasicBlock, BasicBlock2, bn1)
327 | param_norm_layers_: tuple of normalizations for weights in (BasicBlock, BasicBlock2, conv1)
328 | act_layers_: tuple of activation layers for (BasicBlock, BasicBlock2, activation after bn1)
329 | resblock: BasicBlock or PreBasicBlock
330 | odefunc: BasicBlock2 or PreBasicBlock2
331 | '''
332 |
333 | super(MetaNODE, self).__init__()
334 | self.in_planes = in_planes_
335 |
336 | self.n_layers = len(num_blocks)
337 | self.n_features_linear = in_planes_
338 |
339 | self.is_preactivation = False
340 | if (((resblock is not None) and isinstance(resblock, PreBasicBlock)) or
341 | ((odefunc is not None) and isinstance(odefunc, PreBasicBlock2))):
342 | self.is_preactivation = True
343 |
344 | self.conv1 = param_norm_layers_[2](nn.Conv2d(3, in_planes_, kernel_size=3, stride=1, padding=1, bias=False))
345 | self.bn1 = norm_layers_[2](in_planes_)
346 | self.act = act_layers_[2]
347 |
348 | self.layer1 = MetaLayer(in_planes_, num_blocks[0], stride=1,
349 | norm_layers_=norm_layers_[:2],
350 | param_norm_layers_=param_norm_layers_[:2],
351 | act_layers_=act_layers_[:2],
352 | in_planes=self.in_planes,
353 | resblock=resblock,
354 | odefunc=odefunc
355 | )
356 |
357 | if self.n_layers >= 2:
358 | self.n_features_linear *= 2
359 | self.layer2 = MetaLayer(in_planes_*2, num_blocks[1], stride=2,
360 | norm_layers_=norm_layers_[:2],
361 | param_norm_layers_=param_norm_layers_[:2],
362 | act_layers_=act_layers_[:2],
363 | in_planes=self.layer1.in_planes,
364 | resblock=resblock,
365 | odefunc=odefunc
366 | )
367 |
368 | if self.n_layers >= 3:
369 | self.n_features_linear *= 2
370 | self.layer3 = MetaLayer(in_planes_*4, num_blocks[2], stride=2,
371 | norm_layers_=norm_layers_[:2],
372 | param_norm_layers_=param_norm_layers_[:2],
373 | act_layers_=act_layers_[:2],
374 | in_planes=self.layer2.in_planes,
375 | resblock=resblock,
376 | odefunc=odefunc
377 | )
378 |
379 | if self.n_layers >= 4:
380 | self.n_features_linear *= 2
381 | self.layer4 = MetaLayer(in_planes_*8, num_blocks[3], stride=2,
382 | norm_layers_=norm_layers_[:2],
383 | param_norm_layers_=param_norm_layers_[:2],
384 | act_layers_=act_layers_[:2],
385 | in_planes=self.layer3.in_planes,
386 | resblock=resblock,
387 | odefunc=odefunc
388 | )
389 |
390 | self.fc_layers = nn.Sequential(*[nn.AdaptiveAvgPool2d((1, 1)),
391 | Flatten(),
392 | nn.Linear(self.n_features_linear * resblock.expansion, num_classes)])
393 | self.nfe = 0
394 |
395 | @property
396 | def nfe(self):
397 | nfe = 0
398 | for idx in range(1, self.n_layers+1):
399 | nfe += self.__getattr__('layer{}'.format(idx)).nfe
400 | return nfe
401 |
402 | @nfe.setter
403 | def nfe(self, value):
404 | for idx in range(1, self.n_layers+1):
405 | self.__getattr__('layer{}'.format(idx)).nfe = value
406 |
407 |
408 | def forward(self, x, solvers=None, solver_options=None, loss_options = None):
409 | self.ss_loss = 0
410 |
411 | out = self.conv1(x)
412 | if not self.is_preactivation:
413 | out = self.act(self.bn1(out))
414 |
415 | for idx in range(1, self.n_layers + 1):
416 | out = self.__getattr__('layer{}'.format(idx))(out,
417 | solvers=solvers, solver_options=solver_options,
418 | loss_options = loss_options)
419 |
420 | self.ss_loss += self.__getattr__('layer{}'.format(idx)).ss_loss
421 |
422 | if self.is_preactivation:
423 | out = self.act(self.bn1(out))
424 |
425 | out = self.fc_layers(out)
426 | return out
427 |
428 |
429 | def metanode4(norm_layers, param_norm_layers, act_layers, in_planes, is_odenet = True):
430 | if is_odenet:
431 | num_blocks = [(0, 1)]
432 | else:
433 | num_blocks = [(1, 0)]
434 | return MetaNODE(num_blocks,
435 | norm_layers_= norm_layers, param_norm_layers_=param_norm_layers, act_layers_=act_layers,
436 | in_planes_ = in_planes,
437 | resblock=BasicBlock,
438 | odefunc=BasicBlock2
439 | )
440 |
441 |
442 | def metanode6(norm_layers, param_norm_layers, act_layers, in_planes, is_odenet = True):
443 | if is_odenet:
444 | num_blocks = [(1, 1)]
445 | else:
446 | num_blocks = [(2, 0)]
447 | return MetaNODE(num_blocks,
448 | norm_layers_= norm_layers, param_norm_layers_=param_norm_layers, act_layers_=act_layers,
449 | in_planes_ = in_planes,
450 | resblock=BasicBlock,
451 | odefunc=BasicBlock2
452 | )
453 |
454 |
455 | def metanode10(norm_layers, param_norm_layers, act_layers, in_planes, is_odenet = True):
456 | if is_odenet:
457 | num_blocks = [(1, 1), (1, 1)]
458 | else:
459 | num_blocks = [(2, 0), (2, 0)]
460 | return MetaNODE(num_blocks,
461 | norm_layers_= norm_layers, param_norm_layers_=param_norm_layers, act_layers_=act_layers,
462 | in_planes_ = in_planes,
463 | resblock=BasicBlock,
464 | odefunc=BasicBlock2
465 | )
466 |
467 |
468 | def metanode18(norm_layers, param_norm_layers, act_layers, in_planes, is_odenet = True):
469 | if is_odenet:
470 | num_blocks = [(1, 1), (1, 1), (1, 1), (1, 1)]
471 | else:
472 | num_blocks = [(2, 0), (2, 0), (2, 0), (2, 0)]
473 | return MetaNODE(num_blocks,
474 | norm_layers_= norm_layers, param_norm_layers_=param_norm_layers, act_layers_=act_layers,
475 | in_planes_ = in_planes,
476 | resblock=BasicBlock,
477 | odefunc=BasicBlock2
478 | )
479 |
480 |
481 | def metanode34(norm_layers, param_norm_layers, act_layers, in_planes, is_odenet = True):
482 | if is_odenet:
483 | num_blocks = [(1, 2), (1, 3), (1, 5), (1, 2)]
484 | else:
485 | num_blocks = [(3, 0), (4, 0), (6, 0), (3, 0)]
486 | return MetaNODE(num_blocks,
487 | norm_layers_= norm_layers, param_norm_layers_=param_norm_layers, act_layers_=act_layers,
488 | in_planes_ = in_planes,
489 | resblock=BasicBlock,
490 | odefunc=BasicBlock2
491 | )
492 |
493 |
494 | def premetanode4(norm_layers, param_norm_layers, act_layers, in_planes, is_odenet=True):
495 | if is_odenet:
496 | num_blocks = [(0, 1)]
497 | else:
498 | num_blocks = [(1, 0)]
499 | return MetaNODE(num_blocks,
500 | norm_layers_=norm_layers, param_norm_layers_=param_norm_layers, act_layers_=act_layers,
501 | in_planes_=in_planes,
502 | resblock=PreBasicBlock,
503 | odefunc=PreBasicBlock2
504 | )
505 |
506 |
507 | def premetanode6(norm_layers, param_norm_layers, act_layers, in_planes, is_odenet=True):
508 | if is_odenet:
509 | num_blocks = [(1, 1)]
510 | else:
511 | num_blocks = [(2, 0)]
512 | return MetaNODE(num_blocks,
513 | norm_layers_=norm_layers, param_norm_layers_=param_norm_layers, act_layers_=act_layers,
514 | in_planes_=in_planes,
515 | resblock=PreBasicBlock,
516 | odefunc=PreBasicBlock2
517 | )
518 |
519 |
520 | def premetanode10(norm_layers, param_norm_layers, act_layers, in_planes, is_odenet=True):
521 | if is_odenet:
522 | num_blocks = [(1, 1), (1, 1)]
523 | else:
524 | num_blocks = [(2, 0), (2, 0)]
525 | return MetaNODE(num_blocks,
526 | norm_layers_=norm_layers, param_norm_layers_=param_norm_layers, act_layers_=act_layers,
527 | in_planes_=in_planes,
528 | resblock=PreBasicBlock,
529 | odefunc=PreBasicBlock2
530 | )
531 |
532 |
533 | def premetanode18(norm_layers, param_norm_layers, act_layers, in_planes, is_odenet=True):
534 | if is_odenet:
535 | num_blocks = [(1, 1), (1, 1), (1, 1), (1, 1)]
536 | else:
537 | num_blocks = [(2, 0), (2, 0), (2, 0), (2, 0)]
538 | return MetaNODE(num_blocks,
539 | norm_layers_=norm_layers, param_norm_layers_=param_norm_layers, act_layers_=act_layers,
540 | in_planes_=in_planes,
541 | resblock=PreBasicBlock,
542 | odefunc=PreBasicBlock2
543 | )
544 |
545 |
546 | def premetanode34(norm_layers, param_norm_layers, act_layers, in_planes, is_odenet=True):
547 | if is_odenet:
548 | num_blocks = [(1, 2), (1, 3), (1, 5), (1, 2)]
549 | else:
550 | num_blocks = [(3, 0), (4, 0), (6, 0), (3, 0)]
551 | return MetaNODE(num_blocks,
552 | norm_layers_=norm_layers, param_norm_layers_=param_norm_layers, act_layers_=act_layers,
553 | in_planes_=in_planes,
554 | resblock=PreBasicBlock,
555 | odefunc=PreBasicBlock2
556 | )
557 |
--------------------------------------------------------------------------------
/sopa/src/models/odenet_cifar10/utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch.nn.init as init
4 | from torch.nn.utils import spectral_norm, weight_norm
5 | from functools import partial
6 |
7 |
8 | class Identity(nn.Module):
9 | def __init__(self, *args, **kwargs):
10 | super().__init__()
11 |
12 | def forward(self, x):
13 | return x
14 |
15 | def get_normalization(key, num_groups=32):
16 | '''Create a normalization layer given name of layer type.
17 |
18 | :param key: str
19 | Type of normalization layer to use after convolutional layer (e.g. type of layer output normalization).
20 | Can be one of BN (batch normalization), GN (group normalization), LN (layer normalization),
21 | IN (instance normalization), NF(no normalization)
22 | :param num_groups: int
23 | Number of groups for GN normalization
24 | :return: nn.Module
25 | Normalization layer
26 | '''
27 | if key == 'BN':
28 | return nn.BatchNorm2d
29 | elif key == 'LN':
30 | return partial(nn.GroupNorm, 1)
31 | elif key == 'GN':
32 | return partial(nn.GroupNorm, num_groups)
33 | elif key == 'IN':
34 | return nn.InstanceNorm2d
35 | elif key == 'NF':
36 | return Identity
37 | else:
38 | raise NameError('Unknown layer normalization type')
39 |
40 | def get_param_normalization(key):
41 | '''Create a function to normalize layer weights given name of normalization.
42 |
43 | :param key: str
44 | Type of normalization applied to layer's weights.
45 | Can be one of SN (spectral normalization), WN (weight normalization), PNF (no weight normalization).
46 | :return: function
47 | '''
48 | if key == 'SN':
49 | return spectral_norm
50 | elif key == 'WN':
51 | return weight_norm
52 | elif key == 'PNF':
53 | return lambda x: x
54 | else:
55 | raise NameError('Unknown param normalization type')
56 |
57 | def get_activation(key):
58 | '''Create an activation function given name of function type.
59 |
60 | :param key: str
61 | Type of activation layer.
62 | Can be one of: ReLU, AF (no activation/linear activation)
63 | :return: function
64 | '''
65 | if key == 'ReLU':
66 | return F.relu
67 | elif key == 'GeLU':
68 | return F.gelu
69 | elif key == 'Softsign':
70 | return F.softsign
71 | elif key == 'Tanh':
72 | return F.tanh
73 | elif key == 'AF':
74 | return partial(F.leaky_relu, negative_slope=1)
75 | else:
76 | raise NameError('Unknown activation type')
77 |
78 | def conv_init(m):
79 | class_name = m.__class__.__name__
80 | if class_name.find('Conv') != -1 and m.bias is not None:
81 | init.xavier_uniform_(m.weight, gain=np.sqrt(2))
82 | init.constant_(m.bias, 0)
83 | elif class_name.find('BatchNorm') != -1:
84 | init.constant_(m.weight, 1)
85 | init.constant_(m.bias, 0)
86 |
87 | def conv_init_orthogonal(m):
88 | if isinstance(m, nn.Conv2d):
89 | init.orthogonal_(m.weight)
90 |
91 | def fc_init_orthogonal(m):
92 | if isinstance(m, nn.Linear):
93 | init.orthogonal_(m.weight)
94 | init.constant_(m.bias, 1e-3)
95 |
96 |
--------------------------------------------------------------------------------
/sopa/src/models/odenet_mnist/attacks_runner.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.transforms as transforms
4 | from torch.utils.data import DataLoader
5 | import torchvision.datasets as datasets
6 |
7 |
8 | import numpy as np
9 | import glob
10 | import pandas as pd
11 | from collections import defaultdict
12 | from argparse import Namespace
13 |
14 | import os
15 | import sys
16 | sys.path.append('/workspace/home/jgusak/neural-ode-sopa')
17 |
18 | import dataloaders
19 |
20 | from sopa.src.models.odenet_mnist.layers import MetaNODE
21 | from sopa.src.solvers.utils import create_solver
22 |
23 | from sopa.src.models.utils import load_model
24 | from sopa.src.models.odenet_mnist.attacks_utils import run_attack
25 |
26 | from MegaAdversarial.src.utils.runner import fix_seeds
27 |
28 | import argparse
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument('--data_root', type=str, default='')
31 | parser.add_argument('--models_root', type=str, default='')
32 | parser.add_argument('--save_path', type=str, default='')
33 | parser.add_argument('--key_path_word', type=str, default='')
34 |
35 | parser.add_argument('--min_eps', type=float, default=0.)
36 | parser.add_argument('--max_eps', type=float, default=0.3)
37 | parser.add_argument('--num_eps', type=int, default=20)
38 | parser.add_argument('--epsilons', type=lambda s: [float(eps) for eps in s.split(',')], default=None)
39 |
40 | args = parser.parse_args()
41 |
42 |
43 | def run_set_of_attacks(epsilons, attack_modes, loaders, models_root, save_path = None, device = 'cuda', key_path_word = ''):
44 | fix_seeds()
45 |
46 | df = pd.DataFrame()
47 |
48 | if key_path_word == 'u2':
49 | meta_model_path = "{}/*/*/*.pth".format(models_root)
50 | elif key_path_word == 'euler':
51 | meta_model_path = "{}/*/*.pth".format(models_root)
52 | else:
53 | meta_model_path = "{}/*/*.pth".format(models_root)
54 |
55 |
56 | i = 0
57 | for state_dict_path in glob.glob(meta_model_path, recursive = True):
58 | if not (key_path_word in state_dict_path):
59 | continue
60 |
61 | model, model_args = load_model(state_dict_path)
62 | model.eval()
63 | model.cuda()
64 |
65 | val_solver = create_solver(*model_args.solvers[0], dtype = torch.float32, device = device)
66 | val_solver_options = Namespace(solver_mode = 'standalone')
67 | solvers_kwargs = {'solvers':[val_solver], 'solver_options':val_solver_options}
68 |
69 | robust_accuracies = run_attack(model, epsilons, attack_modes, loaders, device, solvers_kwargs=solvers_kwargs)
70 | robust_accuracies = {k : np.array(v) for k,v in robust_accuracies.items()}
71 |
72 | data = [list(dict(model_args._get_kwargs()).values()) +\
73 | list(robust_accuracies.values()) +\
74 | [epsilons]
75 | ]
76 | columns = list(dict(model_args._get_kwargs()).keys()) +\
77 | list(robust_accuracies.keys()) +\
78 | ['epsilons']
79 |
80 | df_tmp = pd.DataFrame(data = data, columns = columns)
81 | df = df.append(df_tmp)
82 |
83 | if save_path is not None:
84 | df.to_csv(save_path, index = False)
85 |
86 | i += 1
87 | print('{} models have been processed'.format(i))
88 |
89 |
90 |
91 | if __name__=="__main__":
92 | loaders = dataloaders.get_loader(batch_size=256,
93 | data_name='mnist',
94 | data_root=args.data_root,
95 | num_workers = 4,
96 | train=False, val=True)
97 | device = 'cuda'
98 |
99 | if args.epsilons is not None:
100 | epsilons = args.epsilons
101 | else:
102 | epsilons = np.linspace(args.min_eps, args.max_eps, num=args.num_eps)
103 |
104 | run_set_of_attacks(epsilons=epsilons,
105 | attack_modes = ["fgsm", "at", "at_ls", "av", "fs"][:1],
106 | loaders = loaders,
107 | models_root = args.models_root,
108 | save_path = args.save_path,
109 | key_path_word = args.key_path_word,
110 | device = device)
111 |
112 |
113 | # CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=4 python3 attacks_runner.py --data_root "/workspace/raid/data/datasets" --models_root "/workspace/raid/data/jgusak/neural-ode-sopa/odenet_mnist_noise/" --save_path '/workspace/home/jgusak/neural-ode-sopa/experiments/odenet_mnist/results/robust_accuracies_fgsm/test_part1.csv' --epsilons 0.15,0.3,0.5
114 |
--------------------------------------------------------------------------------
/sopa/src/models/odenet_mnist/attacks_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.transforms as transforms
4 | from torch.utils.data import DataLoader
5 | import torchvision.datasets as datasets
6 |
7 |
8 | import numpy as np
9 | from collections import defaultdict
10 |
11 | from MegaAdversarial.src.utils.runner import test, test_ensemble, fix_seeds
12 | from MegaAdversarial.src.attacks import (
13 | Clean,
14 | FGSM,
15 | PGD,
16 | LabelSmoothing,
17 | AdversarialVertex,
18 | AdversarialVertexExtra,
19 | FeatureScatter,
20 | NCEScatter,
21 | NCEScatterWithBuffer,
22 | MetaAttack,
23 | FGSM2Ensemble
24 | )
25 |
26 |
27 |
28 | def run_attack(model, epsilons, attack_modes, loaders, device='cuda', solvers_kwargs = None,
29 | at_lr = None, at_n_iter = None):
30 | robust_accuracies = defaultdict(list)
31 |
32 | for mode in attack_modes:
33 | for epsilon in epsilons:
34 | # CONFIG_PGD_TEST = {"eps": epsilon, "lr": 2.0 / 255 * 10, "n_iter": 20}
35 | CONFIG_PGD_TEST = {"eps": epsilon, "lr": at_lr, "n_iter": at_n_iter}
36 | CONFIG_FGSM_TEST = {"eps": epsilon}
37 |
38 | if mode == "clean":
39 | test_attack = Clean(model)
40 | elif mode == "fgsm":
41 | test_attack = FGSM(model, **CONFIG_FGSM_TEST)
42 |
43 | elif mode == "at":
44 | test_attack = PGD(model, **CONFIG_PGD_TEST)
45 |
46 | # elif mode == "at_ls":
47 | # test_attack = PGD(model, **CONFIG_PGD_TEST) # wrong func, fix this
48 |
49 | # elif mode == "av":
50 | # test_attack = PGD(model, **CONFIG_PGD_TEST) # wrong func, fix this
51 |
52 | # elif mode == "fs":
53 | # test_attack = PGD(model, **CONFIG_PGD_TEST) # wrong func, fix this
54 |
55 | print("Attack {}".format(mode))
56 | test_metrics = test(loaders["val"], model, test_attack, device, show_progress=True, solvers_kwargs = solvers_kwargs)
57 | test_log = f"Test: | " + " | ".join(
58 | map(lambda x: f"{x[0]}: {x[1]:.6f}", test_metrics.items())
59 | )
60 | print(test_log)
61 |
62 | robust_accuracies['accuracy_{}'.format(mode)].append(test_metrics['accuracy_adv'])
63 |
64 | return robust_accuracies
65 |
66 |
67 |
68 | def run_attack2ensemble(models, epsilons, attack_modes, loaders, device='cuda', solvers_kwargs_arr = None,
69 | at_lr = None, at_n_iter = None):
70 | robust_accuracies = defaultdict(list)
71 |
72 | for mode in attack_modes:
73 | for epsilon in epsilons:
74 | # CONFIG_PGD_TEST = {"eps": epsilon, "lr": 2.0 / 255 * 10, "n_iter": 20}
75 | CONFIG_PGD_TEST = {"eps": epsilon, "lr": at_lr, "n_iter": at_n_iter}
76 | CONFIG_FGSM_TEST = {"eps": epsilon}
77 |
78 | if mode == "fgsm":
79 | test_attack2ensemble = FGSM2Ensemble(models, **CONFIG_FGSM_TEST)
80 | else:
81 | raise NotImplementedError
82 |
83 | print("Attack {}".format(mode))
84 | test_metrics = test_ensemble(loaders["val"],
85 | models,
86 | test_attack2ensemble,
87 | device,
88 | show_progress=True,
89 | solvers_kwargs_arr = solvers_kwargs_arr)
90 |
91 | test_log = f"Test: | " + " | ".join(
92 | map(lambda x: f"{x[0]}: {x[1]:.6f}", test_metrics.items())
93 | )
94 | print(test_log)
95 |
96 | robust_accuracies['accuracy_{}'.format(mode)].append(test_metrics['accuracy_adv'])
97 |
98 | return robust_accuracies
--------------------------------------------------------------------------------
/sopa/src/models/odenet_mnist/data.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 | import torchvision.datasets as datasets
3 | import torchvision.transforms as transforms
4 |
5 |
6 | def get_mnist_loaders(data_aug=False, batch_size=128, test_batch_size=1000, perc=1.0, data_root=None):
7 | if data_aug:
8 | transform_train = transforms.Compose([
9 | transforms.RandomCrop(28, padding=4),
10 | transforms.ToTensor(),
11 | ])
12 | else:
13 | transform_train = transforms.Compose([
14 | transforms.ToTensor(),
15 | ])
16 |
17 | transform_test = transforms.Compose([
18 | transforms.ToTensor(),
19 | ])
20 |
21 | train_loader = DataLoader(
22 | datasets.MNIST(root=data_root, train=True, download=True, transform=transform_train), batch_size=batch_size,
23 | shuffle=True, num_workers=2, drop_last=True
24 | )
25 |
26 | train_eval_loader = DataLoader(
27 | datasets.MNIST(root=data_root, train=True, download=True, transform=transform_test),
28 | batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
29 | )
30 |
31 | test_loader = DataLoader(
32 | datasets.MNIST(root=data_root, train=False, download=True, transform=transform_test),
33 | batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
34 | )
35 |
36 | return train_loader, test_loader, train_eval_loader
37 |
38 |
39 | def get_svhn_loaders(data_aug=False, batch_size=128, test_batch_size=1000, perc=1.0, data_root=None):
40 | if data_aug:
41 | transform_train = transforms.Compose([
42 | transforms.RandomCrop(32, padding=4),
43 | transforms.ToTensor(),
44 | ])
45 | else:
46 | transform_train = transforms.Compose([
47 | transforms.ToTensor(),
48 | ])
49 |
50 | transform_test = transforms.Compose([
51 | transforms.ToTensor(),
52 | ])
53 |
54 | train_loader = DataLoader(
55 | datasets.SVHN(root=data_root, split='train', download=True, transform=transform_train), batch_size=batch_size,
56 | shuffle=True, num_workers=2, drop_last=True
57 | )
58 |
59 | train_eval_loader = DataLoader(
60 | datasets.SVHN(root=data_root, split='train', download=True, transform=transform_test),
61 | batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
62 | )
63 |
64 | test_loader = DataLoader(
65 | datasets.SVHN(root=data_root, split='test', download=True, transform=transform_test),
66 | batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
67 | )
68 |
69 | return train_loader, test_loader, train_eval_loader
70 |
71 |
72 | def inf_generator(iterable):
73 | """Allows training with DataLoaders in a single infinite loop:
74 | for i, (x, y) in enumerate(inf_generator(train_loader)):
75 | """
76 | iterator = iterable.__iter__()
77 | while True:
78 | try:
79 | yield iterator.__next__()
80 | except StopIteration:
81 | iterator = iterable.__iter__()
82 |
--------------------------------------------------------------------------------
/sopa/src/models/odenet_mnist/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from functools import partial
6 |
7 |
8 | class MetaODEBlock(nn.Module):
9 | def __init__(self, activation_type = 'relu'):
10 | super(MetaODEBlock, self).__init__()
11 |
12 | self.rhs_func = ODEfunc(64, activation_type)
13 | self.integration_time = torch.tensor([0, 1]).float()
14 |
15 |
16 | def forward(self, x, solvers, solver_options):
17 | nsolvers = len(solvers)
18 |
19 | if solver_options.solver_mode == 'standalone':
20 | y = solvers[0].integrate(self.rhs_func, x = x, t = self.integration_time)
21 |
22 | elif solver_options.solver_mode == 'switch':
23 | if solver_options.switch_probs is not None:
24 | switch_probs = solver_options.switch_probs
25 | else:
26 | switch_probs = [1./nsolvers for _ in range(nsolvers)]
27 | solver_id = np.random.choice(range(nsolvers), p = switch_probs)
28 | solver_options.switch_solver_id = solver_id
29 |
30 | y = solvers[solver_id].integrate(self.rhs_func, x = x, t = self.integration_time)
31 |
32 | elif solver_options.solver_mode == 'ensemble':
33 | coin_flip = torch.bernoulli(torch.tensor((1,)), solver_options.ensemble_prob)
34 | solver_options.ensemble_coin_flip = coin_flip
35 |
36 | if coin_flip :
37 | if solver_options.ensemble_weights is not None:
38 | ensemble_weights = solver_options.ensemble_weights
39 | else:
40 | ensemble_weights = [1./nsolvers for _ in range(nsolvers)]
41 |
42 | for i, (wi, solver) in enumerate(zip(ensemble_weights, solvers)):
43 | if i == 0:
44 | y = wi * solver.integrate(self.rhs_func, x = x, t = self.integration_time)
45 | else:
46 | y += wi * solver.integrate(self.rhs_func, x = x, t = self.integration_time)
47 | else:
48 | y = solvers[0].integrate(self.rhs_func, x = x, t = self.integration_time)
49 |
50 | return y[-1,:,:,:,:]
51 |
52 |
53 | def ss_loss(self, y, solvers, solver_options):
54 | z0 = y
55 | rhs_func_ss = partial(self.rhs_func, ss_loss = True)
56 | integration_time_ss = self.integration_time + 1
57 |
58 | nsolvers = len(solvers)
59 |
60 | if solver_options.solver_mode == 'standalone':
61 | z = solvers[0].integrate(rhs_func_ss.func, x = y, t = integration_time_ss)
62 |
63 | elif solver_options.solver_mode == 'switch':
64 | if solver_options.switch_probs is not None:
65 | switch_probs = solver_options.switch_probs
66 | else:
67 | switch_probs = [1./nsolvers for _ in range(nsolvers)]
68 | solver_id = solver_options.switch_solver_id
69 |
70 | z = solvers[solver_id].integrate(rhs_func_ss.func, x = y, t = integration_time_ss)
71 |
72 | elif solver_options.solver_mode == 'ensemble':
73 | coin_flip = solver_options.ensemble_coin_flip
74 |
75 | if coin_flip :
76 | if solver_options.ensemble_weights is not None:
77 | ensemble_weights = solver_options.ensemble_weights
78 | else:
79 | ensemble_weights = [1./nsolvers for _ in range(nsolvers)]
80 |
81 | for i, (wi, solver) in enumerate(zip(ensemble_weights, solvers)):
82 | if i == 0:
83 | z = wi * solver.integrate(rhs_func_ss.func, x = y, t = integration_time_ss)
84 | else:
85 | z += wi * solver.integrate(rhs_func_ss.func, x = y, t = integration_time_ss)
86 | else:
87 | z = solvers[0].integrate(rhs_func_ss.func, x = y, t = integration_time_ss)
88 |
89 | z = z[-1,:,:,:,:] - z0
90 | z = torch.norm(z.reshape((z.shape[0], -1)), dim = 1)
91 | z = torch.mean(z)
92 |
93 | return z
94 |
95 |
96 | class MetaNODE(nn.Module):
97 |
98 | def __init__(self, downsampling_method = 'conv', is_odenet = True, activation_type = 'relu', in_channels = 1):
99 | super(MetaNODE, self).__init__()
100 |
101 | self.is_odenet = is_odenet
102 |
103 | self.downsampling_layers = nn.Sequential(*build_downsampling_layers(downsampling_method, in_channels))
104 | self.fc_layers = nn.Sequential(*build_fc_layers())
105 |
106 | if is_odenet:
107 | self.blocks = nn.ModuleList([MetaODEBlock(activation_type)])
108 | else:
109 | self.blocks = nn.ModuleList([ResBlock(64, 64) for _ in range(6)])
110 |
111 |
112 | def forward(self, x, solvers=None, solver_options=None, loss_options = None):
113 | self.ss_loss = 0
114 |
115 | x = self.downsampling_layers(x)
116 |
117 | for block in self.blocks:
118 | if self.is_odenet:
119 | x = block(x, solvers, solver_options)
120 |
121 | if (loss_options is not None) and loss_options.ss_loss:
122 | z = block.ss_loss(x, solvers, solver_options)
123 | self.ss_loss += z
124 | else:
125 | x = block(x)
126 |
127 | x = self.fc_layers(x)
128 | return x
129 |
130 | def get_ss_loss(self):
131 | return self.ss_loss
132 |
133 |
134 | class ODEfunc(nn.Module):
135 |
136 | def __init__(self, dim, activation_type = 'relu'):
137 | super(ODEfunc, self).__init__()
138 |
139 | if activation_type == 'tanh':
140 | activation = nn.Tanh()
141 | elif activation_type == 'softplus':
142 | activation = nn.Softplus()
143 | elif activation_type == 'softsign':
144 | activation = nn.Softsign()
145 | elif activation_type == 'relu':
146 | activation = nn.ReLU()
147 | else:
148 | raise NotImplementedError('{} activation is not implemented'.format(activation_type))
149 |
150 | self.norm1 = norm(dim)
151 | self.relu = nn.ReLU(inplace=True)
152 | self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
153 | self.norm2 = norm(dim)
154 | self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
155 | self.norm3 = norm(dim)
156 | self.nfe = 0
157 |
158 | def forward(self, t, x, ss_loss = False):
159 | self.nfe += 1
160 | out = self.norm1(x)
161 | out = self.relu(out)
162 | out = self.conv1(t, out)
163 | out = self.norm2(out)
164 | out = self.relu(out)
165 | out = self.conv2(t, out)
166 | out = self.norm3(out)
167 |
168 | if ss_loss:
169 | out = torch.abs(out)
170 |
171 | return out
172 |
173 | def build_downsampling_layers(downsampling_method = 'conv', in_channels = 1):
174 | if downsampling_method == 'conv':
175 | downsampling_layers = [
176 | nn.Conv2d(in_channels, 64, 3, 1),
177 | norm(64),
178 | nn.ReLU(inplace=True),
179 | nn.Conv2d(64, 64, 4, 2, 1),
180 | norm(64),
181 | nn.ReLU(inplace=True),
182 | nn.Conv2d(64, 64, 4, 2, 1),
183 | ]
184 | elif downsampling_method == 'res':
185 | downsampling_layers = [
186 | nn.Conv2d(in_channels, 64, 3, 1),
187 | ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
188 | ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
189 | ]
190 | return downsampling_layers
191 |
192 |
193 | def build_fc_layers():
194 | fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]
195 | return fc_layers
196 |
197 |
198 | def conv3x3(in_planes, out_planes, stride=1):
199 | """3x3 convolution with padding"""
200 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
201 |
202 |
203 | def conv1x1(in_planes, out_planes, stride=1):
204 | """1x1 convolution"""
205 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
206 |
207 |
208 | def norm(dim):
209 | return nn.GroupNorm(min(32, dim), dim)
210 |
211 |
212 | class ResBlock(nn.Module):
213 | expansion = 1
214 |
215 | def __init__(self, inplanes, planes, stride=1, downsample=None):
216 | super(ResBlock, self).__init__()
217 | self.norm1 = norm(inplanes)
218 | self.relu = nn.ReLU(inplace=True)
219 | self.downsample = downsample
220 | self.conv1 = conv3x3(inplanes, planes, stride)
221 | self.norm2 = norm(planes)
222 | self.conv2 = conv3x3(planes, planes)
223 |
224 | def forward(self, x):
225 | shortcut = x
226 |
227 | out = self.relu(self.norm1(x))
228 |
229 | if self.downsample is not None:
230 | shortcut = self.downsample(out)
231 |
232 | out = self.conv1(out)
233 | out = self.norm2(out)
234 | out = self.relu(out)
235 | out = self.conv2(out)
236 |
237 | return out + shortcut
238 |
239 |
240 | class ConcatConv2d(nn.Module):
241 |
242 | def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
243 | super(ConcatConv2d, self).__init__()
244 | module = nn.ConvTranspose2d if transpose else nn.Conv2d
245 | self._layer = module(
246 | dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
247 | bias=bias
248 | )
249 |
250 | def forward(self, t, x):
251 | tt = torch.ones_like(x[:, :1, :, :]) * t
252 | ttx = torch.cat([tt, x], 1)
253 | return self._layer(ttx)
254 |
255 |
256 | class Flatten(nn.Module):
257 |
258 | def __init__(self):
259 | super(Flatten, self).__init__()
260 |
261 | def forward(self, x):
262 | shape = torch.prod(torch.tensor(x.shape[1:])).item()
263 | return x.view(-1, shape)
264 |
265 |
266 |
--------------------------------------------------------------------------------
/sopa/src/models/odenet_mnist/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | def one_hot(x, K):
5 | return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)
6 |
7 |
8 | ######### device inside !!!
9 | def accuracy(model, dataset_loader, device, solver = None, solver_options = None):
10 | total_correct = 0
11 | for x, y in dataset_loader:
12 | x = x.to(device)
13 | y = one_hot(np.array(y.numpy()), 10)
14 |
15 | target_class = np.argmax(y, axis=1)
16 |
17 | if solver is not None:
18 | out = model(x, solver, solver_options).cpu().detach().numpy()
19 | else:
20 | out = model(x).cpu().detach().numpy()
21 |
22 | predicted_class = np.argmax(out, axis=1)
23 | total_correct += np.sum(predicted_class == target_class)
24 | return total_correct / len(dataset_loader.dataset)
25 |
26 |
27 | def sn_test(model, test_loader, device, solvers, solver_options, nsteps_grid):
28 | model.eval()
29 | for solver in solvers:
30 | solver.freeze_params()
31 |
32 | accs = []
33 | for nsteps in nsteps_grid:
34 | for solver in solvers:
35 | solver.grid_constructor = lambda t: torch.linspace(t[0], t[-1], nsteps + 1)
36 |
37 | with torch.no_grad():
38 | acc = accuracy(model, test_loader, device, solvers, solver_options)
39 | accs.append(acc)
40 |
41 | return accs
--------------------------------------------------------------------------------
/sopa/src/models/odenet_mnist/runner.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from argparse import Namespace
4 | import logging
5 | import time
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | import copy
11 |
12 | from decimal import Decimal
13 |
14 | import sys
15 | sys.path.append('/workspace/home/jgusak/neural-ode-sopa/')
16 |
17 | from sopa.src.solvers.utils import create_solver, noise_params
18 | from sopa.src.models.utils import fix_seeds, RunningAverageMeter
19 | from sopa.src.models.odenet_mnist.layers import MetaNODE
20 | from sopa.src.models.odenet_mnist.utils import makedirs, get_logger, count_parameters, learning_rate_with_decay
21 | from sopa.src.models.odenet_mnist.data import get_mnist_loaders, inf_generator
22 | from sopa.src.models.odenet_mnist.metrics import accuracy
23 |
24 |
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument('--network', type=str, choices=['resnet', 'odenet'], default='odenet')
27 | parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res'])
28 | parser.add_argument('--activation', type=str, choices=['tanh', 'softplus', 'softsign', 'relu'], default='relu')
29 | parser.add_argument('--in_channels', type=int, default=1)
30 |
31 | parser.add_argument('--solvers',
32 | type=lambda s:[tuple(map(lambda iparam: str(iparam[1]) if iparam[0] <= 1 else (
33 | int(iparam[1]) if iparam[0]==2 else (
34 | float(iparam[1]) if iparam[0] == 3 else Decimal(iparam[1]))),
35 | enumerate(item.split(',')))) for item in s.strip().split(';')],
36 | default = None,
37 | help='Each solver is represented with (method,parameterization,n_steps,step_size,u0,v0) \n' +
38 | 'If the solver has only one parameter u0, set v0 to -1; \n' +
39 | 'n_steps and step_size are exclusive parameters, only one of them can be != -1, \n'
40 | 'If n_steps = step_size = -1, automatic time grid_constructor is used \n;'
41 | 'For example, --solvers rk4,uv,2,-1,0.3,0.6;rk3,uv,-1,0.1,0.4,0.6;rk2,u,4,-1,0.3,-1')
42 |
43 | parser.add_argument('--solver_mode', type=str, choices=['switch', 'ensemble', 'standalone'], default='standalone')
44 | parser.add_argument('--switch_probs',type=lambda s: [float(item) for item in s.split(',')], default=None,
45 | help="--switch_probs 0.8,0.1,0.1")
46 | parser.add_argument('--ensemble_weights', type=lambda s: [float(item) for item in s.split(',')], default=None,
47 | help="ensemble_weights 0.6,0.2,0.2")
48 | parser.add_argument('--ensemble_prob', type=float, default=1.)
49 |
50 | parser.add_argument('--noise_type', type=str, choices=['cauchy', 'normal'], default=None)
51 | parser.add_argument('--noise_sigma', type=float, default = 0.001)
52 | parser.add_argument('--noise_prob', type=float, default = 0.)
53 | parser.add_argument('--minimize_rk2_error', type=eval, default=False, choices=[True, False])
54 |
55 | parser.add_argument('--nepochs_nn', type=int, default=160)
56 | parser.add_argument('--nepochs_solver', type=int, default=0)
57 | parser.add_argument('--nstages', type=int, default=1)
58 |
59 | parser.add_argument('--ss_loss', type=eval, default=False, choices=[True, False])
60 | parser.add_argument('--ss_loss_reg', type=float, default=0.1)
61 |
62 | parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False])
63 | parser.add_argument('--lr', type=float, default=0.1)
64 | parser.add_argument('--weight_decay', type=float, default=0.0005)
65 | parser.add_argument('--batch_size', type=int, default=128)
66 | parser.add_argument('--test_batch_size', type=int, default=1000)
67 |
68 | parser.add_argument('--lr_uv', type=float, default=1e-3)
69 | parser.add_argument('--torch_dtype', type=str, default='float32')
70 |
71 | parser.add_argument('--data_root', type=str, default='/workspace/home/jgusak/neural-ode-sopa/.data/mnist')
72 | parser.add_argument('--save', type=str, default='./experiment2')
73 | parser.add_argument('--debug', action='store_true')
74 | parser.add_argument('--gpu', type=int, default=0)
75 |
76 | parser.add_argument('--seed', type=int, default=502)
77 |
78 |
79 | args = parser.parse_args()
80 |
81 |
82 | if args.torch_dtype == 'float64':
83 | dtype = torch.float64
84 | elif args.torch_dtype == 'float32':
85 | dtype = torch.float32
86 | else:
87 | raise ValueError('torch_type should be either float64 or float32')
88 |
89 |
90 | if __name__=="__main__":
91 | print(args.solvers)
92 | fix_seeds(args.seed)
93 |
94 | if args.torch_dtype == 'float64':
95 | dtype = torch.float64
96 | elif args.torch_dtype == 'float32':
97 | dtype = torch.float32
98 | else:
99 | raise ValueError('torch_type should be either float64 or float32')
100 |
101 | makedirs(args.save)
102 | logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=None)
103 | logger.info(args)
104 |
105 | device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
106 |
107 | ########### Create train / val solvers
108 | train_solvers = [create_solver(*solver_params, dtype =dtype, device = device) for solver_params in args.solvers]
109 | for solver in train_solvers:
110 | print(solver.u)
111 | solver.freeze_params()
112 |
113 | train_solver_options = Namespace(**{key:vars(args)[key] for key in ['solver_mode','switch_probs',
114 | 'ensemble_prob','ensemble_weights']})
115 | val_solver_options = Namespace(solver_mode = 'standalone')
116 |
117 | ########## Build the model
118 | is_odenet = args.network == 'odenet'
119 |
120 | model = MetaNODE(downsampling_method = args.downsampling_method, is_odenet=is_odenet,
121 | activation_type=args.activation, in_channels = args.in_channels)
122 | model.to(device)
123 |
124 | logger.info(model)
125 |
126 | ########### Create data loaders
127 | train_loader, test_loader, train_eval_loader = get_mnist_loaders(
128 | args.data_aug, args.batch_size, args.test_batch_size, data_root = args.data_root)
129 |
130 | data_gen = inf_generator(train_loader)
131 | batches_per_epoch = len(train_loader)
132 |
133 | ########### Creare criterion and optimizer
134 | criterion = nn.CrossEntropyLoss().to(device)
135 | loss_options = Namespace(ss_loss=args.ss_loss)
136 |
137 | lr_fn = learning_rate_with_decay(
138 | args.batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[60, 100, 140],
139 | decay_rates=[1, 0.1, 0.01, 0.001], lr0 = args.lr)
140 |
141 | optimizer = optim.RMSprop([{"params" : model.parameters(), 'lr' : args.lr},], lr=args.lr, weight_decay=args.weight_decay)
142 |
143 | ########### Train the model
144 | nsolvers = len(train_solvers)
145 | best_acc = [0]*nsolvers
146 |
147 | batch_time_meter = RunningAverageMeter()
148 | f_nfe_meter = RunningAverageMeter()
149 | b_nfe_meter = RunningAverageMeter()
150 | end = time.time()
151 |
152 |
153 | for itr in range(args.nepochs_nn * batches_per_epoch):
154 |
155 | for param_group in optimizer.param_groups:
156 | param_group['lr'] = lr_fn(itr)
157 |
158 | optimizer.zero_grad()
159 | x, y = data_gen.__next__()
160 | x = x.to(device)
161 | y = y.to(device)
162 |
163 | ##### Noise params
164 | if args.noise_type is not None:
165 | for i in range(len(train_solvers)):
166 | train_solvers[i].u, train_solvers[i].v = noise_params(train_solvers[i].u0,
167 | train_solvers[i].v0,
168 | std = args.noise_sigma,
169 | bernoulli_p = args.noise_prob,
170 | noise_type = args.noise_type)
171 | train_solvers[i].build_ButcherTableau()
172 |
173 | ##### Forward pass
174 | if is_odenet:
175 | logits = model(x, train_solvers, train_solver_options, loss_options)
176 | else:
177 | logits = model(x)
178 |
179 | loss = criterion(logits, y)
180 | if (loss_options is not None) and loss_options.ss_loss:
181 | loss += args.ss_loss_reg * model.get_ss_loss()
182 |
183 | ##### Compute NFE-forward
184 | if is_odenet:
185 | nfe_forward = 0
186 | for i in range(len(model.blocks)):
187 | nfe_forward += model.blocks[i].rhs_func.nfe
188 | model.blocks[i].rhs_func.nfe = 0
189 |
190 | loss.backward()
191 | optimizer.step()
192 |
193 | ##### Compute NFE-backward
194 | if is_odenet:
195 | nfe_backward = 0
196 | for i in range(len(model.blocks)):
197 | nfe_backward += model.blocks[i].rhs_func.nfe
198 | model.blocks[i].rhs_func.nfe = 0
199 |
200 | ##### Denoise params
201 | if args.noise_type is not None:
202 | for i in range(len(train_solvers)):
203 | train_solvers[i].u, train_solvers[i].v = train_solvers[i].u0, train_solvers[i].v0
204 | train_solvers[i].build_ButcherTableau()
205 |
206 | batch_time_meter.update(time.time() - end)
207 | if is_odenet:
208 | f_nfe_meter.update(nfe_forward)
209 | b_nfe_meter.update(nfe_backward)
210 | end = time.time()
211 |
212 | if itr % batches_per_epoch == 0:
213 | with torch.no_grad():
214 | train_acc = [0]*nsolvers
215 | val_acc = [0]*nsolvers
216 |
217 | for solver_id, val_solver in enumerate(train_solvers):
218 | train_acc_id = accuracy(model, train_eval_loader, device, [val_solver], val_solver_options)
219 | val_acc_id = accuracy(model, test_loader, device, [val_solver], val_solver_options)
220 |
221 | train_acc[solver_id] = train_acc_id
222 | val_acc[solver_id] = val_acc_id
223 |
224 | if val_acc_id > best_acc[solver_id]:
225 | torch.save({'state_dict': model.state_dict(), 'args': args, 'solver_id':solver_id},
226 | os.path.join(args.save,'model_best_{}.pth'.format(solver_id)))
227 | best_acc[solver_id] = val_acc_id
228 | del train_acc_id, val_acc_id
229 |
230 | if is_odenet:
231 | for i in range(len(model.blocks)):
232 | model.blocks[i].rhs_func.nfe = 0
233 |
234 | logger.info(
235 | "Epoch {:04d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f} | "
236 | "TrainAcc {} | TestAcc {} ".format(
237 | itr // batches_per_epoch, batch_time_meter.val, batch_time_meter.avg, f_nfe_meter.avg,
238 | b_nfe_meter.avg, train_acc, val_acc))
239 |
240 |
241 | ## How to run
242 | # CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=3 python3 runner.py
243 | # --data_root "/workspace/raid/data/datasets"
244 | # --save "./experiment1"
245 | # --network odenet
246 | # --downsampling_method conv
247 | # --solvers rk2,u,-1,-1,0.6,-1;rk2,u,-1,-1,0.5,-1
248 | # --solver_mode switch
249 | # --switch_probs 0.8,0.2
250 | # --nepochs_nn 160
251 | # --nepochs_solver 0
252 | # --nstages 1
253 | # --lr 0.1
254 | # --seed 502
255 |
256 | # CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=4 python3 runner.py --data_root /workspace/raid/data/datasets/mnist --save ./experiment2_new --network 'odenet' --downsampling-method 'conv' --solvers "rk2,u,4,-1,0.5,-1;rk2,u,4,-1,1.0,-1" --solver_mode "switch" --activation "relu" --seed 702 --nepochs_nn 160 --nepochs_solver 0 --nstages 1 --lr 0.01 --ss_loss True
257 |
258 | # CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=4 python3 runner.py --data_root /workspace/raid/data/datasets/mnist --save ./experiment2_new --network 'odenet' --downsampling-method 'conv' --solvers "rk2,u,1,-1,0.66666666,-1" --solver_mode standalone --activation relu --seed 702 --nepochs_nn 160 --nepochs_solver 0 --nstages 1 --lr 0.1 --noise_type 'cauchy' --noise_sigma 0.001 --noise_prob 1.
259 |
--------------------------------------------------------------------------------
/sopa/src/models/odenet_mnist/runner_new.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from argparse import Namespace
4 | import logging
5 | import time
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | import copy
11 |
12 | from decimal import Decimal
13 | import wandb
14 | import sys
15 |
16 | sys.path.append('../../../../')
17 |
18 | from sopa.src.solvers.utils import create_solver
19 | from sopa.src.models.utils import fix_seeds, RunningAverageMeter
20 | from sopa.src.models.odenet_mnist.layers import MetaNODE
21 | from sopa.src.models.odenet_mnist.utils import makedirs, get_logger, count_parameters, learning_rate_with_decay
22 | from sopa.src.models.odenet_mnist.data import get_mnist_loaders, inf_generator
23 | # from sopa.src.models.odenet_mnist.metrics import accuracy
24 | from sopa.src.models.odenet_mnist.train_validate import train, validate
25 |
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument('--network', type=str, choices=['resnet', 'odenet'], default='odenet')
28 | parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res'])
29 | parser.add_argument('--activation', type=str, choices=['tanh', 'softplus', 'softsign', 'relu'], default='relu')
30 | parser.add_argument('--in_channels', type=int, default=1)
31 |
32 | parser.add_argument('--solvers',
33 | type=lambda s: [tuple(map(lambda iparam: str(iparam[1]) if iparam[0] <= 1 else (
34 | int(iparam[1]) if iparam[0] == 2 else (
35 | float(iparam[1]) if iparam[0] == 3 else Decimal(iparam[1]))),
36 | enumerate(item.split(',')))) for item in s.strip().split(';')],
37 | default=None,
38 | help='Each solver is represented with (method,parameterization,n_steps,step_size,u0,v0) \n' +
39 | 'If the solver has only one parameter u0, set v0 to -1; \n' +
40 | 'n_steps and step_size are exclusive parameters, only one of them can be != -1, \n'
41 | 'If n_steps = step_size = -1, automatic time grid_constructor is used \n;'
42 | 'For example, --solvers rk4,uv,2,-1,0.3,0.6;rk3,uv,-1,0.1,0.4,0.6;rk2,u,4,-1,0.3,-1')
43 |
44 | parser.add_argument('--solver_mode', type=str, choices=['switch', 'ensemble', 'standalone'], default='standalone')
45 | parser.add_argument('--val_solver_modes',
46 | type=lambda s: s.strip().split(','),
47 | default=['standalone', 'ensemble', 'switch'],
48 | help='Solver modes to use for validation step')
49 |
50 | parser.add_argument('--switch_probs', type=lambda s: [float(item) for item in s.split(',')], default=None,
51 | help="--switch_probs 0.8,0.1,0.1")
52 | parser.add_argument('--ensemble_weights', type=lambda s: [float(item) for item in s.split(',')], default=None,
53 | help="ensemble_weights 0.6,0.2,0.2")
54 | parser.add_argument('--ensemble_prob', type=float, default=1.)
55 |
56 | parser.add_argument('--noise_type', type=str, choices=['cauchy', 'normal'], default=None)
57 | parser.add_argument('--noise_sigma', type=float, default=0.001)
58 | parser.add_argument('--noise_prob', type=float, default=0.)
59 | parser.add_argument('--minimize_rk2_error', type=eval, default=False, choices=[True, False])
60 |
61 | parser.add_argument('--nepochs_nn', type=int, default=160)
62 | parser.add_argument('--nepochs_solver', type=int, default=0)
63 | parser.add_argument('--nstages', type=int, default=1)
64 |
65 | parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False])
66 | parser.add_argument('--lr', type=float, default=0.01)
67 | parser.add_argument('--weight_decay', type=float, default=0.0005)
68 | parser.add_argument('--batch_size', type=int, default=128)
69 | parser.add_argument('--test_batch_size', type=int, default=1000)
70 |
71 | parser.add_argument('--lr_uv', type=float, default=1e-3)
72 | parser.add_argument('--torch_dtype', type=str, default='float32')
73 | parser.add_argument('--wandb_name', type=str, default='mnist_tmp')
74 |
75 | parser.add_argument('--data_root', type=str, default='/gpfs/gpfs0/t.daulbaev/data/MNIST')
76 | parser.add_argument('--save', type=str, default='../../../rk2_tmp')
77 | parser.add_argument('--debug', action='store_true')
78 | parser.add_argument('--gpu', type=int, default=0)
79 |
80 | parser.add_argument('--seed', type=int, default=502)
81 | # Noise and adversarial attacks parameters:
82 | parser.add_argument('--data_noise_std', type=float, default=0.,
83 | help='Applies Norm(0, std) gaussian noise to each batch entry')
84 | parser.add_argument('--eps_adv_training', type=float, default=0.3,
85 | help='Epsilon for adversarial training')
86 | parser.add_argument(
87 | "--adv_training_mode",
88 | default="clean",
89 | choices=["clean", "fgsm", "at"], #, "at_ls", "av", "fs", "nce", "nce_moco", "moco", "av_extra", "meta"],
90 | help='''Adverarial training method/mode, by default there is no adversarial training (clean).
91 | For further details see MegaAdversarial/train in this repository.
92 | '''
93 | )
94 | parser.add_argument('--use_wandb', type=eval, default=True, choices=[True, False])
95 | parser.add_argument('--use_logger', type=eval, default=False, choices=[True, False])
96 | parser.add_argument('--ss_loss', type=eval, default=False, choices=[True, False])
97 | parser.add_argument('--ss_loss_reg', type=float, default=0.1)
98 | parser.add_argument('--timestamp', type=int, default=int(1e6 * time.time()))
99 |
100 | args = parser.parse_args()
101 |
102 | sys.path.append('../../')
103 |
104 | if args.use_logger:
105 | makedirs(args.save)
106 | logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=None)
107 | logger.info(args)
108 | if args.use_wandb:
109 | wandb.init(project=args.wandb_name, entity="sopa_node")
110 | makedirs(args.save)
111 | wandb.config.update(args)
112 | os.makedirs(os.path.join(args.save, str(args.timestamp)))
113 |
114 | if args.torch_dtype == 'float64':
115 | dtype = torch.float64
116 | elif args.torch_dtype == 'float32':
117 | dtype = torch.float32
118 | else:
119 | raise ValueError('torch_type should be either float64 or float32')
120 |
121 | if __name__ == "__main__":
122 | print(args.solvers)
123 | fix_seeds(args.seed)
124 |
125 | if args.torch_dtype == 'float64':
126 | dtype = torch.float64
127 | elif args.torch_dtype == 'float32':
128 | dtype = torch.float32
129 | else:
130 | raise ValueError('torch_type should be either float64 or float32')
131 |
132 | makedirs(args.save)
133 | logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=None)
134 | logger.info(args)
135 |
136 | device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
137 |
138 | ########### Create train / val solvers
139 | train_solvers = [create_solver(*solver_params, dtype=dtype, device=device) for solver_params in args.solvers]
140 | for solver in train_solvers:
141 | solver.freeze_params()
142 |
143 | train_solver_options = Namespace(**{key: vars(args)[key] for key in ['solver_mode', 'switch_probs',
144 | 'ensemble_prob', 'ensemble_weights']})
145 |
146 | val_solver_modes = args.val_solver_modes
147 |
148 | ########## Build the model
149 | is_odenet = args.network == 'odenet'
150 |
151 | model = MetaNODE(downsampling_method=args.downsampling_method,
152 | is_odenet=is_odenet,
153 | activation_type=args.activation,
154 | in_channels=args.in_channels)
155 | model.to(device)
156 | if args.use_wandb:
157 | wandb.watch(model)
158 | if args.use_logger:
159 | logger.info(model)
160 |
161 | ########### Create data loaders
162 | train_loader, test_loader, train_eval_loader = get_mnist_loaders(args.data_aug,
163 | args.batch_size,
164 | args.test_batch_size,
165 | data_root=args.data_root)
166 |
167 | data_gen = inf_generator(train_loader)
168 | batches_per_epoch = len(train_loader)
169 |
170 | ########### Creare criterion and optimizer
171 | criterion = nn.CrossEntropyLoss().to(device)
172 | loss_options = Namespace(ss_loss=args.ss_loss)
173 |
174 | lr_fn = learning_rate_with_decay(
175 | args.batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[60, 100, 140],
176 | decay_rates=[1, 0.1, 0.01, 0.001], lr0=args.lr)
177 |
178 | optimizer = optim.RMSprop([{"params": model.parameters(), 'lr': args.lr}, ], lr=args.lr,
179 | weight_decay=args.weight_decay)
180 |
181 | ########### Train the model
182 | best_acc = {'standalone': [0] * len(train_solvers),
183 | 'ensemble': 0,
184 | 'switch': 0}
185 |
186 | batch_time_meter = RunningAverageMeter()
187 | f_nfe_meter = RunningAverageMeter()
188 | b_nfe_meter = RunningAverageMeter()
189 |
190 | for itr in range(args.nepochs_nn * batches_per_epoch):
191 |
192 | for param_group in optimizer.param_groups:
193 | param_group['lr'] = lr_fn(itr)
194 |
195 | if itr % batches_per_epoch != 0:
196 | train(itr,
197 | model,
198 | data_gen,
199 | solvers=train_solvers,
200 | solver_options=train_solver_options,
201 | criterion=criterion,
202 | optimizer=optimizer,
203 | batch_time_meter=batch_time_meter,
204 | f_nfe_meter=f_nfe_meter,
205 | b_nfe_meter=b_nfe_meter,
206 | device=device,
207 | dtype=dtype,
208 | is_odenet=is_odenet,
209 | args=args,
210 | logger=None,
211 | wandb_logger=None)
212 | else:
213 | train(itr,
214 | model,
215 | data_gen,
216 | solvers=train_solvers,
217 | solver_options=train_solver_options,
218 | criterion=criterion,
219 | optimizer=optimizer,
220 | batch_time_meter=batch_time_meter,
221 | f_nfe_meter=f_nfe_meter,
222 | b_nfe_meter=b_nfe_meter,
223 | device=device,
224 | dtype=dtype,
225 | is_odenet=is_odenet,
226 | args=args,
227 | logger=logger,
228 | wandb_logger=wandb)
229 |
230 | best_acc = validate(best_acc,
231 | itr,
232 | model,
233 | train_eval_loader,
234 | test_loader,
235 | batches_per_epoch,
236 | solvers=train_solvers,
237 | val_solver_modes=val_solver_modes,
238 | batch_time_meter=batch_time_meter,
239 | f_nfe_meter=f_nfe_meter,
240 | b_nfe_meter=b_nfe_meter,
241 | device=device,
242 | dtype=dtype,
243 | args=args,
244 | logger=logger,
245 | wandb_logger=wandb)
246 |
247 | # # How to run
248 | # CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=3 python3 runner.py
249 | # --data_root "/workspace/raid/data/datasets"
250 | # --save "./experiment1"
251 | # --network odenet
252 | # --downsampling_method conv
253 | # --solvers rk2,u,-1,-1,0.6,-1;rk2,u,-1,-1,0.5,-1
254 | # --solver_mode switch
255 | # --switch_probs 0.8,0.2
256 | # --nepochs_nn 160
257 | # --nepochs_solver 0
258 | # --nstages 1
259 | # --lr 0.1
260 | # --seed 502
261 |
262 | # CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=1 python3 runner_new.py --data_root /workspace/raid/data/datasets/mnist --save ./experiment2_new --network 'odenet' --downsampling-method 'conv' --solvers "rk2,u,1,-1,0.5,-1;rk2,u,1,-1,1.0,-1" --solver_mode "switch" --activation "relu" --seed 702 --nepochs_nn 160 --nepochs_solver 0 --nstages 1 --lr 0.1
263 |
264 | # CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=4 python3 runner_new.py --data_root /workspace/raid/data/datasets/mnist --save ./experiment2_new --network 'odenet' --downsampling-method 'conv' --solvers "rk2,u,1,-1,0.66666666,-1" --solver_mode standalone --activation relu --seed 702 --nepochs_nn 160 --nepochs_solver 0 --nstages 1 --lr 0.1 --noise_type 'cauchy' --noise_sigma 0.001 --noise_prob 1.
265 |
266 | # Пересчитать MNIST
--------------------------------------------------------------------------------
/sopa/src/models/odenet_mnist/runner_old.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import logging
4 | import time
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from torch.utils.data import DataLoader
10 | import torchvision.datasets as datasets
11 | import torchvision.transforms as transforms
12 |
13 | from torchdiffeq._impl.rk_common import _ButcherTableau
14 |
15 | import sys
16 | sys.path.append('/workspace/home/jgusak/neural-ode-sopa/')
17 | from sopa.src.solvers.rk_parametric_old import rk_param_tableau, odeint_plus
18 |
19 | from sopa.src.models.odenet_mnist.layers import ResBlock, ODEfunc, build_downsampling_layers, build_fc_layers
20 | from sopa.src.models.odenet_mnist.utils import makedirs, get_logger, count_parameters, learning_rate_with_decay, RunningAverageMeter
21 | from sopa.src.models.odenet_mnist.data import get_mnist_loaders, inf_generator
22 | from sopa.src.models.odenet_mnist.metrics import accuracy
23 |
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('--network', type=str, choices=['resnet', 'odenet'], default='odenet')
26 | parser.add_argument('--method', type=str, choices=['dopri5', 'adams', 'rk4', 'rk4_param', 'rk3_param','euler'], default='rk4')
27 | parser.add_argument('--step_size', type=float, default=None)
28 | parser.add_argument('--tol', type=float, default=1e-3)
29 | parser.add_argument('--adjoint', type=eval, default=False, choices=[True, False])
30 | parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res'])
31 | parser.add_argument('--nepochs', type=int, default=160)
32 | parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False])
33 | parser.add_argument('--lr', type=float, default=0.1)
34 | parser.add_argument('--batch_size', type=int, default=128)
35 | parser.add_argument('--test_batch_size', type=int, default=1000)
36 |
37 | parser.add_argument('--lr_uv', type=float, default=1e-3)
38 | parser.add_argument('--parameterization', type=str, choices=['uv', 'u1', 'u2', 'u3'], default='uv')
39 | parser.add_argument('--u0', type=float, default=1/3.)
40 | parser.add_argument('--v0', type=float, default=2/3.)
41 | parser.add_argument('--fix_param', action='store_true')
42 |
43 | parser.add_argument('--data_root', type=str, default='/workspace/home/jgusak/neural-ode-sopa/.data/mnist')
44 | parser.add_argument('--save', type=str, default='./experiment1')
45 | parser.add_argument('--debug', action='store_true')
46 | parser.add_argument('--gpu', type=int, default=0)
47 |
48 | parser.add_argument('--torch_seed', type=int, default=502)
49 | parser.add_argument('--numpy_seed', type=int, default=502)
50 |
51 | args = parser.parse_args([])
52 |
53 | # Set random seed for reproducibility
54 | np.random.seed(args.numpy_seed)
55 | torch.manual_seed(args.torch_seed)
56 | torch.backends.cudnn.deterministic = True
57 | torch.backends.cudnn.benchmark = False
58 |
59 | if args.adjoint:
60 | from torchdiffeq import odeint_adjoint as odeint
61 | else:
62 | from torchdiffeq import odeint
63 |
64 | class ODEBlock(nn.Module):
65 |
66 | def __init__(self, odefunc, args, device = 'cpu'):
67 | super(ODEBlock, self).__init__()
68 |
69 | self.odefunc = odefunc
70 | self.integration_time = torch.tensor([0, 1]).float()
71 |
72 | # make trainable parameters as attributes of ODE block,
73 | # recompute tableau at each forward step
74 | self.step_size = args.step_size
75 |
76 | self.method = args.method
77 | self.fix_param = None
78 | self.parameterization = None
79 | self.u0, self.v0 = None, None
80 | self.u_, self.v_ = None, None
81 | self.u, self.v = None, None
82 |
83 |
84 | self.eps = torch.finfo(torch.float32).eps
85 | self.device = device
86 |
87 |
88 | if self.method in ['rk4_param', 'rk3_param']:
89 | self.fix_param = args.fix_param
90 | self.parameterization = args.parameterization
91 | self.u0 = args.u0
92 |
93 | if self.fix_param:
94 | self.u = torch.tensor(self.u0)
95 |
96 | if self.parameterization == 'uv':
97 | self.v0 = args.v0
98 | self.v = torch.tensor(self.v0)
99 |
100 | else:
101 | # an important issue about putting leaf variables to the device https://discuss.pytorch.org/t/tensor-to-device-changes-is-leaf-causing-cant-optimize-a-non-leaf-tensor/37659
102 | self.u_ = nn.Parameter(torch.tensor(self.u0)).to(self.device)
103 | self.u = torch.clamp(self.u_, self.eps, 1. - self.eps).detach().requires_grad_(True)
104 |
105 | if self.parameterization == 'uv':
106 | self.v0 = args.v0
107 | self.v_ = nn.Parameter(torch.tensor(self.v0)).to(self.device)
108 | self.v = torch.clamp(self.v_, self.eps, 1. - self.eps).detach().requires_grad_(True)
109 |
110 | logger.info('Init | u {} | v {}'.format(self.u.data, (self.v if self.v is None else self.v.data)))
111 |
112 | self.alpha, self.beta, self.c_sol = rk_param_tableau(self.u, self.v, device = self.device,
113 | parameterization=self.parameterization,
114 | method = self.method)
115 | self.tableau = _ButcherTableau(alpha = self.alpha,
116 | beta = self.beta,
117 | c_sol = self.c_sol,
118 | c_error = torch.zeros((len(self.c_sol),), device = self.device))
119 |
120 | def forward(self, x):
121 | self.integration_time = self.integration_time.type_as(x)
122 |
123 | if self.method in ['rk4_param', 'rk3_param']:
124 | out = odeint_plus(self.odefunc, x, self.integration_time,
125 | method=self.method, options = {'tableau':self.tableau, 'step_size':self.step_size})
126 | else:
127 | out = odeint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol,
128 | method=self.method, options = {'step_size':self.step_size})
129 |
130 | return out[1]
131 |
132 | @property
133 | def nfe(self):
134 | return self.odefunc.nfe
135 |
136 | @nfe.setter
137 | def nfe(self, value):
138 | self.odefunc.nfe = value
139 |
140 |
141 | if __name__=="__main__":
142 | makedirs(args.save)
143 | logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=None)
144 | logger.info(args)
145 |
146 | device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
147 |
148 | ########### Build model
149 | is_odenet = args.network == 'odenet'
150 |
151 | downsampling_layers = build_downsampling_layers(args.downsampling_method)
152 | fc_layers = build_fc_layers()
153 |
154 | feature_layers = [ODEBlock(ODEfunc(64), args, device)] if is_odenet else [ResBlock(64, 64) for _ in range(6)]
155 | model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device)
156 |
157 | logger.info(model)
158 | logger.info('Number of parameters: {}'.format(count_parameters(model)))
159 |
160 | ########### Create data loaders
161 | train_loader, test_loader, train_eval_loader = get_mnist_loaders(
162 | args.data_aug, args.batch_size, args.test_batch_size, data_root = args.data_root)
163 |
164 | data_gen = inf_generator(train_loader)
165 | batches_per_epoch = len(train_loader)
166 |
167 | ########### Creare criterion and optimizer
168 | criterion = nn.CrossEntropyLoss().to(device)
169 |
170 | lr_fn = learning_rate_with_decay(
171 | args.batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[60, 100, 140],
172 | decay_rates=[1, 0.1, 0.01, 0.001], lr0 = args.lr)
173 |
174 | if is_odenet and (args.method in ['rk4_param', 'rk3_param']) and not args.fix_param:
175 | params_uv = []
176 |
177 | for mname, m in model.named_modules():
178 | if isinstance(m, ODEBlock):
179 | if m.u_ is not None:
180 | params_uv.append(m.u_)
181 | if m.v_ is not None:
182 | params_uv.append(m.v_)
183 |
184 | optimizer = optim.SGD([{"params" : model.parameters()},
185 | {"params" : params_uv, 'lr' : args.lr_uv}], lr=args.lr, momentum = 0.9)
186 | else:
187 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum = 0.9)
188 |
189 |
190 | ########### Train the model
191 | best_acc = 0
192 | batch_time_meter = RunningAverageMeter()
193 | f_nfe_meter = RunningAverageMeter()
194 | b_nfe_meter = RunningAverageMeter()
195 | end = time.time()
196 |
197 | for itr in range(args.nepochs * batches_per_epoch):
198 |
199 | for param_group in optimizer.param_groups:
200 | param_group['lr'] = lr_fn(itr)
201 |
202 | if is_odenet and (args.method in ['rk4_param', 'rk3_param']) and not args.fix_param:
203 |
204 | ### Iterate along the model, find ODEBlock, recalculate the tableau
205 | for mname, m in model.named_modules():
206 | if isinstance(m, ODEBlock):
207 | m.alpha, m.beta, m.c_sol = rk_param_tableau(m.u, m.v, device = device,
208 | parameterization=args.parameterization,
209 | method = args.method)
210 | m.tableau = _ButcherTableau(alpha = m.alpha, beta = m.beta, c_sol = m.c_sol,
211 | c_error = torch.zeros((len(m.c_sol),), device = device))
212 |
213 | optimizer.zero_grad()
214 | x, y = data_gen.__next__()
215 | x = x.to(device)
216 | y = y.to(device)
217 |
218 | logits = model(x)
219 |
220 | loss = criterion(logits, y)
221 |
222 | if is_odenet:
223 | nfe_forward = feature_layers[0].nfe
224 | feature_layers[0].nfe = 0
225 |
226 | loss.backward()
227 | optimizer.step()
228 |
229 | if is_odenet:
230 | nfe_backward = feature_layers[0].nfe
231 | feature_layers[0].nfe = 0
232 |
233 | batch_time_meter.update(time.time() - end)
234 | if is_odenet:
235 | f_nfe_meter.update(nfe_forward)
236 | b_nfe_meter.update(nfe_backward)
237 | end = time.time()
238 |
239 | if itr % batches_per_epoch == 0:
240 | with torch.no_grad():
241 | train_acc = accuracy(model, train_eval_loader, device)
242 | val_acc = accuracy(model, test_loader, device)
243 | if val_acc > best_acc:
244 | torch.save({'state_dict': model.state_dict(), 'args': args}, os.path.join(args.save, 'model.pth'))
245 | best_acc = val_acc
246 |
247 | logger.info(
248 | "Epoch {:04d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f} | "
249 | "TrainAcc {:.6f} | TestAcc {:.6f}".format(
250 | itr // batches_per_epoch, batch_time_meter.val, batch_time_meter.avg, f_nfe_meter.avg,
251 | b_nfe_meter.avg, train_acc, val_acc))
252 |
253 | #CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=3 python3 runner_old.py --network 'odenet' --data_root "/workspace/raid/data/datasets" --batch_size 128 --test_batch_size 1000 --nepochs 160 --lr 0.1 --save "./experiment1" --method 'euler' --step_size 1.0 --gpu 0 --torch_seed 502 --numpy_seed 502 --fix_param
--------------------------------------------------------------------------------
/sopa/src/models/odenet_mnist/train_validate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | from argparse import Namespace
8 | from .metrics import accuracy
9 | from sopa.src.solvers.utils import noise_params
10 | from MegaAdversarial.src.attacks import (
11 | Clean,
12 | PGD,
13 | FGSM
14 | )
15 |
16 | CONFIG_PGD_TRAIN = {"eps": 0.3, "lr": 2.0 / 255, "n_iter": 7}
17 | CONFIG_FGSM_TRAIN = {"eps": 0.3}
18 |
19 | def train(itr,
20 | model,
21 | data_gen,
22 | solvers,
23 | solver_options,
24 | criterion,
25 | optimizer,
26 | batch_time_meter,
27 | f_nfe_meter,
28 | b_nfe_meter,
29 | device = 'cpu',
30 | dtype = torch.float32,
31 | is_odenet = True,
32 | args = None,
33 | logger = None,
34 | wandb_logger = None):
35 |
36 | end = time.time()
37 |
38 | optimizer.zero_grad()
39 | x, y = data_gen.__next__()
40 | x = x.to(device)
41 | y = y.to(device)
42 |
43 | ##### Noise params
44 | if args.noise_type is not None:
45 | for i in range(len(solvers)):
46 | solvers[i].u, solvers[i].v = noise_params(solvers[i].u0,
47 | solvers[i].v0,
48 | std = args.noise_sigma,
49 | bernoulli_p = args.noise_prob,
50 | noise_type = args.noise_type)
51 | solvers[i].build_ButcherTableau()
52 |
53 | if args.adv_training_mode == "clean":
54 | train_attack = Clean(model)
55 | elif args.adv_training_mode == "fgsm":
56 | train_attack = FGSM(model, **CONFIG_FGSM_TRAIN)
57 | elif args.adv_training_mode == "at":
58 | train_attack = PGD(model, **CONFIG_PGD_TRAIN)
59 | else:
60 | raise ValueError("Attack type not understood.")
61 | x, y = train_attack(x, y, {"solvers": solvers, "solver_options": solver_options})
62 |
63 | # Add noise:
64 | if args.data_noise_std > 1e-12:
65 | with torch.no_grad():
66 | x = x + args.data_noise_std * torch.randn_like(x)
67 | ##### Forward pass
68 | if is_odenet:
69 | logits = model(x, solvers, solver_options, Namespace(ss_loss=args.ss_loss))
70 | else:
71 | logits = model(x)
72 |
73 | xentropy = criterion(logits, y)
74 | if args.ss_loss:
75 | ss_loss = model.get_ss_loss()
76 | loss = xentropy + args.ss_loss_reg * ss_loss
77 | else:
78 | ss_loss = 0.
79 | loss = xentropy
80 | if wandb_logger is not None:
81 | wandb_logger.log({"xentropy": xentropy.item(),
82 | "ss_loss": ss_loss,
83 | "loss": loss.item(),
84 | "log_func": "train"})
85 | # if logger is not None:
86 | # fix
87 |
88 | ##### Compute NFE-forward
89 | if is_odenet:
90 | nfe_forward = 0
91 | for i in range(len(model.blocks)):
92 | nfe_forward += model.blocks[i].rhs_func.nfe
93 | model.blocks[i].rhs_func.nfe = 0
94 |
95 | loss.backward()
96 | optimizer.step()
97 |
98 | ##### Compute NFE-backward
99 | if is_odenet:
100 | nfe_backward = 0
101 | for i in range(len(model.blocks)):
102 | nfe_backward += model.blocks[i].rhs_func.nfe
103 | model.blocks[i].rhs_func.nfe = 0
104 |
105 | ##### Denoise params
106 | if args.noise_type is not None:
107 | for i in range(len(solvers)):
108 | solvers[i].u, solvers[i].v = solvers[i].u0, solvers[i].v0
109 | solvers[i].build_ButcherTableau()
110 |
111 | batch_time_meter.update(time.time() - end)
112 | if is_odenet:
113 | f_nfe_meter.update(nfe_forward)
114 | b_nfe_meter.update(nfe_backward)
115 |
116 |
117 |
118 | def validate_standalone(best_acc,
119 | itr,
120 | model,
121 | train_eval_loader,
122 | test_loader,
123 | batches_per_epoch,
124 | solvers,
125 | solver_options,
126 | batch_time_meter,
127 | f_nfe_meter,
128 | b_nfe_meter,
129 | device = 'cpu',
130 | dtype = torch.float32,
131 | args = None,
132 | logger = None,
133 | wandb_logger=None):
134 |
135 | nsolvers = len(solvers)
136 |
137 | with torch.no_grad():
138 | train_acc = [0] * nsolvers
139 | val_acc = [0] * nsolvers
140 |
141 | for solver_id, solver in enumerate(solvers):
142 |
143 | train_acc_id = accuracy(model, train_eval_loader, device, [solver], solver_options)
144 | val_acc_id = accuracy(model, test_loader, device, [solver], solver_options)
145 |
146 | train_acc[solver_id] = train_acc_id
147 | val_acc[solver_id] = val_acc_id
148 |
149 |
150 | if val_acc_id > best_acc[solver_id]:
151 | best_acc[solver_id] = val_acc_id
152 |
153 | torch.save({'state_dict': model.state_dict(),
154 | 'args': args,
155 | 'solver_id':solver_id,
156 | 'val_solver_mode':solver_options.solver_mode,
157 | 'acc': val_acc_id},
158 | os.path.join(os.path.join(args.save, str(args.timestamp)),
159 | 'model_best_{}.pth'.format(solver_id)))
160 | if wandb_logger is not None:
161 | wandb_logger.save(os.path.join(os.path.join(args.save, str(args.timestamp)),
162 | 'model_best_{}.pth'.format(solver_id)))
163 |
164 | if logger is not None:
165 | logger.info("Epoch {:04d} | SolverMode {} | SolverId {} | "
166 | "TrainAcc {:.10f} | TestAcc {:.10f} | BestAcc {:.10f}".format(
167 | itr // batches_per_epoch, solver_options.solver_mode, solver_id,
168 | train_acc_id, val_acc_id, best_acc[solver_id]))
169 | if wandb_logger is not None:
170 | wandb_logger.log({
171 | "epoch": itr // batches_per_epoch,
172 | "solver_mode": solver_options.solver_mode,
173 | "solver_id": solver_id,
174 | "train_acc": train_acc_id,
175 | "test_acc": val_acc_id,
176 | "best_acc": best_acc[solver_id],
177 | "log_func": "validate_standalone"
178 | })
179 | for i in range(len(model.blocks)):
180 | model.blocks[i].rhs_func.nfe = 0
181 |
182 | return best_acc
183 |
184 |
185 |
186 | def validate_ensemble_switch(best_acc,
187 | itr,
188 | model,
189 | train_eval_loader,
190 | test_loader,
191 | batches_per_epoch,
192 | solvers,
193 | solver_options,
194 | batch_time_meter,
195 | f_nfe_meter,
196 | b_nfe_meter,
197 | device = 'cpu',
198 | dtype = torch.float32,
199 | args = None,
200 | logger=None,
201 | wandb_logger=None):
202 |
203 | nsolvers = len(solvers)
204 |
205 | with torch.no_grad():
206 |
207 | train_acc = accuracy(model, train_eval_loader, device, solvers, solver_options)
208 | val_acc = accuracy(model, test_loader, device, solvers, solver_options)
209 |
210 | if val_acc > best_acc:
211 | best_acc = val_acc
212 |
213 | torch.save({'state_dict': model.state_dict(),
214 | 'args': args,
215 | 'solver_id':None,
216 | 'val_solver_mode':solver_options.solver_mode,
217 | 'acc': val_acc},
218 | os.path.join(os.path.join(args.save, str(args.timestamp)),
219 | 'model_best.pth'))
220 | if wandb_logger is not None:
221 | wandb_logger.save(os.path.join(os.path.join(args.save, str(args.timestamp)),
222 | 'model_best.pth'))
223 |
224 |
225 |
226 | if logger is not None:
227 | logger.info("Epoch {:04d} | SolverMode {} | SolverId {} | "
228 | "TrainAcc {:.10f} | TestAcc {:.10f} | BestAcc {:.10f}".format(
229 | itr // batches_per_epoch, solver_options.solver_mode, None,
230 | train_acc, val_acc, best_acc))
231 | if wandb_logger is not None:
232 | wandb_logger.log({
233 | "epoch": itr // batches_per_epoch,
234 | "solver_mode": solver_options.solver_mode,
235 | "solver_id": None,
236 | "train_acc": train_acc,
237 | "test_acc": val_acc,
238 | "best_acc": best_acc,
239 | "log_func": "validate_ensemble_switch"
240 | })
241 |
242 |
243 | for i in range(len(model.blocks)):
244 | model.blocks[i].rhs_func.nfe = 0
245 |
246 | return best_acc
247 |
248 |
249 |
250 | def validate(best_acc,
251 | itr,
252 | model,
253 | train_eval_loader,
254 | test_loader,
255 | batches_per_epoch,
256 | solvers,
257 | val_solver_modes,
258 | batch_time_meter,
259 | f_nfe_meter,
260 | b_nfe_meter,
261 | device = 'cpu',
262 | dtype = torch.float32,
263 | args = None,
264 | logger = None,
265 | wandb_logger=None):
266 |
267 | for solver_mode in val_solver_modes:
268 |
269 | if solver_mode == 'standalone':
270 |
271 | val_solver_options = Namespace(solver_mode = 'standalone')
272 | best_acc['standalone'] = validate_standalone(best_acc['standalone'],
273 | itr,
274 | model,
275 | train_eval_loader,
276 | test_loader,
277 | batches_per_epoch,
278 | solvers = solvers,
279 | solver_options = val_solver_options,
280 | batch_time_meter = batch_time_meter,
281 | f_nfe_meter = f_nfe_meter,
282 | b_nfe_meter = b_nfe_meter,
283 | device = device,
284 | dtype = dtype,
285 | args = args,
286 | logger = logger,
287 | wandb_logger = wandb_logger)
288 | elif solver_mode == 'ensemble':
289 |
290 | val_solver_options = Namespace(solver_mode = 'ensemble',
291 | ensemble_weights = args.ensemble_weights,
292 | ensemble_prob = args.ensemble_prob)
293 |
294 | best_acc['ensemble'] = validate_ensemble_switch(best_acc['ensemble'],
295 | itr,
296 | model,
297 | train_eval_loader,
298 | test_loader,
299 | batches_per_epoch,
300 | solvers = solvers,
301 | solver_options = val_solver_options,
302 | batch_time_meter = batch_time_meter,
303 | f_nfe_meter = f_nfe_meter,
304 | b_nfe_meter = b_nfe_meter,
305 | device = device,
306 | dtype = dtype,
307 | args = args,
308 | logger = logger,
309 | wandb_logger = wandb_logger)
310 | elif solver_mode == 'switch':
311 |
312 | val_solver_options = Namespace(solver_mode = 'switch', switch_probs = args.switch_probs)
313 |
314 | best_acc['switch'] = validate_ensemble_switch(best_acc['switch'],
315 | itr,
316 | model,
317 | train_eval_loader,
318 | test_loader,
319 | batches_per_epoch,
320 | solvers = solvers,
321 | solver_options = val_solver_options,
322 | batch_time_meter = batch_time_meter,
323 | f_nfe_meter = f_nfe_meter,
324 | b_nfe_meter = b_nfe_meter,
325 | device = device,
326 | dtype = dtype,
327 | args = args,
328 | logger = logger,
329 | wandb_logger = wandb_logger)
330 | if logger is not None:
331 | logger.info("Epoch {:04d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f}".format(
332 | itr // batches_per_epoch,
333 | batch_time_meter.val, batch_time_meter.avg,
334 | f_nfe_meter.avg, b_nfe_meter.avg))
335 | if wandb_logger is not None:
336 | wandb_logger.log({
337 | "epoch": itr // batches_per_epoch,
338 | "batch_time_val": batch_time_meter.val,
339 | "nfe": f_nfe_meter.avg,
340 | "nbe": b_nfe_meter.avg,
341 | "log_func": "validate"
342 | })
343 | return best_acc
--------------------------------------------------------------------------------
/sopa/src/models/odenet_mnist/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import logging
4 |
5 |
6 | ######### args.lr inside !!!
7 | def learning_rate_with_decay(batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates, lr0):
8 | initial_learning_rate = lr0 * batch_size / batch_denom
9 |
10 | boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
11 | vals = [initial_learning_rate * decay for decay in decay_rates]
12 |
13 | def learning_rate_fn(itr):
14 | lt = [itr < b for b in boundaries] + [True]
15 | i = np.argmax(lt)
16 | return vals[i]
17 |
18 | return learning_rate_fn
19 |
20 |
21 | def count_parameters(model):
22 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
23 |
24 |
25 | def makedirs(dirname):
26 | if not os.path.exists(dirname):
27 | os.makedirs(dirname)
28 |
29 |
30 | def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False):
31 | logger = logging.getLogger()
32 | if debug:
33 | level = logging.DEBUG
34 | else:
35 | level = logging.INFO
36 | logger.setLevel(level)
37 | if saving:
38 | info_file_handler = logging.FileHandler(logpath, mode="a")
39 | info_file_handler.setLevel(level)
40 | logger.addHandler(info_file_handler)
41 | if displaying:
42 | console_handler = logging.StreamHandler()
43 | console_handler.setLevel(level)
44 | logger.addHandler(console_handler)
45 | logger.info(filepath)
46 |
47 | if filepath is not None:
48 | with open(filepath, "r") as f:
49 | logger.info(f.read())
50 |
51 | for f in package_files:
52 | logger.info(f)
53 | with open(f, "r") as package_f:
54 | logger.info(package_f.read())
55 |
56 | return logger
--------------------------------------------------------------------------------
/sopa/src/models/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import random
4 |
5 | from .odenet_mnist.layers import MetaNODE
6 |
7 | def fix_seeds(seed=502):
8 | np.random.seed(seed)
9 | random.seed(seed)
10 | torch.manual_seed(seed)
11 | torch.cuda.manual_seed_all(seed)
12 | torch.set_printoptions(precision=10)
13 |
14 | torch.backends.cudnn.deterministic = True
15 | torch.backends.cudnn.benchmark = False
16 |
17 | class RunningAverageMeter(object):
18 | """Computes and stores the average and current value"""
19 |
20 | def __init__(self, momentum=0.99):
21 | self.momentum = momentum
22 | self.reset()
23 |
24 | def reset(self):
25 | self.val = None
26 | self.avg = 0
27 |
28 | def update(self, val):
29 | if self.val is None:
30 | self.avg = val
31 | else:
32 | self.avg = self.avg * self.momentum + val * (1 - self.momentum)
33 | self.val = val
34 |
35 |
36 | def load_model(path):
37 | (_, state_dict), (_, model_args), (_, slover_id) = torch.load(path, map_location='cpu').items()
38 |
39 | is_odenet = model_args.network == 'odenet'
40 |
41 | if not hasattr(model_args, 'in_channels'):
42 | model_args.in_channels = 1
43 |
44 | model = MetaNODE(downsampling_method=model_args.downsampling_method,
45 | is_odenet=is_odenet,
46 | in_channels=model_args.in_channels)
47 | model.load_state_dict(state_dict)
48 |
49 | return model, model_args
--------------------------------------------------------------------------------
/sopa/src/solvers/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/sopa/src/solvers/euler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .rk_parametric import RKParametricSolver
5 |
6 |
7 | class Euler(RKParametricSolver):
8 | def __init__(self, parameterization=None, u0 = None, v0 = None, dtype=None, device=None, **kwargs):
9 | super(Euler, self).__init__(**kwargs)
10 | self.dtype = dtype
11 | self.device = device
12 |
13 | self.parameterization = None
14 |
15 | self.u = None
16 | self.u0 = None
17 | self.v = None
18 | self.v0 = None
19 |
20 | self.build_ButcherTableau()
21 |
22 |
23 | def _compute_c(self):
24 | self.c1 = torch.tensor((0.,),dtype=self.dtype, device=self.device)
25 |
26 |
27 | def _compute_b(self):
28 | self.b1 = torch.tensor((1.,),dtype=self.dtype, device=self.device)
29 |
30 |
31 | def _compute_w(self):
32 | self.w11 = torch.tensor((0.,),dtype=self.dtype, device=self.device)
33 |
34 |
35 | def _make_u_valid(self, eps):
36 | pass
37 |
38 |
39 | def _make_params_valid(self):
40 | pass
41 |
42 |
43 | def _get_c(self):
44 | c = torch.tensor([self.c1,])
45 | return c
46 |
47 |
48 | def _get_w(self):
49 | w = [torch.tensor([self.w11,])]
50 | return w
51 |
52 |
53 | def _get_b(self):
54 | b = torch.tensor([self.b1,])
55 | return b
56 |
57 |
58 | def _get_t(self, t, dt):
59 | t0 = t
60 | return t0
61 |
62 |
63 | def _make_step(self, rhs_func, x, t, dt):
64 | t0 = self._get_t(t, dt)
65 |
66 | k1 = rhs_func(t0, x)
67 |
68 | return (k1 * self.b1) * dt
69 |
70 |
71 | def freeze_params(self):
72 | pass
73 |
74 |
75 | def unfreeze_params(self):
76 | pass
77 |
78 |
79 | @property
80 | def order(self):
81 | return 1
82 |
--------------------------------------------------------------------------------
/sopa/src/solvers/rk_parametric.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import abc
4 |
5 | class RKParametricSolver(object, metaclass = abc.ABCMeta):
6 | def __init__(self, n_steps = None, step_size = None, grid_constructor=None):
7 |
8 | ## Compute number of grid related args != None
9 | if sum(1 for _ in filter(None.__ne__, [n_steps, step_size, grid_constructor])) >= 2:
10 | raise ValueError("n_steps, step_size and grid_constructor are pairwise exclusive arguments.")
11 |
12 | ## Initialize time grid
13 | if n_steps is not None:
14 | self.grid_constructor = self._grid_constructor_from_n_steps(n_steps)
15 | elif step_size is not None:
16 | self.grid_constructor = self._grid_constructor_from_step_size(step_size)
17 | elif grid_constructor is not None:
18 | self.grid_constructor = grid_constructor
19 | else:
20 | self.grid_constructor = lambda t: t
21 |
22 |
23 | def _grid_constructor_from_step_size(self, step_size):
24 |
25 | def _grid_constructor(t):
26 | start_time = t[0]
27 | end_time = t[-1]
28 |
29 | n_steps = torch.ceil((end_time - start_time) / step_size + 1).item()
30 | t_infer = torch.arange(0, n_steps).to(t) * step_size + start_time
31 | if t_infer[-1] > t[-1]:
32 | t_infer[-1] = t[-1]
33 | return t_infer
34 |
35 | return _grid_constructor
36 |
37 |
38 | def _grid_constructor_from_n_steps(self, n_steps):
39 |
40 | def _grid_constructor(t):
41 | start_time = t[0]
42 | end_time = t[-1]
43 |
44 | t_infer = torch.linspace(start_time, end_time, n_steps + 1).to(t)
45 | return t_infer
46 |
47 | return _grid_constructor
48 |
49 |
50 | @property
51 | @abc.abstractmethod
52 | def order(self):
53 | pass
54 |
55 | @abc.abstractmethod
56 | def freeze_params(self):
57 | pass
58 |
59 | @abc.abstractmethod
60 | def unfreeze_params(self):
61 | pass
62 |
63 | @abc.abstractmethod
64 | def _make_params_valid(self):
65 | pass
66 |
67 |
68 | def build_ButcherTableau(self, return_tableau = False):
69 | self._make_params_valid()
70 | self._compute_c()
71 | self._compute_b()
72 | self._compute_w()
73 |
74 | if return_tableau:
75 | return self._collect_ButcherTableau()
76 |
77 |
78 | def _collect_ButcherTableau(self):
79 | c = self._get_c()
80 | w = self._get_w()
81 | b = self._get_b()
82 | return c, w, b
83 |
84 |
85 | @abc.abstractmethod
86 | def _make_step(self, rhs_func, x, t, dt):
87 | pass
88 |
89 | def integrate(self, rhs_func, x, t):
90 | # _assert_increasing(t)
91 | t = t.type_as(x[0])
92 |
93 | time_grid = self.grid_constructor(t)
94 |
95 | assert time_grid[0] == t[0] and time_grid[-1] == t[-1]
96 | time_grid = time_grid.to(x[0])
97 |
98 | # print('\ntime_grid (for evaluation):', time_grid)
99 |
100 | solution = [x]
101 |
102 | j = 1
103 | y0 = x # x has shape (batch_size, *x.shape)
104 | for t0, t1 in zip(time_grid[:-1], time_grid[1:]):
105 | dy = self._make_step(rhs_func, y0, t0, t1 - t0) # dy has shape (batch_size, *x.shape)
106 | y1 = y0 + dy
107 |
108 | # interpolate values at intermediate points
109 | while j < len(t) and t1 >= t[j]:
110 | solution.append(self._linear_interp(t0, t1, y0, y1, t[j]))
111 | j += 1
112 | y0 = y1
113 | return torch.stack(solution) # has shape (len(t), batch_size, *x.shape)
114 |
115 |
116 | def _linear_interp(self, t0, t1, y0, y1, t):
117 | if t == t0:
118 | return y0
119 | if t == t1:
120 | return y1
121 | t0, t1, t = t0.to(y0[0]), t1.to(y0[0]), t.to(y0[0])
122 | slope = (y1 - y0) / (t1 - t0) # slope has shape (batch_size, *x.shape)
123 | return y0 + slope * (t - t0)
124 |
125 |
126 | def print_is_requires_grad(self):
127 | print('\nIs requires grad? (RK solver)')
128 |
129 | for pname, p in self.__dict__.items():
130 | if hasattr(p, 'requires_grad'):
131 | print(pname, p.requires_grad)
132 |
--------------------------------------------------------------------------------
/sopa/src/solvers/rk_parametric_order2stage2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .rk_parametric import RKParametricSolver
5 |
6 | def build_ButcherTableau_Midpoint(dtype=None, device=None):
7 | c = torch.tensor([0., 1/2.],dtype=dtype, device=device)
8 | w = [torch.tensor([0.,], dtype=dtype, device=device)] + [torch.tensor([1/2., 0.],dtype=dtype, device=device)]
9 | b = torch.tensor([0., 1.],dtype=dtype, device=device)
10 | return c, w, b
11 |
12 |
13 | def build_ButcherTableau_Heun(dtype=None, device=None):
14 | c = torch.tensor([0., 1.],dtype=dtype, device=device)
15 | w = [torch.tensor([0.,],dtype=dtype, device=device)] + [torch.tensor([1., 0.],dtype=dtype, device=device)]
16 | b = torch.tensor([1/2., 1/2.],dtype=dtype, device=device)
17 | return c, w, b
18 |
19 |
20 | class RKOrder2Stage2(RKParametricSolver):
21 | def __init__(self, parameterization='u', u0 = None, v0 = None, dtype=None, device=None, **kwargs):
22 | super(RKOrder2Stage2, self).__init__(**kwargs)
23 | self.dtype = dtype
24 | self.device = device
25 |
26 | if parameterization != 'u':
27 | raise ValueError('Unknown parameterization for RKOrder2Stage2 solver')
28 | self.parameterization = parameterization
29 | self.u = nn.Parameter(data = torch.tensor((u0,), dtype=self.dtype, device=self.device))
30 | self.u0 = torch.tensor((u0,), dtype=self.dtype, device=self.device)
31 | self.v = None
32 | self.v0 = None
33 |
34 | self.build_ButcherTableau()
35 |
36 |
37 | def _compute_c(self):
38 | self.c1 = torch.tensor((0.,),dtype=self.dtype, device=self.device)
39 | self.c2 = self.u_.clone()
40 |
41 |
42 | def _compute_b(self):
43 | self.b2 = 1. / (2 * self.u_)
44 | self.b1 = 1. - self.b2
45 |
46 |
47 | def _compute_w(self):
48 | self.w21 = self.c2
49 | self.w11, self.w22 = [torch.tensor((0.,),dtype=self.dtype, device=self.device) for _ in range(2)]
50 |
51 |
52 | def _make_u_valid(self, eps):
53 | self.u_ = torch.clamp(self.u, eps, 1.)
54 |
55 |
56 | def _make_params_valid(self):
57 | if self.u.dtype == torch.float64:
58 | eps = torch.finfo(torch.float32).eps
59 | elif self.u.dtype == torch.float32:
60 | eps = torch.finfo(torch.float16).eps
61 |
62 | self._make_u_valid(eps)
63 |
64 |
65 | def _get_c(self):
66 | c = torch.tensor([self.c1, self.c2])
67 | return c
68 |
69 |
70 | def _get_w(self):
71 | w = [torch.tensor([self.w11,])] + [
72 | torch.tensor([self.w21, self.w22])]
73 | return w
74 |
75 |
76 | def _get_b(self):
77 | b = torch.tensor([self.b1, self.b2])
78 | return b
79 |
80 |
81 | def _get_t(self, t, dt):
82 | t0 = t
83 | t1 = t + self.c2 * dt
84 | return (t0, t1)
85 |
86 |
87 | def _make_step(self, rhs_func, x, t, dt):
88 | t0, t1 = self._get_t(t, dt)
89 |
90 | k1 = rhs_func(t0, x)
91 | k2 = rhs_func(t1, x + k1 * self.w21 * dt)
92 |
93 | return (k1 * self.b1 + k2 * self.b2) * dt
94 |
95 |
96 | def freeze_params(self):
97 | self.u.requires_grad = False
98 | if self.v is not None:
99 | self.v.requires_grad = False
100 |
101 | self.build_ButcherTableau() # recompute params to set non leaf requires_grad to False
102 |
103 |
104 | def unfreeze_params(self):
105 | self.u.requires_grad = True
106 | if self.v is not None:
107 | self.v.requires_grad = True
108 |
109 | self.build_ButcherTableau() # recompute params to set non leaf requires_grad to True
110 |
111 | @property
112 | def order(self):
113 | return 2
--------------------------------------------------------------------------------
/sopa/src/solvers/rk_parametric_order3stage3.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .rk_parametric import RKParametricSolver
5 |
6 |
7 | class RKOrder3Stage3(RKParametricSolver):
8 | def __init__(self, parameterization = 'uv', u0 = 1/3., v0 = 2/3., dtype=None, device=None, **kwargs):
9 | super(RKOrder3Stage3, self).__init__(**kwargs)
10 | self.dtype = dtype
11 | self.device = device
12 |
13 | if parameterization != 'uv':
14 | raise ValueError('Unknown parameterization for RKOrder3Stage3 solver')
15 | self.parameterization = parameterization
16 | self.u = nn.Parameter(data = torch.tensor((u0,), dtype=self.dtype, device=self.device))
17 | self.u0 = torch.tensor((u0,), dtype=self.dtype, device=self.device)
18 |
19 | self.v = nn.Parameter(data = torch.tensor((v0,), dtype=self.dtype, device=self.device))
20 | self.v0 = torch.tensor((v0,), dtype=self.dtype, device=self.device)
21 |
22 | self.build_ButcherTableau()
23 |
24 |
25 | def _compute_c(self):
26 | self.c1 = torch.tensor((0.,), dtype=self.dtype, device=self.device)
27 | self.c2 = self.u_.clone()
28 | self.c3 = self.v_.clone()
29 |
30 |
31 | def _compute_b(self):
32 | v_sub_u = self.v_ - self.u_
33 |
34 | self.b2 = (2. - 3. * self.v_) / (6. * self.u_ * (-v_sub_u))
35 | self.b3 = (2. - 3. * self.u_) / (6. * self.v_ * v_sub_u)
36 | self.b1 = 1. - self.b2 - self.b3
37 |
38 |
39 | def _compute_w(self):
40 | self.w32 = self.v_ * (self.v_ - self.u_) / (self.u_ * (2. - 3. * self.u_))
41 | self.w31 = self.c3 - self.w32
42 | self.w21 = self.c2
43 |
44 | self.w11, self.w22, self.w33 = [torch.tensor((0.,), dtype=self.dtype, device=self.device) for _ in range(3)]
45 |
46 |
47 | def _make_u_valid(self, eps):
48 | self.u_ = torch.clamp(self.u, eps, 1.)
49 |
50 |
51 | def _make_v_valid(self, eps):
52 | self.v_ = torch.clamp(self.v, eps, 1.)
53 |
54 |
55 | def _make_params_valid(self):
56 | if self.u.dtype == torch.float64:
57 | eps = torch.finfo(torch.float32).eps
58 | elif self.u.dtype == torch.float32:
59 | eps = torch.finfo(torch.float16).eps
60 |
61 | self._make_u_valid(eps)
62 | self._make_v_valid(eps)
63 |
64 | if self.u_ == self.v_:
65 | if self.u_ < 1. - eps:
66 | self.v_ = self.u_ + eps
67 | else:
68 | self.u_ = self.v_ - eps
69 |
70 |
71 | def _get_c(self):
72 | c = torch.tensor([self.c1, self.c2, self.c3])
73 | return c
74 |
75 |
76 | def _get_w(self):
77 | w = [torch.tensor([self.w11,])] + [
78 | torch.tensor([self.w21, self.w22])] + [
79 | torch.tensor([self.w31, self.w32, self.w33])]
80 | return w
81 |
82 |
83 | def _get_b(self):
84 | b = torch.tensor([self.b1, self.b2, self.b3])
85 | return b
86 |
87 |
88 | def _get_t(self, t, dt):
89 | t0 = t
90 | t1 = t + self.c2 * dt
91 | t2 = t + self.c3 * dt
92 |
93 | return (t0, t1, t2)
94 |
95 |
96 | def _make_step(self, rhs_func, x, t, dt):
97 | t0, t1, t2 = self._get_t(t, dt)
98 |
99 | k1 = rhs_func(t0, x)
100 | k2 = rhs_func(t1, x + k1 * self.w21 * dt)
101 | k3 = rhs_func(t2, x + (k1 * self.w31 + k2 * self.w32) * dt)
102 |
103 | return (k1 * self.b1 + k2 * self.b2 + k3 * self.b3) * dt
104 |
105 |
106 | def freeze_params(self):
107 | self.u.requires_grad = False
108 | if self.v is not None:
109 | self.v.requires_grad = False
110 |
111 | self.build_ButcherTableau() # recompute params to set non leaf requires_grad to False
112 |
113 |
114 | def unfreeze_params(self):
115 | self.u.requires_grad = True
116 | if self.v is not None:
117 | self.v.requires_grad = True
118 |
119 | self.build_ButcherTableau() # recompute params to set non leaf requires_grad to True
120 |
121 | @property
122 | def order(self):
123 | return 3
--------------------------------------------------------------------------------
/sopa/src/solvers/rk_parametric_order4stage4.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .rk_parametric import RKParametricSolver
5 |
6 | def build_ButcherTableau_RKStandard(dtype=None, device=None):
7 | c = torch.tensor([0., 1/2., 1/2., 1.], dtype=dtype, device=device)
8 | w = [torch.tensor([0.,], dtype=dtype, device=device)] + [torch.tensor([1/2., 0.], dtype=dtype, device=device)] + [torch.tensor(w_i, dtype=dtype, device=device) for w_i in [[0., 1/2., 0.], [0., 0., 1., 0.]]]
9 | b = torch.tensor([1/6., 1/3., 1/3., 1/6.])
10 | return c, w, b
11 |
12 |
13 | def build_ButcherTableau_RK38(dtype=None, device=None):
14 | c = torch.tensor([0., 1/3., 2/3., 1.], dtype=dtype, device=device)
15 | w = [torch.tensor([0.,], dtype=dtype, device=device)] + [torch.tensor([1/3., 0.], dtype=dtype, device=device)] + [torch.tensor(w_i, dtype=dtype, device=device) for w_i in [[-1/3., 1., 0.], [1., -1., 1., 0.]]]
16 | b = torch.tensor([1/8., 3/8., 3/8., 1/8.], dtype=dtype, device=device)
17 | return c, w, b
18 |
19 |
20 | class RKOrder4Stage4(RKParametricSolver):
21 | def __init__(self, parameterization = 'u2', u0 = 1/3., v0 = 2/3., dtype=torch.float64, device='cpu', **kwargs):
22 | super(RKOrder4Stage4, self).__init__(**kwargs)
23 | self.dtype = dtype
24 | self.device = device
25 |
26 | self.parameterization = parameterization
27 | self.u = nn.Parameter(data = torch.tensor((u0,), dtype=self.dtype, device=self.device))
28 | self.u0 = torch.tensor((u0,), dtype=self.dtype, device=self.device)
29 |
30 | if self.parameterization == 'uv':
31 | self.v = nn.Parameter(data = torch.tensor((v0,), dtype=self.dtype, device=self.device))
32 | self.v0 = torch.tensor((v0,), dtype=self.dtype, device=self.device)
33 | else:
34 | self.v = None
35 | self.v0 = None
36 |
37 | self.build_ButcherTableau()
38 |
39 |
40 | def _compute_c(self):
41 | self.c1 = torch.tensor((0.,), dtype=self.dtype, device=self.device)
42 |
43 | if self.parameterization == 'u1':
44 | self.c2 = torch.tensor((0.5,), dtype=self.dtype, device=self.device)
45 | self.c3 = torch.tensor((0.,), dtype=self.dtype, device=self.device)
46 |
47 | elif self.parameterization == 'u2':
48 | self.c2 = torch.tensor((0.5,), dtype=self.dtype, device=self.device)
49 | self.c3 = torch.tensor((0.5,), dtype=self.dtype, device=self.device)
50 |
51 | elif self.parameterization == 'u3':
52 | self.c2 = torch.tensor((1.,), dtype=self.dtype, device=self.device)
53 | self.c3 = torch.tensor((0.5,), dtype=self.dtype, device=self.device)
54 |
55 | elif self.parameterization == 'uv':
56 | self.c2 = self.u_.clone() # .clone() is nedeed, because without it self.c2 will be nn.Parameter
57 | self.c3 = self.v_.clone()
58 |
59 | self.c4 = torch.tensor((1.,), dtype=self.dtype, device=self.device)
60 |
61 |
62 |
63 |
64 | def _compute_b(self):
65 | if self.parameterization == 'u1':
66 | self.b1 = torch.tensor((1/6.,),dtype=self.dtype, device=self.device) - self.u_
67 | self.b2 = torch.tensor((2/3.,),dtype=self.dtype, device=self.device)
68 | self.b3 = self.u_.clone()
69 | self.b4 = torch.tensor((1/6.,),dtype=self.dtype, device=self.device)
70 |
71 | elif self.parameterization == 'u2':
72 | self.b1 = torch.tensor((1/6.,),dtype=self.dtype, device=self.device)
73 | self.b2 = torch.tensor((2/3.,),dtype=self.dtype, device=self.device) - self.u_
74 | self.b3 = self.u_.clone()
75 | self.b4 = torch.tensor((1/6.,),dtype=self.dtype, device=self.device)
76 |
77 | elif self.parameterization == 'u3':
78 | self.b1 = torch.tensor((1/6.,),dtype=self.dtype, device=self.device)
79 | self.b2 = torch.tensor((1/6.,),dtype=self.dtype, device=self.device) - self.u_
80 | self.b3 = torch.tensor((2/3.,),dtype=self.dtype, device=self.device)
81 | self.b4 = self.u_.clone()
82 |
83 | elif self.parameterization == 'uv':
84 | sub_u = 1. - self.u_
85 | sub_v = 1. - self.v_
86 | v_sub_u = self.v_ - self.u_
87 |
88 | self.b2 = (2. * self.v_ - 1.) / (12 * self.u_ * sub_u * v_sub_u)
89 | self.b3 = (1. - 2 * self.u_) / (12 * self.v_ * sub_v * v_sub_u)
90 | self.b4 = (6. * self.u_ * self.v_ + 3. - 4. * self.u_ - 4. * self.v_) / (12 * sub_u * sub_v)
91 | self.b1 = 1. - self.b2 - self.b3 - self.b4
92 |
93 |
94 | def _compute_w(self):
95 | self.w43 = self.b3 * (1 - self.c3) / self.b4
96 |
97 | A00 = self.b3 * self.c3 * self.c2
98 | A01 = self.b4 * self.c4 * self.c2
99 | A10 = self.b3
100 | A11 = self.b4
101 |
102 | B0 = 0.125 - self.b4 * self.c4 * self.c3 * self.w43
103 | B1 = self.b2 * (1 - self.c2)
104 |
105 | ### Find w32, w42 using Cramer's rule
106 | detA = A00 * A11 - A01 * A10
107 | detA0 = B0 * A11 - B1 * A01
108 | detA1 = A00 * B1 - A10 * B0
109 |
110 | self.w32 = detA0 / detA
111 | self.w42 = detA1 / detA
112 | ###
113 |
114 | # ### Find w32, w42 using torch.solve
115 | # A = torch.cat((A00, A01, A10, A11)).reshape((2,2))
116 | # B = torch.cat((B0, B1)).reshape((2,1))
117 | # (self.w32, self.w42), _ = torch.solve(B, A)
118 | # ###
119 |
120 | self.w41 = self.c4 - (self.w42 + self.w43)
121 | self.w31 = self.c3 - self.w32
122 | self.w21 = self.c2
123 |
124 | self.w11, self.w22, self.w33, self.w44 = [torch.tensor((0.,),dtype=self.dtype, device = self.device) for _ in range(4)]
125 |
126 |
127 | def _make_u_valid(self, eps):
128 | if self.v is not None:
129 | if self.u < 0.5:
130 | self.u_ = torch.clamp(self.u, eps, 0.5 - eps)
131 | else:
132 | self.u_ = torch.clamp(self.u, 0.5 + eps, 1. - eps)
133 | else:
134 | self.u_ = torch.clamp(self.u, eps, 1. - eps)
135 |
136 |
137 | def _make_v_valid(self, eps):
138 | self.v_ = torch.clamp(self.v, eps, 1. - eps)
139 |
140 |
141 | def _make_params_valid(self):
142 | if self.u.dtype == torch.float64:
143 | eps = torch.finfo(torch.float32).eps
144 | elif self.u.dtype == torch.float32:
145 | eps = torch.finfo(torch.float16).eps
146 |
147 | self._make_u_valid(eps)
148 |
149 | if self.v is not None:
150 | self._make_v_valid(eps)
151 |
152 | if self.u_ == self.v_:
153 | if self.u_ < 1. - eps:
154 | self.v_ = self.u_ + eps
155 | else:
156 | self.u_ = self.v_ - eps
157 |
158 |
159 | def _get_c(self):
160 | c = torch.tensor([self.c1, self.c2, self.c3, self.c4])
161 | return c
162 |
163 | def _get_w(self):
164 | w = [torch.tensor([self.w11,])] + [
165 | torch.tensor([self.w21, self.w22])] + [
166 | torch.tensor([self.w31, self.w32, self.w33])] + [
167 | torch.tensor([self.w41, self.w42, self.w43, self.w44])]
168 |
169 | return w
170 |
171 | def _get_b(self):
172 | b = torch.tensor([self.b1, self.b2, self.b3, self.b4])
173 | return b
174 |
175 | def _get_t(self, t, dt):
176 | t0 = t
177 | t1 = t + self.c2 * dt
178 | t2 = t + self.c3 * dt
179 | t3 = t + self.c4 * dt
180 |
181 | return (t0, t1, t2, t3)
182 |
183 |
184 | def _make_step(self, rhs_func, x, t, dt):
185 | t0, t1, t2, t3 = self._get_t(t, dt)
186 |
187 | k1 = rhs_func(t0, x)
188 | k2 = rhs_func(t1, x + k1 * self.w21 * dt)
189 | k3 = rhs_func(t2, x + (k1 * self.w31 + k2 * self.w32) * dt)
190 | k4 = rhs_func(t3, x + (k1 * self.w41 + k2 * self.w42 + k3 * self.w43) * dt)
191 |
192 | return (k1 * self.b1 + k2 * self.b2 + k3 * self.b3 + k4 * self.b4) * dt
193 |
194 |
195 | def freeze_params(self):
196 | self.u.requires_grad = False
197 | if self.v is not None:
198 | self.v.requires_grad = False
199 |
200 | self.build_ButcherTableau() # recompute params to set non leaf requires_grad to False
201 |
202 | def unfreeze_params(self):
203 | self.u.requires_grad = True
204 | if self.v is not None:
205 | self.v.requires_grad = True
206 |
207 | self.build_ButcherTableau() # recompute params to set non leaf requires_grad to True
208 |
209 | @property
210 | def order(self):
211 | return 4
--------------------------------------------------------------------------------
/sopa/src/solvers/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import copy
4 |
5 | from torch.distributions.normal import Normal
6 | from torch.distributions.cauchy import Cauchy
7 |
8 | from .rk_parametric_order4stage4 import RKOrder4Stage4
9 | from .rk_parametric_order3stage3 import RKOrder3Stage3
10 | from .rk_parametric_order2stage2 import RKOrder2Stage2
11 | from .euler import Euler
12 |
13 | def create_solver(method, parameterization, n_steps, step_size, u0, v0, dtype, device):
14 | '''
15 | method: str
16 | parameterization: str
17 | n_steps: int
18 | step_size: Decimal
19 | u0: Decimal
20 | v0: Decimal
21 | dtype: torch dtype
22 | '''
23 | if n_steps == -1:
24 | n_steps = None
25 |
26 | if step_size == -1:
27 | step_size = None
28 |
29 | if dtype == torch.float64:
30 | u0, v0 = map(lambda el : np.float64(el), [u0, v0])
31 | elif dtype == torch.float32:
32 | u0, v0 = map(lambda el : np.float32(el), [u0, v0])
33 |
34 | if method == 'euler':
35 | return Euler(n_steps = n_steps,
36 | step_size = step_size,
37 | parameterization=parameterization,
38 | u0 = u0, v0 = v0,
39 | dtype =dtype, device = device)
40 | elif method == 'rk2':
41 | return RKOrder2Stage2(n_steps = n_steps,
42 | step_size = step_size,
43 | parameterization=parameterization,
44 | u0 = u0, v0 = v0,
45 | dtype =dtype, device = device)
46 | elif method == 'rk3':
47 | return RKOrder3Stage3(n_steps = n_steps,
48 | step_size = step_size,
49 | parameterization=parameterization,
50 | u0 = u0, v0 = v0,
51 | dtype =dtype, device = device)
52 | elif method == 'rk4':
53 | return RKOrder4Stage4(n_steps = n_steps,
54 | step_size = step_size,
55 | parameterization=parameterization,
56 | u0 = u0, v0 = v0,
57 | dtype =dtype, device = device)
58 |
59 |
60 | def sample_noise(mu, sigma, noise_type='cauchy', size=1, device='cpu', minimize_rk2_error=False):
61 | if not minimize_rk2_error:
62 | if noise_type == 'cauchy':
63 | d = Cauchy(torch.tensor([mu]), torch.tensor([sigma]))
64 | elif noise_type == 'normal':
65 | d = Normal(torch.tensor([mu]), torch.tensor([sigma]))
66 | else:
67 | if noise_type == 'cauchy':
68 | d = Cauchy(torch.tensor([2/3.]), torch.tensor([2/3. * sigma]))
69 | elif noise_type == 'normal':
70 | d = Normal(torch.tensor([2/3.]), torch.tensor([2/3. * sigma]))
71 |
72 | return torch.tensor([d.sample() for _ in range(size)], device=device)
73 |
74 |
75 | def noise_params(mean_u, mean_v=None, std=0.01, bernoulli_p=1.0, noise_type='cauchy', minimize_rk2_error=False):
76 | ''' Noise solver paramers with Cauchy/Normal noise with probability p
77 | '''
78 | d = torch.distributions.Bernoulli(torch.tensor([bernoulli_p], dtype=torch.float32))
79 | v = None
80 | device = mean_u.device
81 | eps = torch.finfo(mean_u.dtype).eps
82 |
83 | if d.sample():
84 | std = torch.abs(torch.tensor(std, device=device))
85 |
86 | u = sample_noise(mean_u, std, noise_type=noise_type, size=1, device=device, minimize_rk2_error=minimize_rk2_error)
87 | if u <= mean_u - 2*std or u >= mean_u + 2*std:
88 | u = mean_u
89 | # u = min(max(u, mean_u - 2*std,0), mean_u + 2*std, 1.)
90 |
91 | if mean_v is not None:
92 | v = sample_noise(mean_v, std, noise_type=noise_type, size=1, device=device, minimize_rk2_error=minimize_rk2_error)
93 | else:
94 | u = mean_u
95 | if mean_v is not None:
96 | v = mean_v
97 |
98 | return u, v
99 |
100 | def sample_solver_by_noising_params(solver, std=0.01, bernoulli_p=1., noise_type='cauchy', minimize_rk2_error=False):
101 | new_solver = copy.deepcopy(solver)
102 | new_solver.u, new_solver.v = noise_params(mean_u=new_solver.u0,
103 | mean_v=new_solver.v0,
104 | std=std,
105 | bernoulli_p=bernoulli_p,
106 | noise_type=noise_type,
107 | minimize_rk2_error=minimize_rk2_error)
108 | new_solver.build_ButcherTableau()
109 | print(new_solver.u, new_solver.v)
110 | return new_solver
111 |
112 | def create_solver_ensemble_by_noising_params(solver, ensemble_size=1, kwargs_noise={}):
113 | solver_ensemble = [solver]
114 | for _ in range(1, ensemble_size):
115 | new_solver = sample_solver_by_noising_params(solver, **kwargs_noise)
116 | solver_ensemble.append(new_solver)
117 | return solver_ensemble
118 |
119 |
120 |
121 |
122 |
123 |
124 |
--------------------------------------------------------------------------------