├── .gitignore └── torchdyn ├── utils.py ├── module4-model ├── m4c.multiple_shooting_layers.ipynb └── m4g_gde_node_classification.ipynb ├── experimental ├── gde_node_classification_pyg.py └── latent_sde.py ├── README.md ├── module3-tasks └── m3a_image_classification.ipynb └── module1-neuralde └── m1b_crossing_trajectories.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | *lightning_logs* 2 | *ipynb_checkpoints* 3 | -------------------------------------------------------------------------------- /torchdyn/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def smape(yhat, y): 5 | return torch.abs(yhat - y) / (torch.abs(yhat) + torch.abs(y)) / 2 -------------------------------------------------------------------------------- /torchdyn/module4-model/m4c.multiple_shooting_layers.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "7d38155d", 6 | "metadata": {}, 7 | "source": [ 8 | "### Multiple Shooting Layers" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 18, 14 | "id": "16f7165f", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "The autoreload extension is already loaded. To reload it, use:\n", 22 | " %reload_ext autoreload\n" 23 | ] 24 | } 25 | ], 26 | "source": [ 27 | "import torchdyn\n", 28 | "import torch\n", 29 | "import torch.nn as nn\n", 30 | "import matplotlib.pyplot as plt\n", 31 | "from torchdyn.core import MultipleShootingLayer, MultipleShootingProblem\n", 32 | "from torchdyn.numerics import Lorenz\n", 33 | "\n", 34 | "import torchdiffeq\n", 35 | "import time \n", 36 | "%load_ext autoreload\n", 37 | "%autoreload 2" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 27, 43 | "id": "c39abef8", 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "x0 = torch.randn(8, 3) + 15\n", 48 | "t_span = torch.linspace(0, 3, 3000)\n", 49 | "sys = Lorenz()" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 28, 55 | "id": "967fbf3e", 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "for sensitivity in ['autograd', 'adjoint', 'interpolated_adjoint']:\n", 60 | " mshooting = MultipleShootingProblem(sys, solver='zero', sensitivity=sensitivity)\n", 61 | " t_eval, sol = mshooting(x0, t_span)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 29, 67 | "id": "dddc49b3", 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "for sensitivity in ['autograd', 'adjoint', 'interpolated_adjoint']:\n", 72 | " mshooting = MultipleShootingLayer(sys, solver='zero', sensitivity=sensitivity)\n", 73 | " t_eval, sol = mshooting(x0, t_span)" 74 | ] 75 | } 76 | ], 77 | "metadata": { 78 | "kernelspec": { 79 | "display_name": "torchdyn", 80 | "language": "python", 81 | "name": "torchdyn" 82 | }, 83 | "language_info": { 84 | "codemirror_mode": { 85 | "name": "ipython", 86 | "version": 3 87 | }, 88 | "file_extension": ".py", 89 | "mimetype": "text/x-python", 90 | "name": "python", 91 | "nbconvert_exporter": "python", 92 | "pygments_lexer": "ipython3", 93 | "version": "3.8.8" 94 | } 95 | }, 96 | "nbformat": 4, 97 | "nbformat_minor": 5 98 | } 99 | -------------------------------------------------------------------------------- /torchdyn/experimental/gde_node_classification_pyg.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torch_geometric.transforms as T 6 | from torch_geometric.datasets import Planetoid 7 | from torch_geometric.nn import SplineConv 8 | from torchdyn.models import NeuralDE 9 | 10 | dataset = 'Cora' 11 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'datasets', dataset) 12 | dataset = Planetoid(path, dataset, transform=T.TargetIndegree()) 13 | data = dataset[0] 14 | 15 | data.train_mask = torch.zeros(data.num_nodes, dtype=torch.bool) 16 | data.train_mask[:data.num_nodes - 1000] = 1 17 | data.val_mask = None 18 | data.test_mask = torch.zeros(data.num_nodes, dtype=torch.bool) 19 | data.test_mask[data.num_nodes - 500:] = 1 20 | 21 | 22 | class GCNLayer(torch.nn.Module): 23 | def __init__(self, input_size, output_size): 24 | super(GCNLayer, self).__init__() 25 | 26 | if input_size != output_size: 27 | raise AttributeError('input size must equal output size') 28 | 29 | self.conv1 = SplineConv(input_size, output_size, dim=1, kernel_size=2).to(device) 30 | self.conv2 = SplineConv(input_size, output_size, dim=1, kernel_size=2).to(device) 31 | 32 | def forward(self, x): 33 | edge_index, edge_attr = data.edge_index, data.edge_attr 34 | x = self.conv1(x, edge_index, edge_attr) 35 | x = self.conv2(x, edge_index, edge_attr) 36 | return x 37 | 38 | 39 | class Net(torch.nn.Module): 40 | def __init__(self): 41 | super(Net, self).__init__() 42 | 43 | self.func = GCNLayer(input_size=64, output_size=64) 44 | 45 | self.conv1 = SplineConv(dataset.num_features, 64, dim=1, kernel_size=2).to(device) 46 | self.neuralDE = NeuralDE(self.func, solver='rk4', s_span=torch.linspace(0, 1, 3)).to(device) 47 | self.conv2 = SplineConv(64, dataset.num_classes, dim=1, kernel_size=2).to(device) 48 | 49 | def forward(self, x): 50 | edge_index, edge_attr = data.edge_index, data.edge_attr 51 | x = F.tanh(self.conv1(x, edge_index, edge_attr)) 52 | x = F.dropout(x, training=self.training) 53 | x = self.neuralDE(x) 54 | x = F.tanh(self.conv2(x, edge_index, edge_attr)) 55 | 56 | return F.log_softmax(x, dim=1) 57 | 58 | 59 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 60 | model, data = Net().to(device), data.to(device) 61 | optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-3) 62 | 63 | 64 | def train(): 65 | model.train() 66 | optimizer.zero_grad() 67 | F.nll_loss(model(data.x)[data.train_mask], data.y[data.train_mask]).backward() 68 | optimizer.step() 69 | 70 | 71 | def test(): 72 | model.eval() 73 | logits, accs = model(data.x), [] 74 | for _, mask in data('train_mask', 'test_mask'): 75 | pred = logits[mask].max(1)[1] 76 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() 77 | accs.append(acc) 78 | return accs 79 | 80 | 81 | for epoch in range(1, 201): 82 | train() 83 | log = 'Epoch: {:03d}, Train: {:.4f}, Test: {:.4f}' 84 | print(log.format(epoch, *test())) 85 | -------------------------------------------------------------------------------- /torchdyn/README.md: -------------------------------------------------------------------------------- 1 | ### Applications and tutorials 2 | The current version of `torchdyn` contains various quickstart examples / tutorials which explore different aspects of continuous / implicit learning and related numerical methods. Most tutorials are kept up-to-date in case of API changes of underlying libraries. We automatically validate via a quick dry run in `test/validate_tutorials.py`. These are indicated by ✅. Older tutorials ⬜️ might require minimal API changes to get working and are not automatically validated, though the goal is to eventually extend testing. Working on ensuring older tutorials are still runnable represents a perfect opportunity to get involved in the project, requiring minimal familiarity with the codebase. 3 | 4 | We organize the tutorials in modules. `00_quickstart.ipynb` offers a general overview of `torchdyn` features, including some comments on design philosophy and goals of the library. 5 | 6 | Each module is then focused on a specific aspect of continuous or implicit learning. For the moment, we offer the following modules and tutorials: 7 | 8 | ### Module 1: Neural Differential Equations 9 | We empirically verify several properties of Neural ODEs, and develop strategies to alleviate some of their weaknesses. Augmentation, depth-variance and more are discussed here. 10 | 11 | * ✅ `m1a_neural_ode_cookbook`: here, we explore the API and how to define Neural ODE variants within `torchdyn` 12 | * ✅ `m1b_crossing_trajectories`: a standard benchmark problem, highlighting expressivity limitations of Neural ODEs, and how they can be addressed 13 | * ✅ `m1c_augmentation_strategies`: augmentation API for Neural ODEs 14 | * ✅ `m1d_higher_order`: higher-order Neural ODE variants for classification 15 | 16 | 17 | ### Module 2: Numerics and Optimization 18 | This module is concerned with the numerics behind neural and non-neural differential equations. We provide examples of `torchdyn` numerics API, including advanced methods such as multiple shooting algorithms and hypersolvers. 19 | 20 | * ✅ `m2a_hypersolver_odeint`: solve ODEs with hybridized neural networks + ODE solvers: the hypersolver API 21 | * ✅ `m2b_multiple_shooting`: get familiar with `torchdyn`'s API dedicated to multiple shooting ODE solvers. 22 | * ✅ `m3c_hybrid_odeint`: learn how to simulate hybrid (potentially multi-mode) dynamical systems via `odeint_hybrid`. 23 | * ✅ `m3d_generalized_adjoint`: introduce integral losses in your Neural ODE training [[18](https://arxiv.org/abs/2003.08063)] to track a sinusoidal signal 24 | 25 | ### Module 3: Tasks and Benchmarks 26 | Here, we showcase how `torchdyn` models can be used in various machine learning and control tasks. The focus is on developing the problem setting rather than applying complex models. 27 | 28 | * ⬜️ `m3a_image_classification`: convolutional Neural ODEs for digit classification on MNIST 29 | * ✅ `m3b_optimal_control`: direct optimal control of dynamical systems via the Neural ODE API. 30 | * ⬜️ `m4c_pde_optimal_control`: fast optimal control of a Timoshenko beam via Multiple Shooting Layers and root tracking. 31 | * ⬜️ `m3d_continuous_normalizing_flows`: density estimation with continuous normalizing flows. 32 | 33 | ### Module 4: Models 34 | This module offers an overview of several specialized continuous or implicit models. 35 | 36 | * ⬜️ `m4a_approximate_normalizing_flows`: recover densities with FFJORD variants of continuous normalizing flows [[19](https://arxiv.org/abs/1810.01367)] 37 | 38 | * ✅ `m4b_hypersolver_optimal_control`: speed up direct optimal control of ODE with hypersolvers. 39 | * ⬜️ `m4c_multiple_shooting_layers`: apply multiple shooting layers to time series classification, speeding up Neural CDEs (WIP). 40 | * ⬜️ `m4d_hamiltonian_networks`: learn dynamics of energy preserving systems with a simple implementation of `Hamiltonian Neural Networks` in `torchdyn` [[10](https://arxiv.org/abs/1906.01563)] 41 | * ⬜️ `m4e_lagrangian_networks`: learn dynamics of energy preserving systems with a simple implementation of `Lagrangian Neural Networks` in `torchdyn` [[12](https://arxiv.org/abs/2003.04630)] 42 | * ✅ `m4f_stable_neural_odes`: learn dynamics with `Stable Neural Flows`, a generalization of HNNs [[18](https://arxiv.org/abs/2003.08063)] 43 | * ⬜️ `m4g_gde_node_classification`: first steps into the world of Neural GDEs [[9](https://arxiv.org/abs/1911.07532)], or ODEs on graphs parametrized by graph neural networks (GNN). Classification on Cora. 44 | 45 | 46 | 47 | 48 | #### Goals 49 | 50 | Our current goals are to extend model zoo with pretrained Neural *DE variants and equilibrium models. -------------------------------------------------------------------------------- /torchdyn/experimental/latent_sde.py: -------------------------------------------------------------------------------- 1 | """A partial Re-implementation of Xuechen Li's work (https://github.com/google-research/torchsde/blob/master/examples/latent_sde.py)""" 2 | 3 | 4 | import os 5 | import math 6 | import numpy as np 7 | from collections import namedtuple 8 | from matplotlib import pyplot as plt 9 | 10 | import torch 11 | from torch import nn, optim 12 | import pytorch_lightning as pl 13 | from torch.utils.data import Dataset, DataLoader 14 | from torch.distributions import Laplace 15 | 16 | from torchdyn.models import LatentNeuralSDE, LinearScheduler, EMAMetric 17 | from torchsde import BrownianPath 18 | 19 | 20 | class IrregularSineDataset(Dataset): 21 | def __init__(self, batch_size, num_batches): 22 | ts_, ts_ext_, ts_vis_, ts, ts_ext, ts_vis, ys, ys_ = self.make_irregular_sine_data() 23 | self.array = ys.view(-1).unsqueeze(0).repeat(batch_size*num_batches, 1) 24 | self.ts = ts 25 | self.ts_ext = ts_ext 26 | self.ts_vis = ts_vis 27 | self.ys = ys 28 | 29 | def __len__(self): return len(self.array) 30 | def __getitem__(self, i): return self.array[i] 31 | def s_span(self): return self.ts 32 | def s_ext_span(self): return self.ts_ext 33 | def v_span(self): return self.ts_vis 34 | def x_sample(self): return self.ys 35 | 36 | @staticmethod 37 | def make_irregular_sine_data(): 38 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 39 | Data = namedtuple('Data', ['ts_', 'ts_ext_', 'ts_vis_', 'ts', 'ts_ext', 'ts_vis', 'ys', 'ys_']) 40 | with torch.no_grad(): 41 | ts_ = np.sort(np.random.uniform(low=0.4, high=1.6, size=16)) 42 | ts_ext_ = np.array([0.] + list(ts_) + [2.0]) 43 | ts_vis_ = np.linspace(0., 2.0, 300) 44 | ys_ = np.sin(ts_ * (2. * math.pi))[:, None] * 0.8 45 | 46 | ts = torch.tensor(ts_).float().to(device) 47 | ts_ext = torch.tensor(ts_ext_).float() 48 | ts_vis = torch.tensor(ts_vis_).float() 49 | ys = torch.tensor(ys_).float().to(device) 50 | 51 | return Data(ts_, ts_ext_, ts_vis_, ts, ts_ext, ts_vis, ys, ys_) 52 | 53 | 54 | class FFunc(pl.LightningModule): 55 | """Posterior drift.""" 56 | def __init__(self): 57 | super(FFunc, self).__init__() 58 | self.net = nn.Sequential( 59 | nn.Linear(3, 200), 60 | nn.Tanh(), 61 | nn.Linear(200, 200), 62 | nn.Tanh(), 63 | nn.Linear(200, 1) 64 | ) 65 | 66 | def forward(self, t, y): 67 | if t.dim() == 0: 68 | t = float(t) * torch.ones_like(y) 69 | # Positional encoding in transformers; must use `t`, since the posterior is likely inhomogeneous. 70 | inp = torch.cat((torch.sin(t), torch.cos(t), y), dim=-1) 71 | return self.net(inp) 72 | 73 | 74 | class HFunc(pl.LightningModule): 75 | """Prior drift""" 76 | def __init__(self, theta=1.0, mu=0.0): 77 | super(HFunc, self).__init__() 78 | self.theta = nn.Parameter(torch.tensor([[theta]]), requires_grad=False) 79 | self.mu = nn.Parameter(torch.tensor([[mu]]), requires_grad=False) 80 | 81 | def forward(self, t, y): 82 | return self.theta * (self.mu - y) 83 | 84 | 85 | class GFunc(pl.LightningModule): 86 | """Diffusion""" 87 | def __init__(self, sigma=0.5): 88 | super(GFunc, self).__init__() 89 | self.sigma = nn.Parameter(torch.tensor([[sigma]]), requires_grad=False) 90 | 91 | def forward(self, t, y): 92 | return self.sigma.repeat(y.size(0), 1) 93 | 94 | 95 | class Model(pl.LightningModule): 96 | def __init__(self): 97 | super().__init__() 98 | 99 | sigma, theta, mu = 0.5, 1.0, 0.0 100 | options = {'trapezoidal_approx': False} 101 | 102 | self.f_func = FFunc() 103 | self.h_func = HFunc(theta=theta, mu=mu) 104 | self.g_func = GFunc(sigma=sigma) 105 | 106 | self.s_span = IrregularSineDataset(1, 1).s_span() 107 | 108 | self.lsde = LatentNeuralSDE(post_drift=self.f_func, diffusion=self.g_func, prior_drift=self.h_func, 109 | sigma=sigma, theta=theta, mu=mu, 110 | noise_type='diagonal', order=1, sensitivity='autograd', s_span=self.s_span, 111 | solver='srk', atol=1e-3, rtol=1e-3, intloss=None, options=options) 112 | 113 | def forward(self, eps: torch.Tensor, s_span=None): 114 | """ 115 | :param: Noise sample 116 | """ 117 | zs, log_ratio = self.lsde(eps, s_span) 118 | zs = zs.squeeze() 119 | # 120 | 121 | return zs, log_ratio 122 | 123 | 124 | class Learner(pl.LightningModule): 125 | def __init__(self, train_path): 126 | super().__init__() 127 | self.model = Model() 128 | 129 | dataset = IrregularSineDataset(1, 1) 130 | self.vis_span = dataset.v_span() 131 | self.x_sample = dataset.x_sample() 132 | self.s_span = dataset.s_span() 133 | self.s_ext_span = dataset.s_ext_span() 134 | 135 | self.train_path = train_path 136 | self.logp_metric = EMAMetric() 137 | self.log_ratio_metric = EMAMetric() 138 | self.loss_metric = EMAMetric() 139 | self.kl_scheduler = LinearScheduler(iters=10) 140 | 141 | self.scale = 0.05 142 | 143 | def forward(self, x): 144 | return self.model(x) 145 | 146 | def training_step(self, batch, batch_idx): 147 | # x, y = torch.split(batch, split_size_or_sections=1, dim=0) 148 | x = batch 149 | eps = torch.randn(batch.shape[0], 1) 150 | 151 | zs, log_ratio = self.model(eps=eps, s_span=self.s_ext_span) 152 | zs = zs[1:-1] 153 | 154 | likelihood = Laplace(loc=zs, scale=self.scale) 155 | 156 | # Bad Hack just in this case where every tensor in batch is identical 157 | logp = likelihood.log_prob(x.mean(dim=0).unsqueeze(1).to(self.device)).sum(dim=0).mean(dim=0) 158 | loss = -logp + log_ratio * self.kl_scheduler() 159 | 160 | # loss.backward() 161 | # self.optimizer.step() 162 | # self.scheduler.step() 163 | self.logp_metric.step(logp) 164 | self.log_ratio_metric.step(log_ratio) 165 | self.loss_metric.step(loss) 166 | 167 | logs = {'train_loss': loss} 168 | return {'loss': loss, 'log': logs} 169 | 170 | def on_epoch_end(self, vis_n_sim=1024): 171 | 172 | img_path = os.path.join(train_dir, f'global_step_{self.current_epoch}.png') 173 | ylims = (-1.75, 1.75) 174 | alphas = [0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55] 175 | percentiles = [0.999, 0.99, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1] 176 | sample_colors = ('#8c96c6', '#8c6bb1', '#810f7c') 177 | fill_color = '#9ebcda' 178 | mean_color = '#4d004b' 179 | num_samples = len(sample_colors) 180 | vis_idx = np.random.permutation(vis_n_sim) 181 | 182 | eps = torch.randn(vis_n_sim, 1) 183 | bm = BrownianPath(t0=self.vis_span[0], w0=torch.zeros(vis_n_sim, 1)) 184 | 185 | # -- Not used -- From show_prior option in original implementation 186 | # zs = self.model.sample_p(vis_span=self.vis_span, n_sim=vis_n_sim, eps=eps, bm=bm).squeeze() 187 | # ts_vis_, zs_ = self.vis_span.cpu().numpy(), zs.cpu().numpy() 188 | # zs_ = np.sort(zs_, axis=1) 189 | 190 | zs = self.model.lsde.sample_q(vis_span=self.vis_span, n_sim=vis_n_sim, eps=eps, bm=bm).squeeze() 191 | samples = zs[:, vis_idx] 192 | s_span_vis_ = self.vis_span.cpu().detach().numpy() 193 | zs_ = zs.cpu().detach().numpy() 194 | samples_ = samples.cpu().detach().numpy() 195 | 196 | zs_ = np.sort(zs_, axis=1) 197 | 198 | with torch.no_grad(): 199 | 200 | plt.subplot(frameon=False) 201 | 202 | for alpha, percentile in zip(alphas, percentiles): 203 | idx = int((1 - percentile) / 2. * vis_n_sim) 204 | zs_bot_, zs_top_ = zs_[:, idx], zs_[:, -idx] 205 | plt.fill_between(s_span_vis_, zs_bot_, zs_top_, alpha=alpha, color=fill_color) 206 | 207 | plt.plot(s_span_vis_, zs_.mean(axis=1), color=mean_color) 208 | 209 | for j in range(num_samples): 210 | plt.plot(s_span_vis_, samples_[:, j], color=sample_colors[j], linewidth=1.0) 211 | 212 | num, ds = 12, 0.12 213 | s, x = torch.meshgrid( 214 | [torch.linspace(0.2, 1.8, num), torch.linspace(-1.5, 1.5, num)] 215 | ) 216 | 217 | s, x = s.reshape(-1, 1).to(self.device), x.reshape(-1, 1).to(self.device) 218 | 219 | ftx = self.model.lsde.defunc.f(s=s, x=x) 220 | ftx = ftx.cpu().reshape(num, num) 221 | 222 | ds = torch.zeros(num, num).fill_(ds) 223 | dx = ftx * ds 224 | ds_, dx_, = ds.cpu().detach().numpy(), dx.cpu().detach().numpy() 225 | s_, x_ = s.cpu().detach().numpy(), x.cpu().detach().numpy() 226 | 227 | plt.quiver(s_, x_, ds_, dx_, alpha=0.3, edgecolors='k', width=0.0035, scale=50) 228 | 229 | # Data. 230 | plt.scatter(self.s_span.cpu().numpy(), self.x_sample.cpu().numpy(), marker='x', zorder=3, color='k', s=35) 231 | 232 | plt.ylim(ylims) 233 | plt.xlabel('$t$') 234 | plt.ylabel('$Y_t$') 235 | plt.tight_layout() 236 | plt.savefig(img_path, dpi=400) 237 | plt.close() 238 | 239 | def configure_optimizers(self): 240 | 241 | optimizer = torch.optim.Adam(self.model.parameters(), lr=0.01) 242 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=.999) 243 | 244 | return [optimizer], [scheduler] 245 | 246 | def train_dataloader(self): 247 | 248 | batch_size = 512 249 | n_batches = 50 250 | 251 | return DataLoader(IrregularSineDataset(batch_size=batch_size, num_batches=n_batches), batch_size=batch_size) 252 | 253 | 254 | train_dir = os.path.join('images', 'ts_ext') 255 | trainer = pl.Trainer(gpus=0, max_epochs=200) 256 | trainer.fit(Learner(train_dir)) 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | -------------------------------------------------------------------------------- /torchdyn/module4-model/m4g_gde_node_classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 27, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import math\n", 10 | "import numpy as np\n", 11 | "import scipy.sparse as sp\n", 12 | "import time\n", 13 | "import torch\n", 14 | "import torch.nn as nn\n", 15 | "import torch.nn.functional as F\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "from torch.nn.parameter import Parameter\n", 18 | "from torch.nn.modules.module import Module\n", 19 | "\n", 20 | "import dgl\n", 21 | "import dgl.function as fn\n", 22 | "\n", 23 | "import dgl.data\n", 24 | "import networkx as nx\n", 25 | "\n", 26 | "from torchdyn.core import NeuralODE\n", 27 | "from torchdyn.nn import DataControl, DepthCat, Augmenter\n", 28 | "from torchdyn.datasets import *\n", 29 | "from torchdyn.utils import *" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 28, 35 | "metadata": { 36 | "tags": [ 37 | "parameters" 38 | ] 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "# quick run for automated notebook validation\n", 43 | "dry_run = False" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "# Neural Graph Differential Equations" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": {}, 56 | "source": [ 57 | "## Semi-supervised node classification " 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "This notebook introduces `Neural GDEs` as a general high-performance model for graph structured data. Notebook `07_graph_differential_equations` is designed from the ground up as an introduction to Neural GDEs and therefore contains ample comments to provide insights on some of our design choices. To be accessible to practicioners/researchers without prior experience on GNNs, we discuss some features of `dgl` as well, one of the PyTorch ecosystems for geometric deep learning." 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "## Data preparation" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 29, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 81 | "\n", 82 | "# seed for repeatability\n", 83 | "torch.backends.cudnn.deterministic = True\n", 84 | "torch.backends.cudnn.benchmark = False\n", 85 | "\n", 86 | "torch.manual_seed(0)\n", 87 | "np.random.seed(0)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 30, 93 | "metadata": {}, 94 | "outputs": [ 95 | { 96 | "name": "stdout", 97 | "output_type": "stream", 98 | "text": [ 99 | " NumNodes: 2708\n", 100 | " NumEdges: 10556\n", 101 | " NumFeats: 1433\n", 102 | " NumClasses: 7\n", 103 | " NumTrainingSamples: 140\n", 104 | " NumValidationSamples: 500\n", 105 | " NumTestSamples: 1000\n", 106 | "Done loading data from cached files.\n" 107 | ] 108 | } 109 | ], 110 | "source": [ 111 | "# dgl offers convenient access to GNN benchmark datasets via `dgl.data`...\n", 112 | "# other standard datasets (e.g. Citeseer / Pubmed) are also accessible via the dgl.data\n", 113 | "# API. The rest of the notebook is compatible with Cora / Citeseer / Pubmed with minimal\n", 114 | "# modification required.\n", 115 | "data = dgl.data.CoraGraphDataset()[0]" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 32, 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "data": { 125 | "text/plain": [ 126 | "(7, 140, 500, 1000)" 127 | ] 128 | }, 129 | "execution_count": 32, 130 | "metadata": {}, 131 | "output_type": "execute_result" 132 | } 133 | ], 134 | "source": [ 135 | "# Cora is a node-classification datasets with 2708 nodes\n", 136 | "X = data.ndata['feat'].to(device)\n", 137 | "Y = data.ndata['label'].to(device)\n", 138 | "\n", 139 | "# In transductive semi-supervised node classification tasks on graphs, the model has access to all\n", 140 | "# node features but only a masked subset of the labels\n", 141 | "train_mask = data.ndata['train_mask']\n", 142 | "val_mask = data.ndata['val_mask']\n", 143 | "test_mask = data.ndata['test_mask']\n", 144 | "\n", 145 | "num_feats = X.shape[1]\n", 146 | "n_classes = 7\n", 147 | "\n", 148 | "# 140 training samples, 300 validation, 1000 test\n", 149 | "n_classes, train_mask.sum().item(), val_mask.sum().item(),test_mask.sum().item()" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "data.remove_edges()" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 35, 164 | "metadata": {}, 165 | "outputs": [ 166 | { 167 | "ename": "TypeError", 168 | "evalue": "'bool' object is not callable", 169 | "output_type": "error", 170 | "traceback": [ 171 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 172 | "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", 173 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# add self-edge for each node\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mremove_edges\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mselfloop_edges\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_edges_from\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnodes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnodes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdgl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDGLGraph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 174 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/networkx/classes/function.py\u001b[0m in \u001b[0;36mselfloop_edges\u001b[0;34m(G, data, keys, default)\u001b[0m\n\u001b[1;32m 1194\u001b[0m )\n\u001b[1;32m 1195\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1196\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mG\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_multigraph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1197\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mkeys\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1198\u001b[0m return (\n", 175 | "\u001b[0;31mTypeError\u001b[0m: 'bool' object is not callable" 176 | ] 177 | } 178 | ], 179 | "source": [ 180 | "# add self-edge for each node\n", 181 | "g = data\n", 182 | "g.remove_edges(nx.selfloop_edges(g))\n", 183 | "g.add_edges_from(zip(g.nodes(), g.nodes()))\n", 184 | "g = dgl.DGLGraph(g)\n", 185 | "edges = g.edges()\n", 186 | "n_edges = g.number_of_edges()\n", 187 | "\n", 188 | "n_edges" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 7, 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "# compute diagonal of normalization matrix D according to standard formula\n", 198 | "degs = g.in_degrees().float()\n", 199 | "norm = torch.pow(degs, -0.5)\n", 200 | "norm[torch.isinf(norm)] = 0\n", 201 | "# add to dgl.Graph in order for the norm to be accessible at training time\n", 202 | "g.ndata['norm'] = norm.unsqueeze(1).to(device)" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": {}, 208 | "source": [ 209 | "## Neural GCDE " 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "metadata": {}, 215 | "source": [ 216 | "As Neural ODEs, GDEs require specification of an ODE function (`ODEFunc`), representing the set of layers that will be called repeatedly by the ODE solver.\n", 217 | "\n", 218 | "Here, we use the convolutional variant of Neural GDEs based on GCNs: `Neural GCDEs`. The only difference with alternative neural GDEs resides in the type of GNN layer utilized in the ODEFunc.\n", 219 | "\n", 220 | "For adaptive step GDEs (dopri5) we increase the hidden dimension to 64 to reduce the stiffness of the ODE and therefore the number of ODEFunc evaluations (`NFE`: Number Function Evaluation)" 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "metadata": {}, 226 | "source": [ 227 | "First, we define the auxiliary GNN model as a standard GCN. Luckily, in this example the graph is static and can thus be assigned during initialization. For varying graphs, additional bookeeping is required." 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 8, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "def accuracy(y_hat:torch.Tensor, y:torch.Tensor):\n", 237 | " preds = torch.max(y_hat, 1)[1]\n", 238 | " return torch.mean((y == preds).float())" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 9, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "class GCNLayer(nn.Module):\n", 248 | " def __init__(self, g:dgl.DGLGraph, in_feats:int, out_feats:int, activation,\n", 249 | " dropout:int, bias:bool=True):\n", 250 | " super().__init__()\n", 251 | " self.g = g\n", 252 | " self.weight = nn.Parameter(torch.Tensor(in_feats, out_feats))\n", 253 | " if bias:\n", 254 | " self.bias = nn.Parameter(torch.Tensor(out_feats))\n", 255 | " else:\n", 256 | " self.bias = None\n", 257 | " self.activation = activation\n", 258 | " if dropout:\n", 259 | " self.dropout = nn.Dropout(p=dropout)\n", 260 | " else:\n", 261 | " self.dropout = 0.\n", 262 | " self.reset_parameters()\n", 263 | "\n", 264 | " def reset_parameters(self):\n", 265 | " stdv = 1. / math.sqrt(self.weight.size(1))\n", 266 | " self.weight.data.uniform_(-stdv, stdv)\n", 267 | " if self.bias is not None:\n", 268 | " self.bias.data.uniform_(-stdv, stdv)\n", 269 | "\n", 270 | " def forward(self, h):\n", 271 | " if self.dropout:\n", 272 | " h = self.dropout(h)\n", 273 | " h = torch.mm(h, self.weight)\n", 274 | " # normalization by square root of src degree\n", 275 | " h = h * self.g.ndata['norm']\n", 276 | " self.g.ndata['h'] = h\n", 277 | " self.g.update_all(fn.copy_src(src='h', out='m'),\n", 278 | " fn.sum(msg='m', out='h'))\n", 279 | " h = self.g.ndata.pop('h')\n", 280 | " # normalization by square root of dst degree\n", 281 | " h = h * self.g.ndata['norm']\n", 282 | " # bias\n", 283 | " if self.bias is not None:\n", 284 | " h = h + self.bias\n", 285 | " if self.activation:\n", 286 | " h = self.activation(h)\n", 287 | " return h" 288 | ] 289 | }, 290 | { 291 | "cell_type": "markdown", 292 | "metadata": {}, 293 | "source": [ 294 | "Then, we construct the Neural GDE as follows:" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": 34, 300 | "metadata": {}, 301 | "outputs": [], 302 | "source": [ 303 | "func = nn.Sequential(GCNLayer(g=g, in_feats=64, out_feats=64, activation=nn.Softplus(), dropout=0.9),\n", 304 | " GCNLayer(g=g, in_feats=64, out_feats=64, activation=None, dropout=0.9)\n", 305 | " ).to(device)" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": 35, 311 | "metadata": {}, 312 | "outputs": [], 313 | "source": [ 314 | "neuralDE = NeuralODE(func, solver='rk4', s_span=torch.linspace(0, 1, 3)).to(device)" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": 36, 320 | "metadata": {}, 321 | "outputs": [], 322 | "source": [ 323 | "m = nn.Sequential(GCNLayer(g=g, in_feats=num_feats, out_feats=64, activation=None, dropout=0.4),\n", 324 | " neuralDE,\n", 325 | " GCNLayer(g=g, in_feats=64, out_feats=n_classes, activation=None, dropout=0.)\n", 326 | " ).to(device)" 327 | ] 328 | }, 329 | { 330 | "cell_type": "markdown", 331 | "metadata": {}, 332 | "source": [ 333 | "## Training loop" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 37, 339 | "metadata": {}, 340 | "outputs": [], 341 | "source": [ 342 | "class PerformanceContainer(object):\n", 343 | " \"\"\" Simple data class for metrics logging.\"\"\"\n", 344 | " def __init__(self, data:dict):\n", 345 | " self.data = data\n", 346 | " \n", 347 | " @staticmethod\n", 348 | " def deep_update(x, y):\n", 349 | " for key in y.keys():\n", 350 | " x.update({key: list(x[key] + y[key])})\n", 351 | " return x" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 38, 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [ 360 | "opt = torch.optim.Adam(m.parameters(), lr=1e-3, weight_decay=5e-4)\n", 361 | "criterion = torch.nn.CrossEntropyLoss()\n", 362 | "logger = PerformanceContainer(data={'train_loss':[], 'train_accuracy':[],\n", 363 | " 'test_loss':[], 'test_accuracy':[],\n", 364 | " 'forward_time':[], 'backward_time':[],\n", 365 | " })\n" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": 39, 371 | "metadata": {}, 372 | "outputs": [ 373 | { 374 | "name": "stdout", 375 | "output_type": "stream", 376 | "text": [ 377 | "[150], Loss: 1.457, Train Accuracy: 0.514, Test Accuracy: 0.377\n", 378 | "[300], Loss: 0.730, Train Accuracy: 0.907, Test Accuracy: 0.731\n", 379 | "[450], Loss: 0.542, Train Accuracy: 0.921, Test Accuracy: 0.766\n", 380 | "[600], Loss: 0.416, Train Accuracy: 0.950, Test Accuracy: 0.816\n", 381 | "[750], Loss: 0.557, Train Accuracy: 0.943, Test Accuracy: 0.810\n", 382 | "[900], Loss: 0.353, Train Accuracy: 0.964, Test Accuracy: 0.819\n", 383 | "[1050], Loss: 0.265, Train Accuracy: 0.971, Test Accuracy: 0.807\n", 384 | "[1200], Loss: 0.340, Train Accuracy: 0.964, Test Accuracy: 0.828\n", 385 | "[1350], Loss: 0.201, Train Accuracy: 0.971, Test Accuracy: 0.828\n", 386 | "[1500], Loss: 0.368, Train Accuracy: 0.971, Test Accuracy: 0.824\n", 387 | "[1650], Loss: 0.255, Train Accuracy: 0.979, Test Accuracy: 0.812\n", 388 | "[1800], Loss: 0.241, Train Accuracy: 0.971, Test Accuracy: 0.820\n", 389 | "[1950], Loss: 0.304, Train Accuracy: 0.979, Test Accuracy: 0.821\n", 390 | "[2100], Loss: 0.248, Train Accuracy: 0.971, Test Accuracy: 0.828\n", 391 | "[2250], Loss: 0.223, Train Accuracy: 0.979, Test Accuracy: 0.815\n", 392 | "[2400], Loss: 0.180, Train Accuracy: 0.979, Test Accuracy: 0.834\n", 393 | "[2550], Loss: 0.321, Train Accuracy: 0.986, Test Accuracy: 0.825\n", 394 | "[2700], Loss: 0.166, Train Accuracy: 0.986, Test Accuracy: 0.808\n", 395 | "[2850], Loss: 0.171, Train Accuracy: 0.986, Test Accuracy: 0.821\n", 396 | "[3000], Loss: 0.190, Train Accuracy: 0.986, Test Accuracy: 0.827\n", 397 | "[3150], Loss: 0.207, Train Accuracy: 0.993, Test Accuracy: 0.823\n", 398 | "[3300], Loss: 0.159, Train Accuracy: 0.986, Test Accuracy: 0.817\n", 399 | "[3450], Loss: 0.183, Train Accuracy: 0.993, Test Accuracy: 0.829\n", 400 | "[3600], Loss: 0.161, Train Accuracy: 0.986, Test Accuracy: 0.831\n", 401 | "[3750], Loss: 0.143, Train Accuracy: 0.986, Test Accuracy: 0.826\n", 402 | "[3900], Loss: 0.182, Train Accuracy: 0.986, Test Accuracy: 0.826\n", 403 | "[4050], Loss: 0.156, Train Accuracy: 0.993, Test Accuracy: 0.817\n", 404 | "[4200], Loss: 0.177, Train Accuracy: 0.986, Test Accuracy: 0.819\n", 405 | "[4350], Loss: 0.160, Train Accuracy: 0.993, Test Accuracy: 0.811\n", 406 | "[4500], Loss: 0.182, Train Accuracy: 0.986, Test Accuracy: 0.829\n", 407 | "[4650], Loss: 0.130, Train Accuracy: 0.986, Test Accuracy: 0.812\n", 408 | "[4800], Loss: 0.143, Train Accuracy: 0.986, Test Accuracy: 0.830\n", 409 | "[4950], Loss: 0.207, Train Accuracy: 0.993, Test Accuracy: 0.818\n" 410 | ] 411 | } 412 | ], 413 | "source": [ 414 | "steps = 5000\n", 415 | "verbose_step = 150\n", 416 | "num_grad_steps = 0\n", 417 | "\n", 418 | "for i in range(steps): # looping over epochs\n", 419 | " m.train()\n", 420 | " outputs = m(X)\n", 421 | " y_pred = outputs\n", 422 | " loss = criterion(y_pred[train_mask], Y[train_mask])\n", 423 | " opt.zero_grad()\n", 424 | " \n", 425 | " start_time = time.time()\n", 426 | " loss.backward()\n", 427 | " \n", 428 | " opt.step()\n", 429 | " num_grad_steps += 1\n", 430 | "\n", 431 | " with torch.no_grad():\n", 432 | " m.eval()\n", 433 | "\n", 434 | " # calculating outputs again with zeroed dropout\n", 435 | " y_pred = m(X)\n", 436 | "\n", 437 | " train_loss = loss.item()\n", 438 | " train_acc = accuracy(y_pred[train_mask], Y[train_mask]).item()\n", 439 | " test_acc = accuracy(y_pred[test_mask], Y[test_mask]).item()\n", 440 | " test_loss = criterion(y_pred[test_mask], Y[test_mask]).item()\n", 441 | " logger.deep_update(logger.data, dict(train_loss=[train_loss], train_accuracy=[train_acc],\n", 442 | " test_loss=[test_loss], test_accuracy=[test_acc])\n", 443 | " )\n", 444 | "\n", 445 | " if num_grad_steps % verbose_step == 0:\n", 446 | " print('[{}], Loss: {:3.3f}, Train Accuracy: {:3.3f}, Test Accuracy: {:3.3f}'.format(num_grad_steps,\n", 447 | " train_loss,\n", 448 | " train_acc,\n", 449 | " test_acc,\n", 450 | " ))" 451 | ] 452 | } 453 | ], 454 | "metadata": { 455 | "kernelspec": { 456 | "display_name": "torchdyn", 457 | "language": "python", 458 | "name": "torchdyn" 459 | }, 460 | "language_info": { 461 | "codemirror_mode": { 462 | "name": "ipython", 463 | "version": 3 464 | }, 465 | "file_extension": ".py", 466 | "mimetype": "text/x-python", 467 | "name": "python", 468 | "nbconvert_exporter": "python", 469 | "pygments_lexer": "ipython3", 470 | "version": "3.8.5" 471 | }, 472 | "latex_envs": { 473 | "LaTeX_envs_menu_present": true, 474 | "autoclose": false, 475 | "autocomplete": true, 476 | "bibliofile": "biblio.bib", 477 | "cite_by": "apalike", 478 | "current_citInitial": 1, 479 | "eqLabelWithNumbers": true, 480 | "eqNumInitial": 1, 481 | "hotkeys": { 482 | "equation": "Ctrl-E", 483 | "itemize": "Ctrl-I" 484 | }, 485 | "labels_anchors": false, 486 | "latex_user_defs": false, 487 | "report_style_numbering": false, 488 | "user_envs_cfg": false 489 | }, 490 | "varInspector": { 491 | "cols": { 492 | "lenName": 16, 493 | "lenType": 16, 494 | "lenVar": 40 495 | }, 496 | "kernels_config": { 497 | "python": { 498 | "delete_cmd_postfix": "", 499 | "delete_cmd_prefix": "del ", 500 | "library": "var_list.py", 501 | "varRefreshCmd": "print(var_dic_list())" 502 | }, 503 | "r": { 504 | "delete_cmd_postfix": ") ", 505 | "delete_cmd_prefix": "rm(", 506 | "library": "var_list.r", 507 | "varRefreshCmd": "cat(var_dic_list()) " 508 | } 509 | }, 510 | "types_to_exclude": [ 511 | "module", 512 | "function", 513 | "builtin_function_or_method", 514 | "instance", 515 | "_Feature" 516 | ], 517 | "window_display": false 518 | } 519 | }, 520 | "nbformat": 4, 521 | "nbformat_minor": 4 522 | } 523 | -------------------------------------------------------------------------------- /torchdyn/module3-tasks/m3a_image_classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Image Classification" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "In this notebook we explore standard image classification on MNIST and CIFAR10 with convolutional Neural ODE variants.\n", 15 | "* Depth-invariant neural ODE\n", 16 | "* Galerkin neural ODE (GalNODE)\n", 17 | "\n", 18 | "In the following notebooks we'll further develop intuition around `augmentation` strategies that can be easily applied to the models below with the flexible `torchdyn` API. Here, we use simple `0-augmentation`." 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "from torchdyn.core import NeuralODE\n", 28 | "from torchdyn.nn import DataControl, DepthCat, Augmenter, GalConv2d, Fourier" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "import torch\n", 38 | "import torch.nn as nn\n", 39 | "from torch.utils.data import DataLoader\n", 40 | "from torchvision import datasets, transforms\n", 41 | "\n", 42 | "import pytorch_lightning as pl\n", 43 | "from pytorch_lightning.loggers import WandbLogger\n", 44 | "from pytorch_lightning.metrics.functional import accuracy" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 3, 50 | "metadata": { 51 | "tags": [ 52 | "parameters" 53 | ] 54 | }, 55 | "outputs": [], 56 | "source": [ 57 | "# quick run for automated notebook validation\n", 58 | "dry_run = False" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 4, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 6, 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "name": "stdout", 77 | "output_type": "stream", 78 | "text": [ 79 | "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", 80 | "Using downloaded and verified file: ../../data/mnist_data/MNIST/raw/train-images-idx3-ubyte.gz\n", 81 | "Extracting ../../data/mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to ../../data/mnist_data/MNIST/raw\n", 82 | "\n", 83 | "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", 84 | "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../../data/mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz\n" 85 | ] 86 | }, 87 | { 88 | "data": { 89 | "application/vnd.jupyter.widget-view+json": { 90 | "model_id": "a01fe54d2fc941268ceab26caac7e2d1", 91 | "version_major": 2, 92 | "version_minor": 0 93 | }, 94 | "text/plain": [ 95 | " 0%| | 0/28881 [00:00\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 5\u001b[0m )\n\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlearn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 333 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, model, train_dataloader, val_dataloaders, datamodule)\u001b[0m\n\u001b[1;32m 456\u001b[0m )\n\u001b[1;32m 457\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 458\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 459\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 460\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstopped\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 334 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, model)\u001b[0m\n\u001b[1;32m 754\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 755\u001b[0m \u001b[0;31m# dispatch `start_training` or `start_evaluating` or `start_predicting`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 756\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 757\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 758\u001b[0m \u001b[0;31m# plugin will finalized fitting (e.g. ddp_spawn will load trained model)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 335 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mdispatch\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 795\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart_predicting\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 796\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 797\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart_training\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 798\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 799\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mrun_stage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 336 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py\u001b[0m in \u001b[0;36mstart_training\u001b[0;34m(self, trainer)\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstart_training\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'pl.Trainer'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 96\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_type_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart_training\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 97\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstart_evaluating\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'pl.Trainer'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 337 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py\u001b[0m in \u001b[0;36mstart_training\u001b[0;34m(self, trainer)\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstart_training\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'pl.Trainer'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0;31m# double dispatch to initiate the training loop\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 144\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_stage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 145\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstart_evaluating\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'pl.Trainer'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 338 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mrun_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 805\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredicting\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 806\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_predict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 807\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_train\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 808\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 809\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_pre_training_routine\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 339 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mrun_train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 867\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"run_training_epoch\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 868\u001b[0m \u001b[0;31m# run train epoch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 869\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_training_epoch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 870\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 871\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_steps\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_steps\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mglobal_step\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 340 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py\u001b[0m in \u001b[0;36mrun_training_epoch\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 497\u001b[0m \u001b[0;31m# ------------------------------------\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 498\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"run_training_batch\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 499\u001b[0;31m \u001b[0mbatch_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_training_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataloader_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 500\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 501\u001b[0m \u001b[0;31m# when returning -1 from train_step, we end epoch early\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 341 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py\u001b[0m in \u001b[0;36mrun_training_batch\u001b[0;34m(self, batch, batch_idx, dataloader_idx)\u001b[0m\n\u001b[1;32m 736\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 737\u001b[0m \u001b[0;31m# optimizer step\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 738\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_step_and_backward_closure\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 739\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizers\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 740\u001b[0m \u001b[0;31m# revert back to previous state\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 342 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py\u001b[0m in \u001b[0;36moptimizer_step\u001b[0;34m(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)\u001b[0m\n\u001b[1;32m 432\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 433\u001b[0m \u001b[0;31m# model hook\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 434\u001b[0;31m model_ref.optimizer_step(\n\u001b[0m\u001b[1;32m 435\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcurrent_epoch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 436\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 343 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py\u001b[0m in \u001b[0;36moptimizer_step\u001b[0;34m(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)\u001b[0m\n\u001b[1;32m 1401\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1402\u001b[0m \"\"\"\n\u001b[0;32m-> 1403\u001b[0;31m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moptimizer_closure\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1404\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1405\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0moptimizer_zero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 344 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure, *args, **kwargs)\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[0mprofiler_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34mf\"optimizer_step_and_closure_{self._optimizer_idx}\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 213\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 214\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__optimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprofiler_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mprofiler_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 215\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_total_optimizer_step_calls\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 216\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 345 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py\u001b[0m in \u001b[0;36m__optimizer_step\u001b[0;34m(self, closure, profiler_name, **kwargs)\u001b[0m\n\u001b[1;32m 132\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprofiler_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 134\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_optimizer_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlambda_closure\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 135\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mCallable\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 346 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py\u001b[0m in \u001b[0;36moptimizer_step\u001b[0;34m(self, optimizer, opt_idx, lambda_closure, **kwargs)\u001b[0m\n\u001b[1;32m 327\u001b[0m )\n\u001b[1;32m 328\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmake_optimizer_step\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 329\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_optimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlambda_closure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 330\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprecision_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpost_optimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 331\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_type_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpost_optimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 347 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py\u001b[0m in \u001b[0;36mrun_optimizer_step\u001b[0;34m(self, optimizer, optimizer_idx, lambda_closure, **kwargs)\u001b[0m\n\u001b[1;32m 334\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlambda_closure\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mCallable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 335\u001b[0m ) -> None:\n\u001b[0;32m--> 336\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_type_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlambda_closure\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlambda_closure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 337\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 338\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0moptimizer_zero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcurrent_epoch\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 348 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py\u001b[0m in \u001b[0;36moptimizer_step\u001b[0;34m(self, optimizer, lambda_closure, **kwargs)\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 192\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0moptimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptim\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlambda_closure\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mCallable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 193\u001b[0;31m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlambda_closure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 194\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 195\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 349 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/torch/optim/optimizer.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0mprofile_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"Optimizer.step#{}.step\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrecord_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprofile_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 88\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 89\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 350 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/torch/autograd/grad_mode.py\u001b[0m in \u001b[0;36mdecorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 29\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 351 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/torch/optim/adamw.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mclosure\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menable_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 65\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 66\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mgroup\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparam_groups\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 352 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py\u001b[0m in \u001b[0;36mtrain_step_and_backward_closure\u001b[0;34m()\u001b[0m\n\u001b[1;32m 730\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 731\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtrain_step_and_backward_closure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 732\u001b[0;31m result = self.training_step_and_backward(\n\u001b[0m\u001b[1;32m 733\u001b[0m \u001b[0msplit_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhiddens\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 734\u001b[0m )\n", 353 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py\u001b[0m in \u001b[0;36mtraining_step_and_backward\u001b[0;34m(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)\u001b[0m\n\u001b[1;32m 821\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"training_step_and_backward\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 822\u001b[0m \u001b[0;31m# lightning module hook\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 823\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msplit_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhiddens\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 824\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_curr_step_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 825\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 354 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, split_batch, batch_idx, opt_idx, hiddens)\u001b[0m\n\u001b[1;32m 288\u001b[0m \u001b[0mmodel_ref\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mResult\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 289\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"training_step\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 290\u001b[0;31m \u001b[0mtraining_step_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 291\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpost_training_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 292\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 355 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, args)\u001b[0m\n\u001b[1;32m 202\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprecision_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_step_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_type_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_step_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 204\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_type_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 205\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 206\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpost_training_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 356 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 154\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 155\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlightning_module\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 156\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 157\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpost_training_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 357 | "\u001b[0;32m\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0my_hat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mCrossEntropyLoss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_hat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0mepoch_progress\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miters\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloader_len\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 358 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1049\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1050\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1051\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1052\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1053\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 359 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 139\u001b[0;31m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 140\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 360 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1049\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1050\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1051\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1052\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1053\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 361 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/torch/nn/modules/conv.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 441\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 442\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 443\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_conv_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 444\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 445\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mConv3d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_ConvNd\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 362 | "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/torch/nn/modules/conv.py\u001b[0m in \u001b[0;36m_conv_forward\u001b[0;34m(self, input, weight, bias)\u001b[0m\n\u001b[1;32m 437\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstride\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 438\u001b[0m _pair(0), self.dilation, self.groups)\n\u001b[0;32m--> 439\u001b[0;31m return F.conv2d(input, weight, bias, self.stride,\n\u001b[0m\u001b[1;32m 440\u001b[0m self.padding, self.dilation, self.groups)\n\u001b[1;32m 441\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 363 | "\u001b[0;31mTypeError\u001b[0m: conv2d() received an invalid combination of arguments - got (tuple, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of:\n * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)\n didn't match because some of the arguments have invalid types: (!tuple!, !Parameter!, !Parameter!, !tuple!, !tuple!, !tuple!, int)\n * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)\n didn't match because some of the arguments have invalid types: (!tuple!, !Parameter!, !Parameter!, !tuple!, !tuple!, !tuple!, int)\n" 364 | ] 365 | } 366 | ], 367 | "source": [ 368 | "learn = Learner(model)\n", 369 | "trainer = pl.Trainer(max_epochs=3,\n", 370 | " progress_bar_refresh_rate=1,\n", 371 | " gpus=1\n", 372 | " )\n", 373 | "\n", 374 | "trainer.fit(learn)" 375 | ] 376 | }, 377 | { 378 | "cell_type": "markdown", 379 | "metadata": {}, 380 | "source": [ 381 | "3 epochs are not enough. Feel free to keep training and using all kinds of scheduling and optimization tricks :)" 382 | ] 383 | }, 384 | { 385 | "cell_type": "markdown", 386 | "metadata": {}, 387 | "source": [ 388 | "## Galerkin Data-Controlled Conv Neural ODE (IL-Augmentation)" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": 12, 394 | "metadata": {}, 395 | "outputs": [], 396 | "source": [ 397 | "func = nn.Sequential(DataControl(),\n", 398 | " DepthCat(1),\n", 399 | " GalConv2d(10+10, 12, 3, padding=1, expfunc=Fourier(5)),\n", 400 | " nn.Softplus(),\n", 401 | " DataControl(),\n", 402 | " DepthCat(1),\n", 403 | " GalConv2d(22, 10, 3, padding=1, expfunc=Fourier(5)),\n", 404 | " nn.Tanh()\n", 405 | " )\n", 406 | "\n", 407 | "neuralDE = NeuralODE(func, \n", 408 | " solver='dopri5',\n", 409 | " sensitivity='adjoint',\n", 410 | " s_span=torch.linspace(0, 1, 2)).to(device)\n", 411 | "\n", 412 | "model = nn.Sequential(Augmenter(augment_idx=1, augment_func=nn.Conv2d(1, 9, 3, padding=1)),\n", 413 | " neuralDE,\n", 414 | " nn.Conv2d(10, 1, 3, padding=1),\n", 415 | " nn.Flatten(), \n", 416 | " nn.Linear(28*28, 10)).to(device)\n" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": 13, 422 | "metadata": {}, 423 | "outputs": [ 424 | { 425 | "name": "stderr", 426 | "output_type": "stream", 427 | "text": [ 428 | "GPU available: True, used: True\n", 429 | "TPU available: False, using: 0 TPU cores\n", 430 | "CUDA_VISIBLE_DEVICES: [0]\n", 431 | "\n", 432 | " | Name | Type | Params\n", 433 | "-------------------------------------\n", 434 | "0 | model | Sequential | 49 K \n" 435 | ] 436 | }, 437 | { 438 | "data": { 439 | "application/vnd.jupyter.widget-view+json": { 440 | "model_id": "0e8f6943122c4163b035e33a668edd1a", 441 | "version_major": 2, 442 | "version_minor": 0 443 | }, 444 | "text/plain": [ 445 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…" 446 | ] 447 | }, 448 | "metadata": {}, 449 | "output_type": "display_data" 450 | }, 451 | { 452 | "name": "stdout", 453 | "output_type": "stream", 454 | "text": [ 455 | "\n" 456 | ] 457 | }, 458 | { 459 | "data": { 460 | "text/plain": [ 461 | "1" 462 | ] 463 | }, 464 | "execution_count": 13, 465 | "metadata": {}, 466 | "output_type": "execute_result" 467 | } 468 | ], 469 | "source": [ 470 | "learn = Learner(model)\n", 471 | "trainer = pl.Trainer(max_epochs=3,\n", 472 | " progress_bar_refresh_rate=1,\n", 473 | " )\n", 474 | "\n", 475 | "trainer.fit(learn)" 476 | ] 477 | }, 478 | { 479 | "cell_type": "markdown", 480 | "metadata": {}, 481 | "source": [ 482 | "3 epochs are not enough. Feel free to keep training and using all kinds of scheduling and optimization tricks :)" 483 | ] 484 | } 485 | ], 486 | "metadata": { 487 | "kernelspec": { 488 | "display_name": "torchdyn", 489 | "language": "python", 490 | "name": "torchdyn" 491 | }, 492 | "language_info": { 493 | "codemirror_mode": { 494 | "name": "ipython", 495 | "version": 3 496 | }, 497 | "file_extension": ".py", 498 | "mimetype": "text/x-python", 499 | "name": "python", 500 | "nbconvert_exporter": "python", 501 | "pygments_lexer": "ipython3", 502 | "version": "3.8.5" 503 | }, 504 | "latex_envs": { 505 | "LaTeX_envs_menu_present": true, 506 | "autoclose": false, 507 | "autocomplete": true, 508 | "bibliofile": "biblio.bib", 509 | "cite_by": "apalike", 510 | "current_citInitial": 1, 511 | "eqLabelWithNumbers": true, 512 | "eqNumInitial": 1, 513 | "hotkeys": { 514 | "equation": "Ctrl-E", 515 | "itemize": "Ctrl-I" 516 | }, 517 | "labels_anchors": false, 518 | "latex_user_defs": false, 519 | "report_style_numbering": false, 520 | "user_envs_cfg": false 521 | }, 522 | "varInspector": { 523 | "cols": { 524 | "lenName": 16, 525 | "lenType": 16, 526 | "lenVar": 40 527 | }, 528 | "kernels_config": { 529 | "python": { 530 | "delete_cmd_postfix": "", 531 | "delete_cmd_prefix": "del ", 532 | "library": "var_list.py", 533 | "varRefreshCmd": "print(var_dic_list())" 534 | }, 535 | "r": { 536 | "delete_cmd_postfix": ") ", 537 | "delete_cmd_prefix": "rm(", 538 | "library": "var_list.r", 539 | "varRefreshCmd": "cat(var_dic_list()) " 540 | } 541 | }, 542 | "types_to_exclude": [ 543 | "module", 544 | "function", 545 | "builtin_function_or_method", 546 | "instance", 547 | "_Feature" 548 | ], 549 | "window_display": false 550 | } 551 | }, 552 | "nbformat": 4, 553 | "nbformat_minor": 4 554 | } 555 | -------------------------------------------------------------------------------- /torchdyn/module1-neuralde/m1b_crossing_trajectories.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Approximating a Reflection Map" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "In this tutorial we show how to painlessly train a neural ODE for approximating the **reflection map** \n", 15 | "\n", 16 | "$$\n", 17 | " y = -x\n", 18 | "$$\n", 19 | "\n", 20 | "This tutorial also serves as a warning on limitations of *vanilla* ODE models which should always be considered when designing your task-specific architecture.\\\n", 21 | "In fact, vanilla Neural ODEs cannot approximate (in 1D) functions which requires the flows to cross, e.g. the reflection map $y=-x$ as they would break the uniqueness of solutions (and thus the determinism). As we show later, one way to overcome this issue is to employ **data-controlled** models." 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 1, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "from torchdyn.core import NeuralODE\n", 31 | "from torchdyn.nn import DataControl, DepthCat, Augmenter\n", 32 | "from torchdyn.utils import *" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "**Data**" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "The dataset contains pairs of `(-1, 1)` and `(1, -1)`" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 2, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "import torch\n", 56 | "import torch.utils.data as data\n", 57 | "\n", 58 | "n_points = 100\n", 59 | "X = torch.linspace(-1,1, n_points).reshape(-1,1)\n", 60 | "y = -X\n", 61 | "\n", 62 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 63 | "X_train, y_train = torch.Tensor(X).to(device), torch.Tensor(y).to(device)\n", 64 | "\n", 65 | "bs = len(X)\n", 66 | "train = data.TensorDataset(X_train, y_train)\n", 67 | "trainloader = data.DataLoader(train, batch_size=bs, shuffle=False)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "**Learner**" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 37, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "import torch.nn as nn\n", 84 | "import pytorch_lightning as pl\n", 85 | "\n", 86 | "\n", 87 | "class Learner(pl.LightningModule):\n", 88 | " def __init__(self, t_span, model:nn.Module, settings:dict={}):\n", 89 | " super().__init__()\n", 90 | " self.model, self.t_span = model, t_span\n", 91 | " \n", 92 | " def forward(self, x):\n", 93 | " return self.model(x)\n", 94 | " \n", 95 | " def training_step(self, batch, batch_idx):\n", 96 | " x, y = batch \n", 97 | " t_eval, yhat = self.model(x, self.t_span)\n", 98 | " yhat = yhat[-1] # select last point of solution trajectory\n", 99 | " loss = nn.MSELoss()(yhat, y)\n", 100 | " return {'loss': loss} \n", 101 | " \n", 102 | " def configure_optimizers(self):\n", 103 | " return torch.optim.Adam(self.model.parameters(), lr=0.01)\n", 104 | "\n", 105 | " def train_dataloader(self):\n", 106 | " return trainloader" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "## Uncontrolled Neural ODE models" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": {}, 119 | "source": [ 120 | "We first consider the following neural ODE variants: `depth-invariant` and `depth-variant` (\"cat\"). As we expect, these models will **NOT** be able to approximate the reflection map." 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 38, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "# vanilla depth-invariant\n", 130 | "func = nn.Sequential(\n", 131 | " nn.Linear(1, 64),\n", 132 | " nn.Tanh(),\n", 133 | " nn.Linear(64,1)\n", 134 | " ).to(device)\n", 135 | "\n", 136 | "\n", 137 | "# vanilla depth-variant\n", 138 | "func_dv = nn.Sequential(DepthCat(1),\n", 139 | " nn.Linear(2, 64),\n", 140 | " nn.Tanh(),\n", 141 | " nn.Linear(64,1)\n", 142 | " ).to(device)\n", 143 | "\n", 144 | "funcs = [func, func_dv]\n", 145 | "\n", 146 | "t_span = torch.linspace(0,1,100)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 39, 152 | "metadata": {}, 153 | "outputs": [ 154 | { 155 | "name": "stderr", 156 | "output_type": "stream", 157 | "text": [ 158 | "GPU available: True, used: True\n", 159 | "TPU available: False, using: 0 TPU cores\n", 160 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", 161 | "\n", 162 | " | Name | Type | Params\n", 163 | "------------------------------------\n", 164 | "0 | model | NeuralODE | 193 \n", 165 | "------------------------------------\n", 166 | "193 Trainable params\n", 167 | "0 Non-trainable params\n", 168 | "193 Total params\n", 169 | "0.001 Total estimated model params size (MB)\n" 170 | ] 171 | }, 172 | { 173 | "name": "stdout", 174 | "output_type": "stream", 175 | "text": [ 176 | "Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.\n" 177 | ] 178 | }, 179 | { 180 | "data": { 181 | "application/vnd.jupyter.widget-view+json": { 182 | "model_id": "611741801b6e4d27a7c93c0dc57a5d9d", 183 | "version_major": 2, 184 | "version_minor": 0 185 | }, 186 | "text/plain": [ 187 | "Training: 0it [00:00, ?it/s]" 188 | ] 189 | }, 190 | "metadata": {}, 191 | "output_type": "display_data" 192 | }, 193 | { 194 | "ename": "AttributeError", 195 | "evalue": "'tuple' object has no attribute 'detach'", 196 | "output_type": "error", 197 | "traceback": [ 198 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 199 | "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", 200 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;31m# plot the learned flows\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msubplot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0mtraj\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrajectory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt_span\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 17\u001b[0m \u001b[0mplot_traj_vf_1D\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt_span\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_grid\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m30\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_span\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m;\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 201 | "\u001b[0;31mAttributeError\u001b[0m: 'tuple' object has no attribute 'detach'" 202 | ] 203 | }, 204 | { 205 | "data": { 206 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAAD8CAYAAADUv3dIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAANIklEQVR4nO3cYajd9X3H8fdHs6zM2TqWWyhJWi2Ls8ENdBfnKKwO3Yg+SB50lASk6wiGdrMMWgYOhyv2UVfWQSFbmzFxLVRr+6BcaEqgnSJI47yitSZiuU1dTSrz1jqfSNWw7x6c4zi9Jt5/47nfu3vyfsGF8/+f3z3n+8+5952Tc84/qSokST0uWO8BJOl8YnQlqZHRlaRGRleSGhldSWpkdCWp0arRTXJXkueTPHmW65Pk80mWkjyR5OrpjylJs2HIM927gV1vcv2NwI7x1wHgn9/6WJI0m1aNblU9CPzsTZbsAb5UI0eBS5K8a1oDStIs2TSF29gKPDuxfXK877mVC5McYPRsmIsuuuj3rrjiiincvST1evTRR39aVXPn8r3TiO5gVXUIOAQwPz9fi4uLnXcvSVOR5D/P9Xun8emFU8D2ie1t432SpBWmEd0F4MPjTzFcC7xUVW94aUGSNODlhST3ANcBW5KcBP4O+BWAqvoCcBi4CVgCXgb+fK2GlaSNbtXoVtW+Va4v4C+nNpEkzTDPSJOkRkZXkhoZXUlqZHQlqZHRlaRGRleSGhldSWpkdCWpkdGVpEZGV5IaGV1JamR0JamR0ZWkRkZXkhoZXUlqZHQlqZHRlaRGRleSGhldSWpkdCWpkdGVpEZGV5IaGV1JamR0JamR0ZWkRkZXkhoZXUlqZHQlqZHRlaRGRleSGhldSWpkdCWpkdGVpEZGV5IaGV1JajQoukl2JXk6yVKS285w/buT3J/ksSRPJLlp+qNK0sa3anSTXAgcBG4EdgL7kuxcsexvgfuq6ipgL/BP0x5UkmbBkGe61wBLVXWiql4F7gX2rFhTwNvHl98B/GR6I0rS7BgS3a3AsxPbJ8f7Jn0KuDnJSeAw8PEz3VCSA0kWkywuLy+fw7iStLFN6420fcDdVbUNuAn4cpI33HZVHaqq+aqan5ubm9JdS9LGMSS6p4DtE9vbxvsm7QfuA6iq7wJvA7ZMY0BJmiVDovsIsCPJZUk2M3qjbGHFmh8D1wMkeR+j6Pr6gSStsGp0q+o0cCtwBHiK0acUjiW5M8nu8bJPArck+R5wD/CRqqq1GlqSNqpNQxZV1WFGb5BN7rtj4vJx4P3THU2SZo9npElSI6MrSY2MriQ1MrqS1MjoSlIjoytJjYyuJDUyupLUyOhKUiOjK0mNjK4kNTK6ktTI6EpSI6MrSY2MriQ1MrqS1MjoSlIjoytJjYyuJDUyupLUyOhKUiOjK0mNjK4kNTK6ktTI6EpSI6MrSY2MriQ1MrqS1MjoSlIjoytJjYyuJDUyupLUyOhKUiOjK0mNBkU3ya4kTydZSnLbWdZ8KMnxJMeSfGW6Y0rSbNi02oIkFwIHgT8GTgKPJFmoquMTa3YAfwO8v6peTPLOtRpYkjayIc90rwGWqupEVb0K3AvsWbHmFuBgVb0IUFXPT3dMSZoNQ6K7FXh2YvvkeN+ky4HLkzyU5GiSXWe6oSQHkiwmWVxeXj63iSVpA5vWG2mbgB3AdcA+4F+SXLJyUVUdqqr5qpqfm5ub0l1L0sYxJLqngO0T29vG+yadBBaq6rWq+hHwA0YRliRNGBLdR4AdSS5LshnYCyysWPMNRs9ySbKF0csNJ6Y3piTNhlWjW1WngVuBI8BTwH1VdSzJnUl2j5cdAV5Ichy4H/jrqnphrYaWpI0qVbUudzw/P1+Li4vrct+S9FYkebSq5s/lez0jTZIaGV1JamR0JamR0ZWkRkZXkhoZXUlqZHQlqZHRlaRGRleSGhldSWpkdCWpkdGVpEZGV5IaGV1JamR0JamR0ZWkRkZXkhoZXUlqZHQlqZHRlaRGRleSGhldSWpkdCWpkdGVpEZGV5IaGV1JamR0JamR0ZWkRkZXkhoZXUlqZHQlqZHRlaRGRleSGhldSWpkdCWp0aDoJtmV5OkkS0lue5N1H0xSSeanN6IkzY5Vo5vkQuAgcCOwE9iXZOcZ1l0M/BXw8LSHlKRZMeSZ7jXAUlWdqKpXgXuBPWdY92ngM8DPpzifJM2UIdHdCjw7sX1yvO//JLka2F5V33yzG0pyIMliksXl5eVfelhJ2uje8htpSS4APgd8crW1VXWoquaran5ubu6t3rUkbThDonsK2D6xvW2873UXA1cCDyR5BrgWWPDNNEl6oyHRfQTYkeSyJJuBvcDC61dW1UtVtaWqLq2qS4GjwO6qWlyTiSVpA1s1ulV1GrgVOAI8BdxXVceS3Jlk91oPKEmzZNOQRVV1GDi8Yt8dZ1l73VsfS5Jmk2ekSVIjoytJjYyuJDUyupLUyOhKUiOjK0mNjK4kNTK6ktTI6EpSI6MrSY2MriQ1MrqS1MjoSlIjoytJjYyuJDUyupLUyOhKUiOjK0mNjK4kNTK6ktTI6EpSI6MrSY2MriQ1MrqS1MjoSlIjoytJjYyuJDUyupLUyOhKUiOjK0mNjK4kNTK6ktTI6EpSI6MrSY0GRTfJriRPJ1lKctsZrv9EkuNJnkjynSTvmf6okrTxrRrdJBcCB4EbgZ3AviQ7Vyx7DJivqt8Fvg78/bQHlaRZMOSZ7jXAUlWdqKpXgXuBPZMLqur+qnp5vHkU2DbdMSVpNgyJ7lbg2Yntk+N9Z7Mf+NaZrkhyIMliksXl5eXhU0rSjJjqG2lJbgbmgc+e6fqqOlRV81U1Pzc3N827lqQNYdOANaeA7RPb28b7fkGSG4DbgQ9U1SvTGU+SZsuQZ7qPADuSXJZkM7AXWJhckOQq4IvA7qp6fvpjStJsWDW6VXUauBU4AjwF3FdVx5LcmWT3eNlngV8Hvpbk8SQLZ7k5STqvDXl5gao6DBxese+Oics3THkuSZpJnpEmSY2MriQ1MrqS1MjoSlIjoytJjYyuJDUyupLUyOhKUiOjK0mNjK4kNTK6ktTI6EpSI6MrSY2MriQ1MrqS1MjoSlIjoytJjYyuJDUyupLUyOhKUiOjK0mNjK4kNTK6ktTI6EpSI6MrSY2MriQ1MrqS1MjoSlIjoytJjYyuJDUyupLUyOhKUiOjK0mNjK4kNTK6ktRoUHST7ErydJKlJLed4fpfTfLV8fUPJ7l06pNK0gxYNbpJLgQOAjcCO4F9SXauWLYfeLGqfgv4R+Az0x5UkmbBkGe61wBLVXWiql4F7gX2rFizB/i38eWvA9cnyfTGlKTZsGnAmq3AsxPbJ4HfP9uaqjqd5CXgN4GfTi5KcgA4MN58JcmT5zL0BraFFX8m5wGP+fxwvh3zb5/rNw6J7tRU1SHgEECSxaqa77z/9eYxnx885tmXZPFcv3fIywungO0T29vG+864Jskm4B3AC+c6lCTNqiHRfQTYkeSyJJuBvcDCijULwJ+NL/8p8O9VVdMbU5Jmw6ovL4xfo70VOAJcCNxVVceS3AksVtUC8K/Al5MsAT9jFObVHHoLc29UHvP5wWOefed8vPEJqST18Yw0SWpkdCWp0ZpH93w8hXjAMX8iyfEkTyT5TpL3rMec07TaMU+s+2CSSrKhP1405HiTfGj8OB9L8pXuGadtwM/1u5Pcn+Sx8c/2Tesx5zQluSvJ82c7pyAjnx//mTyR5OpVb7Sq1uyL0RtvPwTeC2wGvgfsXLHmL4AvjC/vBb66ljOt9dfAY/4j4NfGlz92PhzzeN3FwIPAUWB+vede48d4B/AY8Bvj7Xeu99wNx3wI+Nj48k7gmfWeewrH/YfA1cCTZ7n+JuBbQIBrgYdXu821fqZ7Pp5CvOoxV9X9VfXyePMoo88+b2RDHmeATzP6fzl+3jncGhhyvLcAB6vqRYCqer55xmkbcswFvH18+R3ATxrnWxNV9SCjT2SdzR7gSzVyFLgkybve7DbXOrpnOoV469nWVNVp4PVTiDeqIcc8aT+jvyk3slWPefzPru1V9c3OwdbIkMf4cuDyJA8lOZpkV9t0a2PIMX8KuDnJSeAw8PGe0dbVL/v73nsasH5RkpuBeeAD6z3LWkpyAfA54CPrPEqnTYxeYriO0b9kHkzyO1X13+s51BrbB9xdVf+Q5A8YfXb/yqr6n/Ue7P+TtX6mez6eQjzkmElyA3A7sLuqXmmaba2sdswXA1cCDyR5htFrXwsb+M20IY/xSWChql6rqh8BP2AU4Y1qyDHvB+4DqKrvAm9j9B/hzLJBv++T1jq65+MpxKsec5KrgC8yCu5Gf60PVjnmqnqpqrZU1aVVdSmj17F3V9U5/6ch62zIz/U3GD3LJckWRi83nGiccdqGHPOPgesBkryPUXSXW6fstwB8ePwphmuBl6rquTf9joZ3/25i9Lf8D4Hbx/vuZPRLB6MH5mvAEvAfwHvX+x3LhmP+NvBfwOPjr4X1nnmtj3nF2gfYwJ9eGPgYh9FLKseB7wN713vmhmPeCTzE6JMNjwN/st4zT+GY7wGeA15j9K+X/cBHgY9OPM4Hx38m3x/yc+1pwJLUyDPSJKmR0ZWkRkZXkhoZXUlqZHQlqZHRlaRGRleSGv0vMzgPTisUvzQAAAAASUVORK5CYII=\n", 207 | "text/plain": [ 208 | "
" 209 | ] 210 | }, 211 | "metadata": { 212 | "needs_background": "light" 213 | }, 214 | "output_type": "display_data" 215 | } 216 | ], 217 | "source": [ 218 | "import matplotlib.pyplot as plt\n", 219 | "\n", 220 | "plt.figure(figsize=(12,4))\n", 221 | "plot_settings = {'n_grid':30, 'x_span':[-1,1], 'device':device}\n", 222 | "\n", 223 | "for i, f in enumerate(funcs):\n", 224 | " # define the model\n", 225 | " model = NeuralODE(f, solver='tsit5', sensitivity='interpolated_adjoint', atol=1e-3, rtol=1e-3).to(device)\n", 226 | " # train the neural ODE\n", 227 | " learn = Learner(t_span, model)\n", 228 | " trainer = pl.Trainer(min_epochs=100, max_epochs=200, gpus=1)\n", 229 | " trainer.fit(learn)\n", 230 | " \n", 231 | " # plot the learned flows\n", 232 | " plt.subplot(1,2,1+i)\n", 233 | " traj = model.cpu().trajectory(X_train.cpu(), t_span).detach()\n", 234 | " plot_traj_vf_1D(model, t_span, traj, n_grid=30, x_span=[-1,1], device=device);" 235 | ] 236 | }, 237 | { 238 | "cell_type": "markdown", 239 | "metadata": {}, 240 | "source": [ 241 | "## Controlled Neural ODE models\n", 242 | "\n", 243 | "Following the work in [Massaroli S., Poli M., et al., 2020](https://arxiv.org/abs/2002.08071), we can easily approximate the reflection map leveraging **data-controlled Neural ODEs**. Data-control allows the Neural ODE to learn a family of vector fields instead of a single one, via conditioning the vector field `f` with the initial condition `x`." 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 6, 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "# define the data-controlled model\n", 253 | "f = nn.Sequential(DataControl(),\n", 254 | " nn.Linear(2, 64),\n", 255 | " nn.Tanh(),\n", 256 | " nn.Linear(64,1)\n", 257 | ").to(device)\n", 258 | "\n", 259 | "model = NeuralODE(f, solver='dopri5').to(device)" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 7, 265 | "metadata": {}, 266 | "outputs": [ 267 | { 268 | "name": "stderr", 269 | "output_type": "stream", 270 | "text": [ 271 | "GPU available: True, used: True\n", 272 | "TPU available: False, using: 0 TPU cores\n", 273 | "CUDA_VISIBLE_DEVICES: [0]\n", 274 | "\n", 275 | " | Name | Type | Params\n", 276 | "-----------------------------------\n", 277 | "0 | model | NeuralDE | 257 \n" 278 | ] 279 | }, 280 | { 281 | "data": { 282 | "application/vnd.jupyter.widget-view+json": { 283 | "model_id": "be7d5f252708455b8784852bb74e1112", 284 | "version_major": 2, 285 | "version_minor": 0 286 | }, 287 | "text/plain": [ 288 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…" 289 | ] 290 | }, 291 | "metadata": {}, 292 | "output_type": "display_data" 293 | }, 294 | { 295 | "name": "stdout", 296 | "output_type": "stream", 297 | "text": [ 298 | "\n" 299 | ] 300 | }, 301 | { 302 | "data": { 303 | "text/plain": [ 304 | "1" 305 | ] 306 | }, 307 | "execution_count": 7, 308 | "metadata": {}, 309 | "output_type": "execute_result" 310 | } 311 | ], 312 | "source": [ 313 | "# train the neural ODE\n", 314 | "learn = Learner(model)\n", 315 | "trainer = pl.Trainer(min_epochs=200, max_epochs=250, gpus=1)\n", 316 | "trainer.fit(learn)" 317 | ] 318 | }, 319 | { 320 | "cell_type": "markdown", 321 | "metadata": {}, 322 | "source": [ 323 | "**Plots**" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 8, 329 | "metadata": {}, 330 | "outputs": [], 331 | "source": [ 332 | "# evaluate the trajectories of each data point\n", 333 | "s_span = torch.linspace(0,1,100)\n", 334 | "traj = model.trajectory(X_train, s_span).cpu().detach()" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": 9, 340 | "metadata": {}, 341 | "outputs": [ 342 | { 343 | "data": { 344 | "text/plain": [ 345 | "Text(0.5, 1.0, 'Depth-Trajectories of Controlled Neural ODEs')" 346 | ] 347 | }, 348 | "execution_count": 9, 349 | "metadata": {}, 350 | "output_type": "execute_result" 351 | }, 352 | { 353 | "data": { 354 | "image/png": "\n", 355 | "text/plain": [ 356 | "
" 357 | ] 358 | }, 359 | "metadata": { 360 | "needs_background": "light" 361 | }, 362 | "output_type": "display_data" 363 | } 364 | ], 365 | "source": [ 366 | "# plot the depth evolution of the data\n", 367 | "fig = plt.figure(figsize=(6,3))\n", 368 | "ax = fig.add_subplot(111)\n", 369 | "ax.plot(s_span, traj[:,::5,0], color='blue', alpha=.3);\n", 370 | "ax.set_xlabel(r\"$s$ [Depth]\")\n", 371 | "ax.set_ylabel(r\"$z(s)$ [State]\")\n", 372 | "ax.set_title(r\"Depth-Trajectories of Controlled Neural ODEs\")" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": 10, 378 | "metadata": {}, 379 | "outputs": [ 380 | { 381 | "data": { 382 | "image/png": "\n", 383 | "text/plain": [ 384 | "
" 385 | ] 386 | }, 387 | "metadata": { 388 | "needs_background": "light" 389 | }, 390 | "output_type": "display_data" 391 | } 392 | ], 393 | "source": [ 394 | "# plot the evolution of the data in the s-x-h space\n", 395 | "from mpl_toolkits.mplot3d import Axes3D\n", 396 | "\n", 397 | "n_grid=30\n", 398 | "x_span=[-1,1]\n", 399 | "fig = plt.figure(figsize=(6,6)) ; ax =Axes3D(fig)\n", 400 | "ss = torch.linspace(s_span[0], s_span[-1], n_grid)\n", 401 | "xx = torch.linspace(x_span[0], x_span[-1], n_grid)\n", 402 | "S, X = torch.meshgrid(ss,xx) ; \n", 403 | "u_traj = traj[0,:,0].repeat(traj.shape[1],1)\n", 404 | "e = torch.abs(y.T - traj[:,:,0])\n", 405 | "color = plt.cm.plasma(e.numpy())\n", 406 | "for i in range(traj.shape[1]):\n", 407 | " tr = ax.scatter(s_span, u_traj[:,i], traj[:,i,0],\n", 408 | " c=color[:,i],alpha=1, cmap=color[:,i], zdir='z')\n", 409 | "norm = mpl.colors.Normalize(e.min(),e.max())\n", 410 | "plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap='plasma'),\n", 411 | " label='Approximation Error', orientation='horizontal')\n", 412 | "ax.set_xlabel(r\"$s$ [depth]\"); ax.set_ylabel(r\"$x$\"); ax.set_zlabel(r\"$z(s)$\") ; \n", 413 | "ax.xaxis._axinfo[\"grid\"]['color'] = (1,1,1,0)\n", 414 | "ax.yaxis._axinfo[\"grid\"]['color'] = (1,1,1,0)\n", 415 | "ax.zaxis._axinfo[\"grid\"]['color'] = (1,1,1,0)" 416 | ] 417 | } 418 | ], 419 | "metadata": { 420 | "kernelspec": { 421 | "display_name": "torchdyn", 422 | "language": "python", 423 | "name": "torchdyn" 424 | }, 425 | "language_info": { 426 | "codemirror_mode": { 427 | "name": "ipython", 428 | "version": 3 429 | }, 430 | "file_extension": ".py", 431 | "mimetype": "text/x-python", 432 | "name": "python", 433 | "nbconvert_exporter": "python", 434 | "pygments_lexer": "ipython3", 435 | "version": "3.8.5" 436 | }, 437 | "latex_envs": { 438 | "LaTeX_envs_menu_present": true, 439 | "autoclose": false, 440 | "autocomplete": true, 441 | "bibliofile": "biblio.bib", 442 | "cite_by": "apalike", 443 | "current_citInitial": 1, 444 | "eqLabelWithNumbers": true, 445 | "eqNumInitial": 1, 446 | "hotkeys": { 447 | "equation": "Ctrl-E", 448 | "itemize": "Ctrl-I" 449 | }, 450 | "labels_anchors": false, 451 | "latex_user_defs": false, 452 | "report_style_numbering": false, 453 | "user_envs_cfg": false 454 | }, 455 | "varInspector": { 456 | "cols": { 457 | "lenName": 16, 458 | "lenType": 16, 459 | "lenVar": 40 460 | }, 461 | "kernels_config": { 462 | "python": { 463 | "delete_cmd_postfix": "", 464 | "delete_cmd_prefix": "del ", 465 | "library": "var_list.py", 466 | "varRefreshCmd": "print(var_dic_list())" 467 | }, 468 | "r": { 469 | "delete_cmd_postfix": ") ", 470 | "delete_cmd_prefix": "rm(", 471 | "library": "var_list.r", 472 | "varRefreshCmd": "cat(var_dic_list()) " 473 | } 474 | }, 475 | "types_to_exclude": [ 476 | "module", 477 | "function", 478 | "builtin_function_or_method", 479 | "instance", 480 | "_Feature" 481 | ], 482 | "window_display": false 483 | } 484 | }, 485 | "nbformat": 4, 486 | "nbformat_minor": 4 487 | } 488 | --------------------------------------------------------------------------------