├── LICENSE ├── Plots ├── ResNet110_error.png ├── ResNet110_loss.png ├── ResNet56_error.png ├── ResNet56_error_2.png └── ResNet56_loss.png ├── README.md ├── ReZNet-6x_faster_ResNet_training_via_ReZero.ipynb ├── ReZero-Deep_Fast_NeuralNetwork.ipynb ├── ReZero-Deep_Fast_Transformer.ipynb └── requirements.txt /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 tbachlechner 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Plots/ResNet110_error.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tbachlechner/ReZero-examples/fe7e7ef080df6555018bec3102613c9e4c0d1f1d/Plots/ResNet110_error.png -------------------------------------------------------------------------------- /Plots/ResNet110_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tbachlechner/ReZero-examples/fe7e7ef080df6555018bec3102613c9e4c0d1f1d/Plots/ResNet110_loss.png -------------------------------------------------------------------------------- /Plots/ResNet56_error.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tbachlechner/ReZero-examples/fe7e7ef080df6555018bec3102613c9e4c0d1f1d/Plots/ResNet56_error.png -------------------------------------------------------------------------------- /Plots/ResNet56_error_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tbachlechner/ReZero-examples/fe7e7ef080df6555018bec3102613c9e4c0d1f1d/Plots/ResNet56_error_2.png -------------------------------------------------------------------------------- /Plots/ResNet56_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tbachlechner/ReZero-examples/fe7e7ef080df6555018bec3102613c9e4c0d1f1d/Plots/ResNet56_loss.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ReZero-examples 2 | 3 | This repo contains examples demonstrating the power of the ReZero architecture, see the [paper](https://arxiv.org/pdf/2003.04887.pdf). 4 | 5 | The official ReZero repo is [here](https://github.com/majumderb/rezero). 6 | 7 | - [ReZero speeds up superconvergence](https://github.com/tbachlechner/ReZero-Superconvergence/blob/master/Faster_SuperC.ipynb) 8 | 9 | - [ReZNet - Faster ResNet training via ReZero](https://github.com/tbachlechner/ReZero-examples/blob/master/ReZNet-6x_faster_ResNet_training_via_ReZero.ipynb) 10 |

11 | ResNet56_error 12 |

13 | 14 | Final valid errors: Vanilla - 7.74. FixUp - 7.5. ReZero - 6.38, [see](https://github.com/tbachlechner/Fixup) . 15 | 16 | - [Training 128 layer ReZero Transformer on WikiText-2 language modeling](https://github.com/tbachlechner/ReZero-examples/blob/master/ReZero-Deep_Fast_Transformer.ipynb) 17 | - [Training 10,000 layer ReZero fully connected network on CIFAR-10](https://github.com/tbachlechner/ReZero-examples/blob/master/ReZero-Deep_Fast_NeuralNetwork.ipynb) 18 | 19 | # Contribute 20 | 21 | If you find ReZero or a similar architecture improves the performance of your application, you are invited to share a demonstration here. 22 | 23 | # Install 24 | 25 | To install ReZero via pip use ```pip install rezero``` 26 | 27 | # Usage 28 | We provide custom ReZero Transformer layers (RZTX). 29 | 30 | For example, this will create a Transformer encoder: 31 | ```py 32 | import torch 33 | import torch.nn as nn 34 | from rezero.transformer import RZTXEncoderLayer 35 | 36 | encoder_layer = RZTXEncoderLayer(d_model=512, nhead=8) 37 | transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) 38 | src = torch.rand(10, 32, 512) 39 | out = transformer_encoder(src) 40 | ``` 41 | 42 | # Citation 43 | If you find `rezero` useful for your research, please cite our paper: 44 | ```BibTex 45 | @inproceedings{BacMajMaoCotMcA20, 46 | title = "ReZero is All You Need: Fast Convergence at Large Depth", 47 | author = "Bachlechner, Thomas and 48 | Majumder, Bodhisattwa Prasad 49 | Mao, Huanru Henry and 50 | Cottrell, Garrison W. and 51 | McAuley, Julian", 52 | booktitle = "arXiv", 53 | year = "2020", 54 | url = "https://arxiv.org/abs/2003.04887" 55 | } 56 | ``` 57 | -------------------------------------------------------------------------------- /ReZNet-6x_faster_ResNet_training_via_ReZero.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "###### ReZNet: (Up to 6x) faster ResNet training with ReZero\n", 8 | "\n", 9 | "In this notebook we will examine how the [ReZero](https://arxiv.org/abs/2003.04887) architecture addition enables or accelerates training in deep [ResNet](https://arxiv.org/pdf/1512.03385.pdf) networks. We will find for example that for a ResNet110 the number of epochs to reach 50% accuracy decreases by a factor of 6 upon implementing ReZero. In this particular example the accuracy after convergence also improves with ReZero. The architecture here differs importantly from [Fixup](https://arxiv.org/pdf/1901.09321.pdf) and [SkipInit](https://arxiv.org/pdf/2002.10444.pdf) in that the skip connection is implemented **after** the nonlinearity to preserve signal propagation.\n", 10 | "\n", 11 | "The official ReZero repo is [here](https://github.com/majumderb/rezero).\n", 12 | "\n", 13 | "This notebook is heavily inspired by [Yerlan Idelbayev's beautiful ResNet implementation](https://github.com/akamaster/pytorch_resnet_cifar10).\n", 14 | "\n", 15 | "Running time of the notebook: 15 minutes on laptop with single RTX 2060 GPU.\n", 16 | "\n", 17 | "Note: This notebook as evaluated with PyTorch 1.4, the test accuracies may differ slightly for other versions." 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 1, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "######################################################################\n", 27 | "# Import and set manual seed\n", 28 | "\n", 29 | "import time\n", 30 | "import torch\n", 31 | "import torch.nn as nn\n", 32 | "import torch.nn.parallel\n", 33 | "import torch.optim\n", 34 | "import torch.utils.data\n", 35 | "import torchvision.transforms as transforms\n", 36 | "import torchvision.datasets as datasets\n", 37 | "import torch.nn.functional as F\n", 38 | "import torch.nn.init as init\n", 39 | "\n", 40 | "torch.manual_seed(0)\n", 41 | "\n", 42 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 43 | "\n", 44 | "######################################################################\n", 45 | "# Define ResNet model as in \n", 46 | "# https://github.com/akamaster/pytorch_resnet_cifar10\n", 47 | "\n", 48 | "def _weights_init(m):\n", 49 | " if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):\n", 50 | " init.kaiming_normal_(m.weight)\n", 51 | "\n", 52 | "class LambdaLayer(nn.Module):\n", 53 | " def __init__(self, lambd):\n", 54 | " super(LambdaLayer, self).__init__()\n", 55 | " self.lambd = lambd\n", 56 | "\n", 57 | " def forward(self, x):\n", 58 | " return self.lambd(x)\n", 59 | "\n", 60 | "class BasicBlock(nn.Module):\n", 61 | " expansion = 1\n", 62 | "\n", 63 | " def __init__(self, in_planes, planes, stride=1, option='A', rezero = True):\n", 64 | " super(BasicBlock, self).__init__()\n", 65 | " self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n", 66 | " self.bn1 = nn.BatchNorm2d(planes)\n", 67 | " self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)\n", 68 | " self.bn2 = nn.BatchNorm2d(planes)\n", 69 | " self.rezero = rezero\n", 70 | " if self.rezero:\n", 71 | " self.resweight = self.resweight = nn.Parameter(torch.Tensor([0]), requires_grad=True)\n", 72 | " \n", 73 | " self.shortcut = nn.Sequential()\n", 74 | " if stride != 1 or in_planes != planes:\n", 75 | " self.shortcut = LambdaLayer(lambda x:\n", 76 | " F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), \"constant\", 0))\n", 77 | " \n", 78 | " def forward(self, x):\n", 79 | " out = F.relu(self.bn1(self.conv1(x)))\n", 80 | " out = self.bn2(self.conv2(out))\n", 81 | " \n", 82 | " if self.rezero == True:\n", 83 | " # In a ReZero ResNet the skip connection is after the nonlinearity\n", 84 | " out = self.resweight * F.relu(out) + self.shortcut(x)\n", 85 | " elif self.rezero == False:\n", 86 | " # In a vanilla ResNet the skip connection is before the nonlinearity\n", 87 | " out = F.relu(out + self.shortcut(x))\n", 88 | " return out\n", 89 | "\n", 90 | "\n", 91 | "class ResNet(nn.Module):\n", 92 | " def __init__(self, block, num_blocks, num_classes=10, rezero = False):\n", 93 | " super(ResNet, self).__init__()\n", 94 | " self.in_planes = 16\n", 95 | "\n", 96 | " self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)\n", 97 | " self.bn1 = nn.BatchNorm2d(16)\n", 98 | " self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1, rezero = rezero)\n", 99 | " self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2, rezero = rezero)\n", 100 | " self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2, rezero = rezero)\n", 101 | " self.linear = nn.Linear(64, num_classes)\n", 102 | "\n", 103 | " self.apply(_weights_init)\n", 104 | "\n", 105 | " def _make_layer(self, block, planes, num_blocks, stride, rezero = False):\n", 106 | " strides = [stride] + [1]*(num_blocks-1)\n", 107 | " layers = []\n", 108 | " for stride in strides:\n", 109 | " layers.append(block(self.in_planes, planes, stride, rezero = rezero))\n", 110 | " self.in_planes = planes * block.expansion\n", 111 | "\n", 112 | " return nn.Sequential(*layers)\n", 113 | "\n", 114 | " def forward(self, x):\n", 115 | " out = F.relu(self.bn1(self.conv1(x)))\n", 116 | " out = self.layer1(out)\n", 117 | " out = self.layer2(out)\n", 118 | " out = self.layer3(out)\n", 119 | " out = F.avg_pool2d(out, out.size()[3])\n", 120 | " out = out.view(out.size(0), -1)\n", 121 | " out = self.linear(out)\n", 122 | " return out\n", 123 | "\n", 124 | "######################################################################\n", 125 | "# Define various variants\n", 126 | "\n", 127 | "def resnet20(rezero = False):\n", 128 | " return ResNet(BasicBlock, [3, 3, 3], rezero = rezero)\n", 129 | "\n", 130 | "\n", 131 | "def resnet56(rezero = False):\n", 132 | " return ResNet(BasicBlock, [9, 9, 9], rezero = rezero)\n", 133 | "\n", 134 | "\n", 135 | "def resnet110(rezero = False):\n", 136 | " return ResNet(BasicBlock, [18, 18, 18], rezero = rezero)\n", 137 | "\n", 138 | "\n", 139 | "def test(net):\n", 140 | " import numpy as np\n", 141 | " total_params = 0\n", 142 | "\n", 143 | " for x in filter(lambda p: p.requires_grad, net.parameters()):\n", 144 | " total_params += np.prod(x.data.numpy().shape)\n", 145 | " print(\"Total number of params {:2.3f}M\".format(total_params/1e6))\n", 146 | " print(\"Total layers\", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters()))))\n", 147 | "\n", 148 | "######################################################################\n", 149 | "# Define function to train\n", 150 | "\n", 151 | "def train(train_loader, model, criterion, optimizer, epoch,print_freq,lr_scheduler):\n", 152 | " \"\"\"\n", 153 | " Run one train epoch\n", 154 | " \"\"\"\n", 155 | " batch_time = AverageMeter()\n", 156 | " data_time = AverageMeter()\n", 157 | " losses = AverageMeter()\n", 158 | " top1 = AverageMeter()\n", 159 | "\n", 160 | " # switch to train mode\n", 161 | " model.train()\n", 162 | "\n", 163 | " end = time.time()\n", 164 | " for i, (input, target) in enumerate(train_loader):\n", 165 | "\n", 166 | " # measure data loading time\n", 167 | " data_time.update(time.time() - end)\n", 168 | "\n", 169 | " target = target.cuda()\n", 170 | " input_var = input.cuda()\n", 171 | " target_var = target\n", 172 | "\n", 173 | " # compute output\n", 174 | " output = model(input_var)\n", 175 | " loss = criterion(output, target_var)\n", 176 | "\n", 177 | " # compute gradient and do SGD step\n", 178 | " optimizer.zero_grad()\n", 179 | " loss.backward()\n", 180 | " optimizer.step()\n", 181 | "\n", 182 | " output = output.float()\n", 183 | " loss = loss.float()\n", 184 | " # measure accuracy and record loss\n", 185 | " prec1 = accuracy(output.data, target)[0]\n", 186 | " losses.update(loss.item(), input.size(0))\n", 187 | " top1.update(prec1.item(), input.size(0))\n", 188 | "\n", 189 | " # measure elapsed time\n", 190 | " batch_time.update(time.time() - end)\n", 191 | " end = time.time()\n", 192 | "\n", 193 | "\n", 194 | " \n", 195 | " if i % print_freq == 0: \n", 196 | " print('| epoch {:3d} | {:4d}/{:4d} batches | '\n", 197 | " 'lr {:02.2f} | ms/batch {:4.0f} | '\n", 198 | " 'loss {loss.avg:1.3f} | Top 1 accuracy {top1.avg:2.2f} %'.format(\n", 199 | " epoch+1, i, len(train_loader), lr_scheduler.get_lr()[0],\n", 200 | " 1000*batch_time.avg,loss=losses,top1=top1))\n", 201 | "\n", 202 | "\n", 203 | "def validate(val_loader, model, criterion):\n", 204 | " \"\"\"\n", 205 | " Run evaluation\n", 206 | " \"\"\"\n", 207 | " batch_time = AverageMeter()\n", 208 | " losses = AverageMeter()\n", 209 | " top1 = AverageMeter()\n", 210 | "\n", 211 | " # switch to evaluate mode\n", 212 | " model.eval()\n", 213 | "\n", 214 | " end = time.time()\n", 215 | " with torch.no_grad():\n", 216 | " for i, (input, target) in enumerate(val_loader):\n", 217 | " target = target.cuda()\n", 218 | " input_var = input.cuda()\n", 219 | " target_var = target.cuda()\n", 220 | "\n", 221 | "\n", 222 | " # compute output\n", 223 | " output = model(input_var)\n", 224 | " loss = criterion(output, target_var)\n", 225 | "\n", 226 | " output = output.float()\n", 227 | " loss = loss.float()\n", 228 | "\n", 229 | " # measure accuracy and record loss\n", 230 | " prec1 = accuracy(output.data, target)[0]\n", 231 | " losses.update(loss.item(), input.size(0))\n", 232 | " top1.update(prec1.item(), input.size(0))\n", 233 | "\n", 234 | " # measure elapsed time\n", 235 | " batch_time.update(time.time() - end)\n", 236 | " end = time.time()\n", 237 | " \n", 238 | "\n", 239 | " return losses.avg, top1.avg\n", 240 | "\n", 241 | "class AverageMeter(object):\n", 242 | " \"\"\"Computes and stores the average and current value\"\"\"\n", 243 | " def __init__(self):\n", 244 | " self.reset()\n", 245 | "\n", 246 | " def reset(self):\n", 247 | " self.val = 0\n", 248 | " self.avg = 0\n", 249 | " self.sum = 0\n", 250 | " self.count = 0\n", 251 | "\n", 252 | " def update(self, val, n=1):\n", 253 | " self.val = val\n", 254 | " self.sum += val * n\n", 255 | " self.count += n\n", 256 | " self.avg = self.sum / self.count\n", 257 | "\n", 258 | "\n", 259 | "def accuracy(output, target, topk=(1,)):\n", 260 | " \"\"\"Computes the precision@k for the specified values of k\"\"\"\n", 261 | " maxk = max(topk)\n", 262 | " batch_size = target.size(0)\n", 263 | "\n", 264 | " _, pred = output.topk(maxk, 1, True, True)\n", 265 | " pred = pred.t()\n", 266 | " correct = pred.eq(target.view(1, -1).expand_as(pred))\n", 267 | "\n", 268 | " res = []\n", 269 | " for k in topk:\n", 270 | " correct_k = correct[:k].view(-1).float().sum(0)\n", 271 | " res.append(correct_k.mul_(100.0 / batch_size))\n", 272 | " return res\n", 273 | "\n", 274 | "######################################################################\n", 275 | "# Package model setup and training into one simple function\n", 276 | "\n", 277 | "def setup_and_train(model,batch_size = 128, lr = 0.1,momentum = 0.9,\n", 278 | " weight_decay = 1e-4,epochs = 200,print_freq = 50):\n", 279 | " model = model.to(device)\n", 280 | " start_epoch = 0\n", 281 | " normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", 282 | " std=[0.229, 0.224, 0.225])\n", 283 | "\n", 284 | "\n", 285 | " train_loader = torch.utils.data.DataLoader(\n", 286 | " datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([\n", 287 | " transforms.RandomHorizontalFlip(),\n", 288 | " transforms.RandomCrop(32, 4),\n", 289 | " transforms.ToTensor(),\n", 290 | " normalize,\n", 291 | " ]), download=True),\n", 292 | " batch_size=batch_size, shuffle=True,\n", 293 | " num_workers=1, pin_memory=True)\n", 294 | "\n", 295 | " val_loader = torch.utils.data.DataLoader(\n", 296 | " datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([\n", 297 | " transforms.ToTensor(),\n", 298 | " normalize,\n", 299 | " ])),\n", 300 | " batch_size=128, shuffle=False,\n", 301 | " num_workers=1, pin_memory=True)\n", 302 | "\n", 303 | " # define loss function (criterion) and pptimizer\n", 304 | " criterion = nn.CrossEntropyLoss().cuda()\n", 305 | "\n", 306 | "\n", 307 | " \n", 308 | " optimizer = torch.optim.SGD(model.parameters(), lr,\n", 309 | " momentum=momentum,\n", 310 | " weight_decay=weight_decay)\n", 311 | "\n", 312 | " lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,\n", 313 | " milestones=[100, 150], last_epoch=start_epoch - 1)\n", 314 | " best_prec1 = 0\n", 315 | " for epoch in range(start_epoch, epochs):\n", 316 | " epoch_start_time = time.time()\n", 317 | " print('-'*95)\n", 318 | " train(train_loader, model, criterion, optimizer, epoch,print_freq,lr_scheduler)\n", 319 | " lr_scheduler.step()\n", 320 | "\n", 321 | " # evaluate on validation set\n", 322 | " loss, prec1 = validate(val_loader, model, criterion)\n", 323 | "\n", 324 | " # remember best prec@1 \n", 325 | " is_best = prec1 > best_prec1\n", 326 | " best_prec1 = max(prec1, best_prec1)\n", 327 | "\n", 328 | " print('-'*95)\n", 329 | " print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:1.3f} | '\n", 330 | " 'valid precision {:3.2f}% (best: {:3.2f}%) '.format(epoch+1, (time.time() - epoch_start_time),\n", 331 | " loss, prec1,best_prec1))\n", 332 | "\n" 333 | ] 334 | }, 335 | { 336 | "cell_type": "markdown", 337 | "metadata": {}, 338 | "source": [ 339 | "## ResNet20\n", 340 | "\n", 341 | "First, we train a ResNet20 for with ReZero for one epoch, and then train a ResNet20 without Rezero until it achieves the same accuracy.\n", 342 | "\n", 343 | "* In this example ReZero accelerates initial training by a factor of about 2." 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 2, 349 | "metadata": {}, 350 | "outputs": [ 351 | { 352 | "name": "stdout", 353 | "output_type": "stream", 354 | "text": [ 355 | "Total number of params 0.270M\n", 356 | "Total layers 20\n", 357 | "Files already downloaded and verified\n", 358 | "-----------------------------------------------------------------------------------------------\n", 359 | "| epoch 1 | 0/ 391 batches | lr 0.20 | ms/batch 322 | loss 2.382 | Top 1 accuracy 5.47 %\n", 360 | "| epoch 1 | 130/ 391 batches | lr 0.20 | ms/batch 35 | loss 1.826 | Top 1 accuracy 31.22 %\n", 361 | "| epoch 1 | 260/ 391 batches | lr 0.20 | ms/batch 33 | loss 1.665 | Top 1 accuracy 38.30 %\n", 362 | "| epoch 1 | 390/ 391 batches | lr 0.20 | ms/batch 33 | loss 1.557 | Top 1 accuracy 42.68 %\n", 363 | "-----------------------------------------------------------------------------------------------\n", 364 | "| end of epoch 1 | time: 14.14s | valid loss 1.507 | valid precision 49.52% (best: 49.52%) \n" 365 | ] 366 | } 367 | ], 368 | "source": [ 369 | "model = resnet20(rezero=True)\n", 370 | "test(model)\n", 371 | "setup_and_train(model, batch_size = 128, lr = 0.2, epochs = 1, print_freq = 130)" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": 3, 377 | "metadata": {}, 378 | "outputs": [ 379 | { 380 | "name": "stdout", 381 | "output_type": "stream", 382 | "text": [ 383 | "Total number of params 0.270M\n", 384 | "Total layers 20\n", 385 | "Files already downloaded and verified\n", 386 | "-----------------------------------------------------------------------------------------------\n", 387 | "| epoch 1 | 0/ 391 batches | lr 0.20 | ms/batch 71 | loss 3.223 | Top 1 accuracy 9.38 %\n", 388 | "| epoch 1 | 130/ 391 batches | lr 0.20 | ms/batch 31 | loss 2.108 | Top 1 accuracy 23.83 %\n", 389 | "| epoch 1 | 260/ 391 batches | lr 0.20 | ms/batch 31 | loss 1.895 | Top 1 accuracy 29.94 %\n", 390 | "| epoch 1 | 390/ 391 batches | lr 0.20 | ms/batch 31 | loss 1.780 | Top 1 accuracy 34.10 %\n", 391 | "-----------------------------------------------------------------------------------------------\n", 392 | "| end of epoch 1 | time: 13.35s | valid loss 1.556 | valid precision 43.43% (best: 43.43%) \n", 393 | "-----------------------------------------------------------------------------------------------\n", 394 | "| epoch 2 | 0/ 391 batches | lr 0.20 | ms/batch 79 | loss 1.460 | Top 1 accuracy 46.09 %\n", 395 | "| epoch 2 | 130/ 391 batches | lr 0.20 | ms/batch 31 | loss 1.441 | Top 1 accuracy 46.95 %\n", 396 | "| epoch 2 | 260/ 391 batches | lr 0.20 | ms/batch 31 | loss 1.374 | Top 1 accuracy 49.59 %\n", 397 | "| epoch 2 | 390/ 391 batches | lr 0.20 | ms/batch 31 | loss 1.327 | Top 1 accuracy 51.55 %\n", 398 | "-----------------------------------------------------------------------------------------------\n", 399 | "| end of epoch 2 | time: 13.39s | valid loss 1.260 | valid precision 55.54% (best: 55.54%) \n" 400 | ] 401 | } 402 | ], 403 | "source": [ 404 | "model = resnet20(rezero=False)\n", 405 | "test(model)\n", 406 | "setup_and_train(model, batch_size = 128, lr = 0.2, epochs = 2, print_freq = 130)" 407 | ] 408 | }, 409 | { 410 | "cell_type": "markdown", 411 | "metadata": {}, 412 | "source": [ 413 | "## ResNet56\n", 414 | "\n", 415 | "Next, we train a ResNet56 for one epoch, and then train a ResNet56 without Rezero until it achieves the same accuracy.\n", 416 | "\n", 417 | "* In this example ReZero accelerates initial training by a factor of about 3." 418 | ] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "execution_count": 4, 423 | "metadata": {}, 424 | "outputs": [ 425 | { 426 | "name": "stdout", 427 | "output_type": "stream", 428 | "text": [ 429 | "Total number of params 0.853M\n", 430 | "Total layers 56\n", 431 | "Files already downloaded and verified\n", 432 | "-----------------------------------------------------------------------------------------------\n", 433 | "| epoch 1 | 0/ 391 batches | lr 0.10 | ms/batch 133 | loss 2.377 | Top 1 accuracy 3.12 %\n", 434 | "| epoch 1 | 130/ 391 batches | lr 0.10 | ms/batch 94 | loss 1.800 | Top 1 accuracy 31.95 %\n", 435 | "| epoch 1 | 260/ 391 batches | lr 0.10 | ms/batch 94 | loss 1.631 | Top 1 accuracy 39.12 %\n", 436 | "| epoch 1 | 390/ 391 batches | lr 0.10 | ms/batch 94 | loss 1.494 | Top 1 accuracy 44.63 %\n", 437 | "-----------------------------------------------------------------------------------------------\n", 438 | "| end of epoch 1 | time: 39.11s | valid loss 1.339 | valid precision 53.50% (best: 53.50%) \n" 439 | ] 440 | } 441 | ], 442 | "source": [ 443 | "model = resnet56(rezero=True)\n", 444 | "test(model)\n", 445 | "setup_and_train(model, batch_size = 128, lr = 0.1, epochs = 1, print_freq = 130)" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": 5, 451 | "metadata": {}, 452 | "outputs": [ 453 | { 454 | "name": "stdout", 455 | "output_type": "stream", 456 | "text": [ 457 | "Total number of params 0.853M\n", 458 | "Total layers 56\n", 459 | "Files already downloaded and verified\n", 460 | "-----------------------------------------------------------------------------------------------\n", 461 | "| epoch 1 | 0/ 391 batches | lr 0.10 | ms/batch 130 | loss 11.472 | Top 1 accuracy 8.59 %\n", 462 | "| epoch 1 | 130/ 391 batches | lr 0.10 | ms/batch 89 | loss 2.912 | Top 1 accuracy 10.69 %\n", 463 | "| epoch 1 | 260/ 391 batches | lr 0.10 | ms/batch 89 | loss 2.573 | Top 1 accuracy 12.43 %\n", 464 | "| epoch 1 | 390/ 391 batches | lr 0.10 | ms/batch 89 | loss 2.388 | Top 1 accuracy 16.01 %\n", 465 | "-----------------------------------------------------------------------------------------------\n", 466 | "| end of epoch 1 | time: 37.08s | valid loss 1.904 | valid precision 26.63% (best: 26.63%) \n", 467 | "-----------------------------------------------------------------------------------------------\n", 468 | "| epoch 2 | 0/ 391 batches | lr 0.10 | ms/batch 134 | loss 1.885 | Top 1 accuracy 27.34 %\n", 469 | "| epoch 2 | 130/ 391 batches | lr 0.10 | ms/batch 90 | loss 1.875 | Top 1 accuracy 29.13 %\n", 470 | "| epoch 2 | 260/ 391 batches | lr 0.10 | ms/batch 90 | loss 1.805 | Top 1 accuracy 31.76 %\n", 471 | "| epoch 2 | 390/ 391 batches | lr 0.10 | ms/batch 90 | loss 1.749 | Top 1 accuracy 34.07 %\n", 472 | "-----------------------------------------------------------------------------------------------\n", 473 | "| end of epoch 2 | time: 37.38s | valid loss 1.682 | valid precision 38.90% (best: 38.90%) \n", 474 | "-----------------------------------------------------------------------------------------------\n", 475 | "| epoch 3 | 0/ 391 batches | lr 0.10 | ms/batch 138 | loss 1.634 | Top 1 accuracy 36.72 %\n", 476 | "| epoch 3 | 130/ 391 batches | lr 0.10 | ms/batch 90 | loss 1.537 | Top 1 accuracy 42.96 %\n", 477 | "| epoch 3 | 260/ 391 batches | lr 0.10 | ms/batch 90 | loss 1.462 | Top 1 accuracy 46.00 %\n", 478 | "| epoch 3 | 390/ 391 batches | lr 0.10 | ms/batch 90 | loss 1.400 | Top 1 accuracy 48.51 %\n", 479 | "-----------------------------------------------------------------------------------------------\n", 480 | "| end of epoch 3 | time: 37.47s | valid loss 1.262 | valid precision 54.77% (best: 54.77%) \n", 481 | "-----------------------------------------------------------------------------------------------\n", 482 | "| epoch 4 | 0/ 391 batches | lr 0.10 | ms/batch 135 | loss 1.263 | Top 1 accuracy 55.47 %\n", 483 | "| epoch 4 | 130/ 391 batches | lr 0.10 | ms/batch 91 | loss 1.183 | Top 1 accuracy 57.22 %\n", 484 | "| epoch 4 | 260/ 391 batches | lr 0.10 | ms/batch 90 | loss 1.147 | Top 1 accuracy 58.76 %\n", 485 | "| epoch 4 | 390/ 391 batches | lr 0.10 | ms/batch 90 | loss 1.102 | Top 1 accuracy 60.59 %\n", 486 | "-----------------------------------------------------------------------------------------------\n", 487 | "| end of epoch 4 | time: 37.46s | valid loss 0.997 | valid precision 64.64% (best: 64.64%) \n" 488 | ] 489 | } 490 | ], 491 | "source": [ 492 | "model = resnet56(rezero=False)\n", 493 | "test(model)\n", 494 | "setup_and_train(model, batch_size = 128, lr = 0.1, epochs = 4, print_freq = 130)" 495 | ] 496 | }, 497 | { 498 | "cell_type": "markdown", 499 | "metadata": {}, 500 | "source": [ 501 | "## ResNet110\n", 502 | "\n", 503 | "Next, we train a ResNet110 for one epoch, and then train a ResNet110 without Rezero until it achieves the same accuracy.\n", 504 | "\n", 505 | "* In this example ReZero accelerates initial training by a factor of about 6." 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "execution_count": 6, 511 | "metadata": {}, 512 | "outputs": [ 513 | { 514 | "name": "stdout", 515 | "output_type": "stream", 516 | "text": [ 517 | "Total number of params 1.728M\n", 518 | "Total layers 110\n", 519 | "Files already downloaded and verified\n", 520 | "-----------------------------------------------------------------------------------------------\n", 521 | "| epoch 1 | 0/ 391 batches | lr 0.10 | ms/batch 223 | loss 2.402 | Top 1 accuracy 9.38 %\n", 522 | "| epoch 1 | 130/ 391 batches | lr 0.10 | ms/batch 189 | loss 1.793 | Top 1 accuracy 32.82 %\n", 523 | "| epoch 1 | 260/ 391 batches | lr 0.10 | ms/batch 189 | loss 1.589 | Top 1 accuracy 41.12 %\n", 524 | "| epoch 1 | 390/ 391 batches | lr 0.10 | ms/batch 189 | loss 1.440 | Top 1 accuracy 47.20 %\n", 525 | "-----------------------------------------------------------------------------------------------\n", 526 | "| end of epoch 1 | time: 78.27s | valid loss 1.641 | valid precision 50.65% (best: 50.65%) \n" 527 | ] 528 | } 529 | ], 530 | "source": [ 531 | "model = resnet110(rezero=True)\n", 532 | "test(model)\n", 533 | "setup_and_train(model, batch_size = 128, lr = 0.1, epochs = 1, print_freq = 130)" 534 | ] 535 | }, 536 | { 537 | "cell_type": "code", 538 | "execution_count": 7, 539 | "metadata": {}, 540 | "outputs": [ 541 | { 542 | "name": "stdout", 543 | "output_type": "stream", 544 | "text": [ 545 | "Total number of params 1.728M\n", 546 | "Total layers 110\n", 547 | "Files already downloaded and verified\n", 548 | "-----------------------------------------------------------------------------------------------\n", 549 | "| epoch 1 | 0/ 391 batches | lr 0.10 | ms/batch 222 | loss 12.921 | Top 1 accuracy 10.16 %\n", 550 | "| epoch 1 | 130/ 391 batches | lr 0.10 | ms/batch 178 | loss 3.905 | Top 1 accuracy 11.36 %\n", 551 | "| epoch 1 | 260/ 391 batches | lr 0.10 | ms/batch 177 | loss 3.072 | Top 1 accuracy 13.86 %\n", 552 | "| epoch 1 | 390/ 391 batches | lr 0.10 | ms/batch 177 | loss 2.731 | Top 1 accuracy 17.16 %\n", 553 | "-----------------------------------------------------------------------------------------------\n", 554 | "| end of epoch 1 | time: 73.58s | valid loss 1.948 | valid precision 27.01% (best: 27.01%) \n", 555 | "-----------------------------------------------------------------------------------------------\n", 556 | "| epoch 2 | 0/ 391 batches | lr 0.10 | ms/batch 223 | loss 2.013 | Top 1 accuracy 25.00 %\n", 557 | "| epoch 2 | 130/ 391 batches | lr 0.10 | ms/batch 178 | loss 1.936 | Top 1 accuracy 27.21 %\n", 558 | "| epoch 2 | 260/ 391 batches | lr 0.10 | ms/batch 178 | loss 1.892 | Top 1 accuracy 29.00 %\n", 559 | "| epoch 2 | 390/ 391 batches | lr 0.10 | ms/batch 178 | loss 1.854 | Top 1 accuracy 30.55 %\n", 560 | "-----------------------------------------------------------------------------------------------\n", 561 | "| end of epoch 2 | time: 73.73s | valid loss 2.182 | valid precision 36.19% (best: 36.19%) \n", 562 | "-----------------------------------------------------------------------------------------------\n", 563 | "| epoch 3 | 0/ 391 batches | lr 0.10 | ms/batch 223 | loss 1.756 | Top 1 accuracy 35.16 %\n", 564 | "| epoch 3 | 130/ 391 batches | lr 0.10 | ms/batch 178 | loss 1.721 | Top 1 accuracy 35.91 %\n", 565 | "| epoch 3 | 260/ 391 batches | lr 0.10 | ms/batch 178 | loss 1.694 | Top 1 accuracy 37.33 %\n", 566 | "| epoch 3 | 390/ 391 batches | lr 0.10 | ms/batch 178 | loss 1.668 | Top 1 accuracy 38.35 %\n", 567 | "-----------------------------------------------------------------------------------------------\n", 568 | "| end of epoch 3 | time: 73.83s | valid loss 1.855 | valid precision 41.56% (best: 41.56%) \n", 569 | "-----------------------------------------------------------------------------------------------\n", 570 | "| epoch 4 | 0/ 391 batches | lr 0.10 | ms/batch 223 | loss 1.597 | Top 1 accuracy 47.66 %\n", 571 | "| epoch 4 | 130/ 391 batches | lr 0.10 | ms/batch 178 | loss 1.550 | Top 1 accuracy 43.44 %\n", 572 | "| epoch 4 | 260/ 391 batches | lr 0.10 | ms/batch 178 | loss 1.534 | Top 1 accuracy 43.91 %\n", 573 | "| epoch 4 | 390/ 391 batches | lr 0.10 | ms/batch 178 | loss 1.517 | Top 1 accuracy 44.51 %\n", 574 | "-----------------------------------------------------------------------------------------------\n", 575 | "| end of epoch 4 | time: 73.94s | valid loss 1.819 | valid precision 39.79% (best: 41.56%) \n", 576 | "-----------------------------------------------------------------------------------------------\n", 577 | "| epoch 5 | 0/ 391 batches | lr 0.10 | ms/batch 223 | loss 1.387 | Top 1 accuracy 50.78 %\n", 578 | "| epoch 5 | 130/ 391 batches | lr 0.10 | ms/batch 178 | loss 1.438 | Top 1 accuracy 47.84 %\n", 579 | "| epoch 5 | 260/ 391 batches | lr 0.10 | ms/batch 178 | loss 1.419 | Top 1 accuracy 48.65 %\n", 580 | "| epoch 5 | 390/ 391 batches | lr 0.10 | ms/batch 178 | loss 1.405 | Top 1 accuracy 49.18 %\n", 581 | "-----------------------------------------------------------------------------------------------\n", 582 | "| end of epoch 5 | time: 73.97s | valid loss 1.450 | valid precision 50.40% (best: 50.40%) \n", 583 | "-----------------------------------------------------------------------------------------------\n", 584 | "| epoch 6 | 0/ 391 batches | lr 0.10 | ms/batch 224 | loss 1.421 | Top 1 accuracy 46.88 %\n", 585 | "| epoch 6 | 130/ 391 batches | lr 0.10 | ms/batch 179 | loss 1.353 | Top 1 accuracy 51.16 %\n", 586 | "| epoch 6 | 260/ 391 batches | lr 0.10 | ms/batch 179 | loss 1.330 | Top 1 accuracy 51.77 %\n", 587 | "| epoch 6 | 390/ 391 batches | lr 0.10 | ms/batch 178 | loss 1.316 | Top 1 accuracy 52.22 %\n", 588 | "-----------------------------------------------------------------------------------------------\n", 589 | "| end of epoch 6 | time: 74.00s | valid loss 1.273 | valid precision 55.72% (best: 55.72%) \n", 590 | "-----------------------------------------------------------------------------------------------\n", 591 | "| epoch 7 | 0/ 391 batches | lr 0.10 | ms/batch 225 | loss 1.224 | Top 1 accuracy 50.78 %\n", 592 | "| epoch 7 | 130/ 391 batches | lr 0.10 | ms/batch 179 | loss 1.245 | Top 1 accuracy 55.36 %\n", 593 | "| epoch 7 | 260/ 391 batches | lr 0.10 | ms/batch 178 | loss 1.241 | Top 1 accuracy 55.46 %\n", 594 | "| epoch 7 | 390/ 391 batches | lr 0.10 | ms/batch 178 | loss 1.226 | Top 1 accuracy 56.03 %\n", 595 | "-----------------------------------------------------------------------------------------------\n", 596 | "| end of epoch 7 | time: 73.95s | valid loss 1.222 | valid precision 58.21% (best: 58.21%) \n", 597 | "-----------------------------------------------------------------------------------------------\n", 598 | "| epoch 8 | 0/ 391 batches | lr 0.10 | ms/batch 222 | loss 1.029 | Top 1 accuracy 64.84 %\n", 599 | "| epoch 8 | 130/ 391 batches | lr 0.10 | ms/batch 179 | loss 1.161 | Top 1 accuracy 58.47 %\n", 600 | "| epoch 8 | 260/ 391 batches | lr 0.10 | ms/batch 178 | loss 1.155 | Top 1 accuracy 58.92 %\n", 601 | "| epoch 8 | 390/ 391 batches | lr 0.10 | ms/batch 178 | loss 1.143 | Top 1 accuracy 59.37 %\n", 602 | "-----------------------------------------------------------------------------------------------\n", 603 | "| end of epoch 8 | time: 73.99s | valid loss 1.488 | valid precision 54.72% (best: 58.21%) \n" 604 | ] 605 | } 606 | ], 607 | "source": [ 608 | "model = resnet110(rezero=False)\n", 609 | "test(model)\n", 610 | "setup_and_train(model, batch_size = 128, lr = 0.1, epochs = 8, print_freq = 130)" 611 | ] 612 | }, 613 | { 614 | "cell_type": "code", 615 | "execution_count": null, 616 | "metadata": {}, 617 | "outputs": [], 618 | "source": [] 619 | } 620 | ], 621 | "metadata": { 622 | "kernelspec": { 623 | "display_name": "Python 3", 624 | "language": "python", 625 | "name": "python3" 626 | }, 627 | "language_info": { 628 | "codemirror_mode": { 629 | "name": "ipython", 630 | "version": 3 631 | }, 632 | "file_extension": ".py", 633 | "mimetype": "text/x-python", 634 | "name": "python", 635 | "nbconvert_exporter": "python", 636 | "pygments_lexer": "ipython3", 637 | "version": "3.7.6" 638 | } 639 | }, 640 | "nbformat": 4, 641 | "nbformat_minor": 4 642 | } 643 | -------------------------------------------------------------------------------- /ReZero-Deep_Fast_NeuralNetwork.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Training 10,000 layer ReZero neural network on CIFAR-10 data\n", 8 | "\n", 9 | "In this notebook we will see how the [ReZero](https://arxiv.org/abs/2003.04887) architecture addition enables training of very deep networks. In particular, we will load the CIFAR-10 dataset via `torchvision` and train a deep fully connected network with various architecture additions to fit the data. We will compare four different methods: A vanilla fully connected network, residual connections, LayerNorm, and ReZero. We use identical hyperparameters (except for weight initialization, see below) and ReLU activations for each case. The architectures are described in [Table 1](https://arxiv.org/abs/2003.04887).\n", 10 | "\n", 11 | "Running time of the notebook: 6 minutes on laptop with single RTX 2060 GPU (and several hours for training 10,000 fully connected network at the end).\n", 12 | "\n", 13 | "### Define the model\n", 14 | "\n", 15 | "We now define the `DeepNN` model and several functions that load and prepare the data. Finally, we arrive at the function `setup_and_train`, that defines, trains and evaluates the model, and takes the following parameters as input:\n", 16 | "\n", 17 | "`version` : Defines architecture: `'Vanilla'`, `'Residual'`, `'LayerNorm'`, or `'ReZero'`.\n", 18 | "\n", 19 | "`epochs` : Number of epochs to train\n", 20 | "\n", 21 | "`depth` : Depth of NN\n", 22 | "\n", 23 | "`width` : Width of NN\n", 24 | "\n", 25 | "`lr` : Learning rate" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 1, 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "data": { 35 | "text/plain": [ 36 | "" 37 | ] 38 | }, 39 | "execution_count": 1, 40 | "metadata": {}, 41 | "output_type": "execute_result" 42 | } 43 | ], 44 | "source": [ 45 | "######################################################################\n", 46 | "# Import and set manual seed\n", 47 | "\n", 48 | "import torch\n", 49 | "import torchvision\n", 50 | "import torchvision.transforms as transforms\n", 51 | "import torch.nn as nn\n", 52 | "import torch.nn.functional as F\n", 53 | "import torch.optim as optim\n", 54 | "import time\n", 55 | "import numpy as np\n", 56 | "import matplotlib.pyplot as plt\n", 57 | "\n", 58 | "torch.manual_seed(0)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 2, 64 | "metadata": {}, 65 | "outputs": [ 66 | { 67 | "name": "stdout", 68 | "output_type": "stream", 69 | "text": [ 70 | "Files already downloaded and verified\n" 71 | ] 72 | } 73 | ], 74 | "source": [ 75 | "########################################################################\n", 76 | "# Download and define the training set.\n", 77 | "\n", 78 | "batchsize = 100\n", 79 | "transform = transforms.Compose([transforms.ToTensor(),\n", 80 | " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n", 81 | "trainset = torchvision.datasets.CIFAR10(root='./data', train=True,\n", 82 | " download=True, transform=transform)\n", 83 | "trainloader = torch.utils.data.DataLoader(trainset, batch_size=batchsize,\n", 84 | " shuffle=True, num_workers=2)\n", 85 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 86 | "\n", 87 | "\n", 88 | "########################################################################\n", 89 | "#Define input-output Jacobian\n", 90 | "\n", 91 | "def get_jacobian(model, x):\n", 92 | " nc = x.size()[0]\n", 93 | " ny = x.size()[2]\n", 94 | " nx = x.size()[1]\n", 95 | " noutputs = 10\n", 96 | " x = x.reshape(nc*nx*ny)\n", 97 | " x = x.repeat(noutputs,1)\n", 98 | " x.requires_grad_(True)\n", 99 | " y = model(x.reshape(noutputs,nc,nx,ny))\n", 100 | " y.backward(torch.eye(noutputs).to(device))\n", 101 | " return x.grad.data\n", 102 | "\n", 103 | "\n", 104 | "########################################################################\n", 105 | "# Define fully connected network with ReLu activations and\n", 106 | "# architectures :'ReZero', 'LayerNorm', 'Residual', 'Vanilla'\n", 107 | "\n", 108 | "class DeepNN(nn.Module):\n", 109 | " def __init__(self, lr, width, depth, version):\n", 110 | " super(DeepNN, self).__init__()\n", 111 | " self.linear_input = nn.Linear(3*32*32, width)\n", 112 | " self.linear_layers = nn.ModuleList([nn.Linear(width, width) for i in range(depth)])\n", 113 | " self.linear_output = nn.Linear(width, 10)\n", 114 | " self.version = version\n", 115 | " if self.version == 'ReZero':\n", 116 | " self.resweight = nn.Parameter(torch.zeros(depth), requires_grad=True)\n", 117 | " if self.version == 'LayerNorm':\n", 118 | " self.ln = torch.nn.LayerNorm((width))\n", 119 | " #Initialize:\n", 120 | " torch.nn.init.kaiming_normal_(self.linear_input.weight, a=0, mode='fan_in', nonlinearity='relu')\n", 121 | " for i in range(depth):\n", 122 | " if self.version == 'ReZero':\n", 123 | " torch.nn.init.xavier_normal_(self.linear_layers[i].weight, gain=torch.sqrt(torch.tensor(2.)))\n", 124 | " elif self.version == 'Vanilla':\n", 125 | " torch.nn.init.xavier_normal_(self.linear_layers[i].weight, gain=torch.sqrt(torch.tensor(2.)))\n", 126 | " elif self.version == 'Residual':\n", 127 | " # See https://arxiv.org/abs/1712.08969\n", 128 | " torch.nn.init.xavier_normal_(self.linear_layers[i].weight, gain=torch.sqrt(torch.tensor(0.25)))\n", 129 | " elif self.version == 'LayerNorm':\n", 130 | " torch.nn.init.xavier_normal_(self.linear_layers[i].weight, gain=torch.sqrt(torch.tensor(2.)))\n", 131 | " def forward(self, x):\n", 132 | " x = x.view(-1, 3*32*32)\n", 133 | " x = F.relu(self.linear_input(x))\n", 134 | " for i, j in enumerate(self.linear_layers):\n", 135 | " if self.version == 'ReZero':\n", 136 | " x = x + self.resweight[i] * torch.relu(self.linear_layers[i](x))\n", 137 | " elif self.version == 'Vanilla':\n", 138 | " x = F.relu(self.linear_layers[i](x))\n", 139 | " elif self.version == 'Residual':\n", 140 | " x = x + F.relu(self.linear_layers[i](x))\n", 141 | " elif self.version == 'LayerNorm':\n", 142 | " x = self.ln(F.relu(self.linear_layers[i](x)))\n", 143 | " x = self.linear_output(x)\n", 144 | " return x\n", 145 | "\n", 146 | "def setup_and_train(epochs, lr, width, depth, version, plt_jacobian = True):\n", 147 | " ######################################################################\n", 148 | " # Model setup\n", 149 | " model = DeepNN(lr, width, depth, version)\n", 150 | " model.to(device);\n", 151 | " ######################################################################\n", 152 | " # Define criterion and optimizer\n", 153 | " criterion = torch.nn.CrossEntropyLoss()\n", 154 | " optimizer = torch.optim.Adagrad(model.parameters(), lr = lr)\n", 155 | " scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=1.0)\n", 156 | " \n", 157 | " ######################################################################\n", 158 | " # Train the model\n", 159 | " model.train()\n", 160 | " for epoch in range(epochs): # loop over the dataset multiple times\n", 161 | " epoch_start_time = time.time()\n", 162 | " running_loss = 0.0\n", 163 | " log_interval = 100\n", 164 | " for batch, data in enumerate(trainloader, 0):\n", 165 | " # get the inputs; data is a list of [inputs, labels]\n", 166 | " inputs, labels = data[0].to(device), data[1].to(device)\n", 167 | "\n", 168 | " # zero the parameter gradients\n", 169 | " optimizer.zero_grad()\n", 170 | "\n", 171 | " # forward + backward + optimize\n", 172 | " outputs = model(inputs)\n", 173 | " loss = criterion(outputs, labels)\n", 174 | " loss.backward()\n", 175 | " optimizer.step()\n", 176 | " # print statistics\n", 177 | " running_loss += loss.item()\n", 178 | " cur_loss = running_loss / (batch+1)\n", 179 | " print('| end of epoch {:3d} | time / epoch {:5.2f}s | loss {:5.2f}'.format\n", 180 | " (epoch+1, (time.time() - epoch_start_time),cur_loss))\n", 181 | " running_loss = 0.\n", 182 | " \n", 183 | " if plt_jacobian == True:\n", 184 | " d_collected = list()\n", 185 | " u_collected = list()\n", 186 | " for i in range(100):\n", 187 | " src = torch.randn(3, 32, 32).to(device)\n", 188 | " J = get_jacobian(model,src)\n", 189 | " v, d, u = torch.svd(J.to('cpu'))\n", 190 | " d_collected.append(d.numpy().tolist())\n", 191 | " u_collected.append(u.numpy().tolist())\n", 192 | " d_ = np.asarray(d_collected).flatten()\n", 193 | " print('-' * 55)\n", 194 | " print('Mean sq singular value of io Jacobian:', \"%0.3f\" % np.mean(d_**2))\n", 195 | " fig, ax = plt.subplots()\n", 196 | " opacity=.7\n", 197 | " plt.ylim((0,1))\n", 198 | " plt.xlim((-7,4))\n", 199 | " ax.hist(np.log(d_)/np.log(10), bins = 10, alpha = opacity, label = 'NN model: ' + version,\n", 200 | " density = True)\n", 201 | " ax.legend(loc='upper left')\n", 202 | " ax.set_xlabel('log (io-Jacobian singular values)')\n", 203 | " plt.show()\n" 204 | ] 205 | }, 206 | { 207 | "cell_type": "markdown", 208 | "metadata": {}, 209 | "source": [ 210 | "### Vanilla network\n", 211 | "\n", 212 | "A vanilla network converges very slowly. There is a large spread in the singular value spectrum of the input-output Jacobian." 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 3, 218 | "metadata": {}, 219 | "outputs": [ 220 | { 221 | "name": "stdout", 222 | "output_type": "stream", 223 | "text": [ 224 | "| end of epoch 1 | time / epoch 7.81s | loss 54693.80\n", 225 | "| end of epoch 2 | time / epoch 7.75s | loss 2.34\n", 226 | "| end of epoch 3 | time / epoch 7.80s | loss 2.30\n", 227 | "| end of epoch 4 | time / epoch 7.72s | loss 2.30\n", 228 | "| end of epoch 5 | time / epoch 7.52s | loss 2.30\n", 229 | "| end of epoch 6 | time / epoch 7.74s | loss 2.28\n", 230 | "| end of epoch 7 | time / epoch 7.71s | loss 2.21\n", 231 | "| end of epoch 8 | time / epoch 7.65s | loss 2.08\n", 232 | "| end of epoch 9 | time / epoch 7.84s | loss 2.05\n", 233 | "| end of epoch 10 | time / epoch 7.67s | loss 2.02\n", 234 | "-------------------------------------------------------\n", 235 | "Mean sq singular value of io Jacobian: 0.829\n" 236 | ] 237 | }, 238 | { 239 | "data": { 240 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEKCAYAAADpfBXhAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAZP0lEQVR4nO3de5RU5Z3u8e+DIETFS8TkIBjFiRcYGkUaxGiUGXGCORk0UYR2EhUvjEYhLshFR4fDqHMJkhhdXhLUrPZEBjBGHUKY4yVKHB1RQAVRAoejRDsw8RYZUBFIfuePvbtTFNXd1VDdhW8/n7VY1n73u9/921X207v3rnpLEYGZmaWlS7ULMDOzynO4m5klyOFuZpYgh7uZWYIc7mZmCXK4m5klqNVwl/RjSW9KWtHMekm6RdIaScslHVf5Ms3MrC3KOXOvB0a1sP504Ij83wTgjl0vy8zMdkWr4R4RTwLvttDlDOB/R2YRsL+k3pUq0MzM2q5rBcboA7xRsNyQt60v7ihpAtnZPXvvvfeQo48+ugK7NzPrPJYuXfp2RBzUWr9KhLtKtJWc0yAiZgIzAWpra2PJkiUV2L2ZWech6Tfl9KvEu2UagEMKlvsC6yowrpmZ7aRKhPs84Lz8XTPDgQ0RscMlGTMz6zitXpaRNBsYAfSS1AD8L6AbQET8EFgAfBFYA3wAjG+vYs3MrDythntE1LWyPoDLK1HM1q1baWhoYPPmzZUYzj4GevToQd++fenWrVu1SzFLSiVuqFZMQ0MDPXv25LDDDkMqdZ/WUhIRvPPOOzQ0NNCvX79ql2OWlN1q+oHNmzdz4IEHOtg7CUkceOCB/kvNrB3sVuEOONg7Gb/eZu1jtwt3MzPbdbvVNfdiF9Uvruh4d18wtNU+kpg8eTLf+973AJgxYwabNm1i2rRpTJs2jenTp7N27Vo+9alPAbDPPvuwadOmitZZbMSIEcyYMYPa2tqd7rNw4UKuvvpqnnnmmaa2bdu20adPH1588UV6927bjBHr1q1j0qRJ3H///SxcuJAZM2Ywf/586uvrWbJkCbfeemubxjOzyvKZe5Hu3bvzwAMP8Pbbb5dc36tXr6bg/zg5+eSTaWhoYO3atU1tjz32GAMHDmxzsAMcfPDB3H///RWs0MwqyeFepGvXrkyYMIGbbrqp5PoLL7yQuXPn8u67Lc2llp3Rf+c732HIkCGMHDmS5557jhEjRnD44Yczb948ILuBPH78eGpqahg8eDBPPPEEAB9++CHjxo1j0KBBjB07lg8//LBp3EceeYQTTjiB4447jjFjxpT9V0OXLl0YM2YMc+fObWqbM2cOdXXZO13vvPNOhg4dyjHHHMNZZ53FBx98AMAFF1zApEmT+NznPsfhhx/eFOhr165l4MCBLe7z5z//OccffzyDBw9m5MiR/O53vyurVjPbdQ73Ei6//HJmzZrFhg0bdli3zz77cOGFF3LzzTe3OMb777/PiBEjWLp0KT179uTaa6/l0Ucf5cEHH2Tq1KkA3HbbbQC89NJLzJ49m/PPP5/Nmzdzxx13sNdee7F8+XKuueYali5dCsDbb7/NDTfcwGOPPcbzzz9PbW0t3//+93fY98UXX0ypeXvq6uqYM2cOAB999BELFizgrLPOAuArX/kKixcvZtmyZfTv35+77767abv169fz1FNPMX/+fK666qpynkIATjrpJBYtWsQLL7zAuHHjmD59etnbmtmu2a2vuVfLvvvuy3nnncctt9zCJz7xiR3WT5o0iWOPPZYpU6Y0O8aee+7JqFHZNPg1NTV0796dbt26UVNT03Rp5KmnnmLixIkAHH300Rx66KGsXr2aJ598kkmTJgEwaNAgBg0aBMCiRYt45ZVXOPHEEwHYsmULJ5xwwg77vuuuu0rWNHToUDZt2sSqVatYuXIlw4cP54ADDgBgxYoVXHvttbz33nts2rSJL3zhC03bnXnmmXTp0oUBAwa06ey7oaGBsWPHsn79erZs2eL3spt1IId7M6688kqOO+44xo/fcTaF/fffn3PPPZfbb7+92e27devW9Da/Ll260L1796bH27ZtA7IP8TSn1FsEI4LTTjuN2bNnt+lYCo0bN445c+awcuXKpksykF1+eeihhzjmmGOor69n4cKFTesaa2+t5mITJ05k8uTJjB49moULFzJt2rSdrtvM2saXZZrxyU9+knPOOWe7yxOFJk+ezI9+9KOmoN4ZJ598MrNmzQJg9erVvP766xx11FHbta9YsYLly5cDMHz4cJ5++mnWrFkDwAcffMDq1avbtM+6ujruvfdeHn/8cUaPHt3UvnHjRnr37s3WrVub9r2rNmzYQJ8+fQC45557KjKmmZVntz5zL+eti+1pypQpzb6lr1evXnz5y19u9sZrOb7+9a9z6aWXUlNTQ9euXamvr6d79+5cdtlljB8/nkGDBnHssccybNgwAA466CDq6+upq6vjo48+AuCGG27gyCOP3G7ciy++mEsvvbTk2yIHDBjAXnvtxZAhQ9h7772b2q+//nqOP/54Dj30UGpqati4ceNOH1ejadOmMWbMGPr06cPw4cN57bXXdnlMMyuP2vJndiWV+rKOlStX0r9//6rUY9Xj192sfJKWRkTzH3rJ+bKMmVmCHO5mZgna7cK9WpeJrDr8epu1j90q3Hv06ME777zjH/hOonE+9x49elS7FLPk7Fbvlunbty8NDQ289dZb1S7FOkjjNzGZWWXtVuHerVs3f4rRzKwCdqvLMmZmVhkOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0tQWeEuaZSkVZLWSLqqxPrPSHpC0guSlkv6YuVLNTOzcrUa7pL2AG4DTgcGAHWSBhR1uxa4LyIGA+OA2ytdqJmZla+cM/dhwJqIeDUitgBzgDOK+gSwb/54P2Bd5Uo0M7O2Kifc+wBvFCw35G2FpgFfldQALAAmlhpI0gRJSyQt8Zdgm5m1n3LCXSXaomi5DqiPiL7AF4GfSNph7IiYGRG1EVF70EEHtb1aMzMrSznh3gAcUrDclx0vu1wE3AcQEc8APYBelSjQzMzarmsZfRYDR0jqB/yW7IbpuUV9XgdOBeol9ScL9xavu6x9530uql/c5oLvvmBom7cxM+tsWj1zj4htwBXAw8BKsnfFvCzpOkmj825TgEskLQNmAxdERPGlGzMz6yDlnLkTEQvIbpQWtk0tePwKcGJlSzMzs53lT6iamSXI4W5mliCHu5lZghzuZmYJcribmSXI4W5mliCHu5lZghzuZmYJcribmSXI4W5mliCHu5lZghzuZmYJcribmSXI4W5mliCHu5lZghzuZmYJcribmSXI4W5mliCHu5lZghzuZmYJcribmSXI4W5mliCHu5lZghzuZmYJcribmSXI4W5mliCHu5lZghzuZmYJcribmSXI4W5mliCHu5lZghzuZmYJcribmSXI4W5mlqCywl3SKEmrJK2RdFUzfc6R9IqklyX9a2XLNDOztujaWgdJewC3AacBDcBiSfMi4pWCPkcAVwMnRsTvJX2qvQo2M7PWlXPmPgxYExGvRsQWYA5wRlGfS4DbIuL3ABHxZmXLNDOztign3PsAbxQsN+RthY4EjpT0tKRFkkaVGkjSBElLJC3ZvPG9navYzMxa1eplGUAl2qLEOEcAI4C+wH9IGhgR2yV4RMwEZgL06te/eAwzM6uQcs7cG4BDCpb7AutK9Pm3iNgaEa8Bq8jC3szMqqCccF8MHCGpn6Q9gXHAvKI+DwF/ASCpF9llmlcrWaiZmZWv1XCPiG3AFcDDwErgvoh4WdJ1kkbn3R4G3pH0CvAE8K2IeKe9ijYzs5aVc82diFgALChqm1rwOIDJ+T8zM6syf0LVzCxBDnczswQ53M3MEuRwNzNLkMPdzCxBDnczswQ53M3MEuRwNzNLkMPdzCxBDnczswQ53M3MEuRwNzNLkMPdzCxBDnczswQ53M3MEuRwNzNLkMPdzCxBDnczswQ53M3MEuRwNzNLkMPdzCxBDnczswQ53M3MEuRwNzNLkMPdzCxBDnczswQ53M3MEuRwNzNLkMPdzCxBDnczswQ53M3MEuRwNzNLkMPdzCxBDnczswQ53M3MElRWuEsaJWmVpDWSrmqh39mSQlJt5Uo0M7O2ajXcJe0B3AacDgwA6iQNKNGvJzAJeLbSRZqZWduUc+Y+DFgTEa9GxBZgDnBGiX7XA9OBzRWsz8zMdkI54d4HeKNguSFvayJpMHBIRMxvaSBJEyQtkbRk88b32lysmZmVp5xwV4m2aFopdQFuAqa0NlBEzIyI2oio7dFz//KrNDOzNikn3BuAQwqW+wLrCpZ7AgOBhZLWAsOBeb6pamZWPeWE+2LgCEn9JO0JjAPmNa6MiA0R0SsiDouIw4BFwOiIWNIuFZuZWataDfeI2AZcATwMrATui4iXJV0naXR7F2hmZm3XtZxOEbEAWFDUNrWZviN2vSwzM9sV/oSqmVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZgkqK9wljZK0StIaSVeVWD9Z0iuSlkv6paRDK1+qmZmVq9Vwl7QHcBtwOjAAqJM0oKjbC0BtRAwC7gemV7pQMzMrXzln7sOANRHxakRsAeYAZxR2iIgnIuKDfHER0LeyZZqZWVuUE+59gDcKlhvytuZcBPx7qRWSJkhaImnJ5o3vlV+lmZm1Sdcy+qhEW5TsKH0VqAVOKbU+ImYCMwF69etfcgwzM9t15YR7A3BIwXJfYF1xJ0kjgWuAUyLio8qUZ2ZmO6OcyzKLgSMk9ZO0JzAOmFfYQdJg4EfA6Ih4s/JlmplZW7Qa7hGxDbgCeBhYCdwXES9Luk7S6LzbjcA+wE8lvShpXjPDmZlZByjnsgwRsQBYUNQ2teDxyArXZWZmu8CfUDUzS5DD3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQWVN+bs7uah+cYfu7+4Lhnbo/szMKuFjF+5m1ryOPvkBnwDtrnxZxswsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBJUVrhLGiVplaQ1kq4qsb67pLn5+mclHVbpQs3MrHythrukPYDbgNOBAUCdpAFF3S4Cfh8RnwVuAr5b6ULNzKx85Zy5DwPWRMSrEbEFmAOcUdTnDOCe/PH9wKmSVLkyzcysLRQRLXeQzgZGRcTF+fLXgOMj4oqCPivyPg358v/L+7xdNNYEYEK+eBSwqlIH0ka9gLdb7ZWOzna84GPuLDrjMR8VET1b69S1jIFKnYEX/0Yopw8RMROYWcY+25WkJRFRW+06OkpnO17wMXcWnfWYy+lXzmWZBuCQguW+wLrm+kjqCuwHvFtOAWZmVnnlhPti4AhJ/STtCYwD5hX1mQecnz8+G3g8WrveY2Zm7abVyzIRsU3SFcDDwB7AjyPiZUnXAUsiYh5wN/ATSWvIztjHtWfRFVD1S0MdrLMdL/iYOwsfczNavaFqZmYfP/6EqplZghzuZmYJ6rThLmliPqXCy5KmV7uejiLpm5JCUq9q19LeJN0o6deSlkt6UNL+1a6pvbQ2RUhqJB0i6QlJK/Of4W9Uu6aOIGkPSS9Imt9a304Z7pL+guxTtYMi4s+BGVUuqUNIOgQ4DXi92rV0kEeBgRExCFgNXF3letpFmVOEpGYbMCUi+gPDgcs7wTEDfANYWU7HThnuwGXAv0TERwAR8WaV6+koNwHfpsQHzFIUEY9ExLZ8cRHZZzRSVM4UIUmJiPUR8Xz+eCNZ4PWpblXtS1Jf4H8Cd5XTv7OG+5HA5/MZLH8laWi1C2pvkkYDv42IZdWupUouBP692kW0kz7AGwXLDSQedIXyWWgHA89Wt5J29wOyk7M/ltO5nOkHPpYkPQb8jxKrriE77gPI/pwbCtwn6fCP+wevWjnmvwP+qmMran8tHXNE/Fve5xqyP+NndWRtHais6T9SJGkf4GfAlRHx39Wup71I+hLwZkQslTSinG2SDfeIGNncOkmXAQ/kYf6cpD+STUD0VkfV1x6aO2ZJNUA/YFk+WWdf4HlJwyLivzqwxIpr6XUGkHQ+8CXg1I/7L+8WlDNFSHIkdSML9lkR8UC162lnJwKjJX0R6AHsK+neiPhqcxt0yg8xSboUODgipko6Evgl8JmEf/i3I2ktUFs8a2dqJI0Cvg+cEhEf61/cLcnnc1oNnAr8lmzKkHMj4uWqFtaO8inF7wHejYgrq11PR8rP3L8ZEV9qqV9nveb+Y+DwfKriOcD5nSXYO5lbgZ7Ao5JelPTDahfUHvKbxo1ThKwE7ks52HMnAl8D/jJ/bV/Mz2ot1ynP3M3MUtdZz9zNzJLmcDczS5DD3cwsQQ53M7MEOdzNzBLkcE+UpE0VHOsHkk7OH9/VlgmaJI0oZwa7NtazttSslpIulXReJfdVNH6bjr0N4y6U1GFf8ixpmqRvVnjMxyQdUMkxbdck+wlVqwxJnwSGN35QJCIurnJJzYqIdn0f++5y7JL2iIg/VLuOIj8Bvg78Y7ULsYzP3BOnzI2SVkh6SdLYvL2LpNvzubDnS1og6ewSQ5wN/J+C8ZrOMiXV5WOukPTdMmoZJuk/8/mo/1PSUXn7HpJm5GMtlzQxbz817/uSpB9L6l4w3LckPZf/+2zev+mMVNIlkhZLWibpZ5L2ytvrJd2S7//VUscsaW9Jv8i3XVHwnBUe+yZJ/5j3WSTp03n7n+XLiyVd1/gXVPFfMJJulXRBiX3fIWlJ/rr8Q0H7WklTJT0FjClo3y9f1yVf3kvSG5K6NfccFO2v8Jh65Z9ebnxNbsy3Xy7pb/P23pKezD80tELS5/Oh5gF1zb/61tEc7un7CnAscAwwErhRUu+8/TCgBrgYOKGZ7U8ElhY3SjoY+C7wl/n4QyWd2UotvwZOjojBwFTgn/L2CWRz3wzO516fJakHUA+MjYgasr8yLysY678jYhjZp1B/UGJfD0TE0Ig4huxTmxcVrOsNnEQ258y/lNh2FLAuIo6JiIEU/HIrsDewKB//SeCSvP1m4OaIGMrOze9yTUTUAoOAUyQNKli3OSJOiog5jQ0RsQFYBpySN/018HBEbKXl56A1FwEb8uMYClwiqR9wbj5+4/9TL+Z1/B7oLunAnThmawcO9/SdBMyOiD9ExO+AX5H9sJ4E/DQi/phPHvZEM9v3pvSEakOBhRHxVv7x91nAya3Ush/wU2XTPtwE/HnePhL4YePc6xHxLnAU8FpErM773FM0/uyC/5b6xTRQ0n9Iegn4m4J9ATyUH/crwKdLbPsSMFLSdyV9Pg/QYluAxjPxpWS/KMlr+Wn++F9LbNeacyQ9D7yQ11x4jX9uM9vMBcbmj8cV9GvpOWjNXwHnSXqRbCrdA4EjyOatGS9pGlCTz6Xe6E3g4Dbsw9qRwz19paaDbam92Idks9CVtb2kL+tPc30U3yS8HngiPxv+64JxxY5T1LZWXzTzuFE9cEV+1v8PbH8MH7W0n/wXyhCykP9nSVNLjL+1YD6iP9D6/attbP/ztsNzmp8Zf5NsBstBwC+K+r3fzNjzgNOV3R8ZAjyet9fT/HNQqq7C9QImRsSx+b9++ZefPEn2S/a3wE+0/Q3sHmT/v9huwOGevieBsfk11IPIfjCfA54CzsqvvX8aGNHM9iuBz5Zof5bsskEvZV/zVgf8KiIeLAiEJUXb7EcWCgAXFLQ/AlyqbHbDxpu4vwYOa7yeTjZJ1K8Kthlb8N9nStTXE1ivbFrYv2nm2ErKLzl9EBH3kn0F43Ft2HwRcFb+eFxB+2+AAZK6S9qPbAbHYvuSBfiG/DU5vZwdRsQmstf0ZmB+wc3Wcp6DtWS/ECC7v9LoYeCyfFskHZnfiziUbF7xO4G7yZ8bSSKbV39tOTVb+/O7ZdL3INmlgmVkZ7jfjoj/kvQzsoBZQTZd7LNAqcsPvwD+lqKv9oqI9ZKuJrucI2BB45djFOnKn86UpwP3SJrMn84uycc+ElguaStwZ0TcKmk82WWcrmSXAwrfDdNd0rNkJyilbuT9fX5MvyE7A+9Zok9zasjuTfwR2Mr21/pbcyVwr6QpZM/dBoCIeEPSfcBy4P+SXXbZTkQsk/QC8DLwKvB0G/Y7l+xy0IiCtnKegxlkX1bzNXZ8TQ4jm/dfZJfmzszH/1b+Om0CGs/ch5Ddg9iG7RY8K2QnJmmfiNiU3wR7Djix1Jd35O/Q+FJEvLcT+/gG0Ccivr3rFe/+8nekfBgRIWkcUBcRSX+fKYCkm4F5EfHLatdiGZ+5d27zJe0P7Alc38K3Mk0BPgO0Kdwl3Q0MBM7ZpSo/XoYAt+Znu++RfXdrZ7DCwb578Zm7mVmCfEPVzCxBDnczswQ53M3MEuRwNzNLkMPdzCxB/x8jWCSiiDq5pgAAAABJRU5ErkJggg==\n", 241 | "text/plain": [ 242 | "
" 243 | ] 244 | }, 245 | "metadata": { 246 | "needs_background": "light" 247 | }, 248 | "output_type": "display_data" 249 | } 250 | ], 251 | "source": [ 252 | "######################################################################\n", 253 | "# The model is set up with the hyperparameter below.\n", 254 | "\n", 255 | "version = 'Vanilla' # Architecture\n", 256 | "epochs = 10 # Number of epochs\n", 257 | "depth = 64 # Number of layers\n", 258 | "width = 256 # Width\n", 259 | "lr = 0.01 # Learning rate for Adagrad optimizer\n", 260 | "\n", 261 | "setup_and_train(epochs, lr, width, depth, version, plt_jacobian = True)" 262 | ] 263 | }, 264 | { 265 | "cell_type": "markdown", 266 | "metadata": {}, 267 | "source": [ 268 | "### Residual network\n", 269 | "\n", 270 | "A network with a residual connection converges a bit better. There still is a large spread in the singular value spectrum of the input-output Jacobian." 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 4, 276 | "metadata": {}, 277 | "outputs": [ 278 | { 279 | "name": "stdout", 280 | "output_type": "stream", 281 | "text": [ 282 | "| end of epoch 1 | time / epoch 8.61s | loss 131269.63\n", 283 | "| end of epoch 2 | time / epoch 9.10s | loss 26.92\n", 284 | "| end of epoch 3 | time / epoch 8.70s | loss 11.76\n", 285 | "| end of epoch 4 | time / epoch 8.47s | loss 4.55\n", 286 | "| end of epoch 5 | time / epoch 8.81s | loss 2.32\n", 287 | "| end of epoch 6 | time / epoch 9.79s | loss 2.10\n", 288 | "| end of epoch 7 | time / epoch 8.93s | loss 2.00\n", 289 | "| end of epoch 8 | time / epoch 8.42s | loss 1.95\n", 290 | "| end of epoch 9 | time / epoch 8.65s | loss 1.91\n", 291 | "| end of epoch 10 | time / epoch 8.41s | loss 1.87\n", 292 | "-------------------------------------------------------\n", 293 | "Mean sq singular value of io Jacobian: 225.593\n" 294 | ] 295 | }, 296 | { 297 | "data": { 298 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEKCAYAAADpfBXhAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAaC0lEQVR4nO3de5RU5Z3u8e8jIERBMdDxGFHBCd6G5tqALrww4yWYODjBqOBkFLygeIsK5nI0BNE5J1ESovEWURcmYVCZoLIUg5pRGXVQGgVEUYIMkY6ZI6IyEkUg/s4ftekUTXV3NVR34dvPZy0Wtfd+996/Xd399NvvrnpLEYGZmaVlt3IXYGZmpedwNzNLkMPdzCxBDnczswQ53M3MEuRwNzNLUKPhLuleSe9KWlbPdkm6RdJKSUsl9S99mWZm1hTF9NynA8Ma2H4y0DP7Nxa4Y+fLMjOzndFouEfEfOD9BpqcCvwychYAnSXtV6oCzcys6dqW4Bj7A2vylmuydX+q21DSWHK9e/bcc88Bhx12WAlOb9a6rV7354Lru3fZs4UrsZawaNGi9yKiorF2pQh3FVhXcE6DiLgLuAugqqoqqqurS3B6s9btvOkLC66/Z/TAFq7EWoKkPxTTrhSvlqkBDshb7ga8U4LjmpnZDipFuM8Bzs5eNXMksD4ithuSMTOzltPosIykmcBQoKukGuCHQDuAiLgTmAt8DVgJfAyMaa5izcysOI2Ge0SMamR7AJeUopjNmzdTU1PDxo0bS3E4S0CHDh3o1q0b7dq1K3cpZp8rpbihWjI1NTV06tSJ7t27IxW6T2utSUSwbt06ampq6NGjR7nLMftc2aWmH9i4cSNdunRxsBsAkujSpYv/kjPbAbtUzx1wsNs2/P2w4/wSydZtl+q5m5lZaexyPfd89fU8dlQxPRZJXHXVVfzkJz8BYMqUKWzYsIFJkyYxadIkbrzxRlavXs2XvvQlADp27MiGDRtKWmddQ4cOZcqUKVRVVe1Um0mTJjFt2jQqKirYtGkTP/jBDxg1qsH75QVVV1fzy1/+kltuuWW7bd27d6e6upquXbs2+biTJk2iY8eOTJgwocn7mtm23HOvo3379syePZv33nuv4PauXbvWBv/n0ZVXXsnixYt55JFHuPDCC9m8eXOTj1FVVVUw2M1s1+Fwr6Nt27aMHTuWqVOnFtx+7rnn8sADD/D++w3NpZbr0X/3u99lwIABnHDCCbz00ksMHTqUgw8+mDlz5gC5G8hjxoyhsrKSfv368fTTTwPwySefMHLkSHr37s2ZZ57JJ598UnvcJ554gqOOOor+/ftz+umn7/BfDT179mSPPfbggw8+AOCtt95i2LBhDBgwgGOOOYY33ngDgFmzZtGrVy/69OnDscceC8AzzzzDKaecAsC6des46aST6NevHxdeeCG5V8bC6tWr6dWrV+35pkyZwqRJkwCYNm0aAwcOpE+fPpx22ml8/PHHO3QNZlY/h3sBl1xyCTNmzGD9+vXbbevYsSPnnnsuN998c4PH+POf/8zQoUNZtGgRnTp14tprr+XJJ5/koYceYuLEiQDcdtttALz66qvMnDmTc845h40bN3LHHXewxx57sHTpUq655hoWLVoEwHvvvccNN9zAU089xcsvv0xVVRU//elPtzv3+eefT2Pz9rz88sv07Nmzdnhp7Nix/PznP2fRokVMmTKFiy++GIDJkyczb948lixZUvtLKd91113H0UcfzSuvvMLw4cN5++23GzwvwIgRI1i4cCFLlizh8MMP55577ml0HzNrml16zL1c9tprL84++2xuueUWvvCFL2y3/fLLL6dv376MHz++3mPsvvvuDBuWmwa/srKS9u3b065dOyorK1m9ejUAzz33HJdddhkAhx12GAcddBArVqxg/vz5XH755QD07t2b3r17A7BgwQJef/11hgwZAsCmTZs46qijtjv33XffXW9dU6dOZdq0aaxatYrf/va3AGzYsIEXXniB008/vbbdp59+CsCQIUMYPXo0Z5xxBiNGjNjuePPnz2f27NkAfP3rX2efffap99xbLVu2jGuvvZYPP/yQDRs28NWvfrXRfcysaRzu9bjiiivo378/Y8ZsP5tC586dOeuss7j99tvr3b9du3a1L+PbbbfdaN++fe3jLVu2ANQOYRRS6CWAEcGJJ57IzJkzm3Qt+a688komTJjA7NmzOfvss3nrrbf47LPP6Ny5M4sXL96u/Z133smLL77IY489Rt++fQu2KVRr27Zt+eyzz2qX81+rPnr0aB5++GH69OnD9OnTeeaZZ3b4esysMA/L1OOLX/wiZ5xxRr1DBldddRW/+MUvaoN6Rxx77LHMmDEDgBUrVvD2229z6KGHbrN+2bJlLF26FIAjjzyS559/npUrVwLw8ccfs2LFih0694gRI6iqquK+++5jr732okePHsyaNQvI/RJZsmQJkBuLHzx4MJMnT6Zr166sWbNmm+Pk1/r444/XjuHvu+++vPvuu6xbt45PP/2URx99tHafjz76iP3224/NmzfX7mtmpbVL99zL/WaL8ePHc+uttxbc1rVrV77xjW/Ue+O1GBdffDEXXXQRlZWVtG3blunTp9O+fXvGjRvHmDFj6N27N3379mXQoEEAVFRUMH36dEaNGlU7bHLDDTdwyCGHbHPc888/n4suuqjBl0UCTJw4kbPOOosLLriAGTNmMG7cOG644QY2b97MyJEj6dOnD1dffTW///3viQiOP/54+vTpw7PPPlt7jB/+8IeMGjWK/v37c9xxx3HggQcCub9cJk6cyODBg+nRowf5H8xy/fXXM3jwYA466CAqKyv56KOPdvg5bC1K/bJgS58aGhpoToU+rGP58uUcfvjhZanHdl3+vihtuJe702Q7R9KiiGi454aHZczMkuRwNzNL0C4X7uUaJrJdk78fzHbMLhXuHTp0YN26df6BNuCv87l36NCh3KWYfe7sUq+W6datGzU1Naxdu7bcpdguYusnMZlZ0+xS4d6uXTt/4o6ZWQnsUsMyZmZWGg53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBLkcDczS1BR4S5pmKQ3Ja2U9L0C2w+U9LSkVyQtlfS10pdqZmbFajTcJbUBbgNOBo4ARkk6ok6za4EHI6IfMBK4vdSFmplZ8YrpuQ8CVkbEqojYBNwPnFqnTQB7ZY/3Bt4pXYlmZtZUxYT7/sCavOWabF2+ScC3JNUAc4HLCh1I0lhJ1ZKq/SHYZmbNp5hwV4F1UWd5FDA9IroBXwN+JWm7Y0fEXRFRFRFVFRUVTa/WzMyKUky41wAH5C13Y/thl/OABwEi4j+BDkDXUhRoZmZNV0y4LwR6SuohaXdyN0zn1GnzNnA8gKTDyYW7x13MzMqk0XCPiC3ApcA8YDm5V8W8JmmypOFZs/HABZKWADOB0RFRd+jGzMxaSNtiGkXEXHI3SvPXTcx7/DowpLSlmZnZjvI7VM3MEuRwNzNLkMPdzCxBDnczswQ53M3MEuRwNzNLkMPdzCxBDnczswQ53M3MEuRwNzNLkMPdzCxBDnczswQ53M3MEuRwNzNLkMPdzCxBDnczswQ53M3MEuRwNzNLkMPdzCxBDnczswQ53M3MEuRwNzNLkMPdzCxBDnczswQ53M3MEuRwNzNLkMPdzCxBDnczswQ53M3MEuRwNzNLkMPdzCxBDnczswQ53M3MEuRwNzNLUFHhLmmYpDclrZT0vXranCHpdUmvSfrX0pZpZmZN0baxBpLaALcBJwI1wEJJcyLi9bw2PYHvA0Mi4gNJX2qugs3MrHHF9NwHASsjYlVEbALuB06t0+YC4LaI+AAgIt4tbZlmZtYUxYT7/sCavOWabF2+Q4BDJD0vaYGkYYUOJGmspGpJ1WvXrt2xis3MrFHFhLsKrIs6y22BnsBQYBRwt6TO2+0UcVdEVEVEVUVFRVNrNTOzIhUT7jXAAXnL3YB3CrR5JCI2R8R/AW+SC3szMyuDYsJ9IdBTUg9JuwMjgTl12jwM/B2ApK7khmlWlbJQMzMrXqPhHhFbgEuBecBy4MGIeE3SZEnDs2bzgHWSXgeeBq6OiHXNVbSZmTWs0ZdCAkTEXGBunXUT8x4HcFX2z8zMyszvUDUzS5DD3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEtS13AWZmAOdNX1jvtntGD2zBStLgnruZWYIc7mZmCXK4m5klqKgxd0nDgJuBNsDdEfGjetp9E5gFDIyI6pJVaWYl47Ht1qHRnrukNsBtwMnAEcAoSUcUaNcJuBx4sdRFmplZ0xQzLDMIWBkRqyJiE3A/cGqBdtcDNwIbS1ifmZntgGLCfX9gTd5yTbaulqR+wAER8WhDB5I0VlK1pOq1a9c2uVgzMytOMeGuAuuidqO0GzAVGN/YgSLiroioioiqioqK4qs0M7MmKSbca4AD8pa7Ae/kLXcCegHPSFoNHAnMkVRVqiLNzKxpign3hUBPST0k7Q6MBOZs3RgR6yOia0R0j4juwAJguF8tY2ZWPo2Ge0RsAS4F5gHLgQcj4jVJkyUNb+4Czcys6Yp6nXtEzAXm1lk3sZ62Q3e+LDMz2xl+h6qZWYIc7mZmCXK4m5klyOFuZpYgh7uZWYIc7mZmCXK4m5klyOFuZpYgh7uZWYIc7mZmCXK4m5klyOFuZpYgh7uZWYIc7mZmCXK4m5klyOFuZpYgh7uZWYIc7mZmCXK4m5klyOFuZpYgh7uZWYIc7mZmCXK4m5klyOFuZpYgh7uZWYIc7mZmCXK4m5klyOFuZpYgh7uZWYIc7mZmCXK4m5klyOFuZpYgh7uZWYIc7mZmCSoq3CUNk/SmpJWSvldg+1WSXpe0VNLvJB1U+lLNzKxYjYa7pDbAbcDJwBHAKElH1Gn2ClAVEb2BfwNuLHWhZmZWvGJ67oOAlRGxKiI2AfcDp+Y3iIinI+LjbHEB0K20ZZqZWVMUE+77A2vylmuydfU5D3i80AZJYyVVS6peu3Zt8VWamVmTFBPuKrAuCjaUvgVUATcV2h4Rd0VEVURUVVRUFF+lmZk1Sdsi2tQAB+QtdwPeqdtI0gnANcBxEfFpacozM7MdUUzPfSHQU1IPSbsDI4E5+Q0k9QN+AQyPiHdLX6aZmTVFo+EeEVuAS4F5wHLgwYh4TdJkScOzZjcBHYFZkhZLmlPP4czMrAUUMyxDRMwF5tZZNzHv8QklrsvMzHaC36FqZpYgh7uZWYIc7mZmCXK4m5klyOFuZpYgh7uZWYIc7mZmCXK4m5klyOFuZpYgh7uZWYIc7mZmCXK4m5klyOFuZpYgh7uZWYIc7mZmCXK4m5klyOFuZpYgh7uZWYIc7mZmCXK4m5klyOFuZpYgh7uZWYIc7mZmCXK4m5klyOFuZpYgh7uZWYIc7mZmCWpb7gLMrHU5b/rCcpfQKrjnbmaWIIe7mVmCPCxjZrXqGzK5Z/TAFq7EdpZ77mZmCXK4m5klyOFuZpYgh7uZWYKKCndJwyS9KWmlpO8V2N5e0gPZ9hcldS91oWZmVrxGw11SG+A24GTgCGCUpCPqNDsP+CAivgJMBX5c6kLNzKx4xfTcBwErI2JVRGwC7gdOrdPmVOC+7PG/AcdLUunKNDOzpijmde77A2vylmuAwfW1iYgtktYDXYD38htJGguMzRY3SHpzR4ouga7UqS1xre16wddcUveOaY6jluT8rfHrfGgxjYoJ90I98NiBNkTEXcBdRZyzWUmqjoiqctfRUlrb9YKvubVorddcTLtihmVqgAPylrsB79TXRlJbYG/g/WIKMDOz0ism3BcCPSX1kLQ7MBKYU6fNHOCc7PE3gX+PiO167mZm1jIaHZbJxtAvBeYBbYB7I+I1SZOB6oiYA9wD/ErSSnI99pHNWXQJlH1oqIW1tusFX3Nr4Wuuh9zBNjNLj9+hamaWIIe7mVmCWm24S7osm1LhNUk3lrueliJpgqSQ1LXctTQ3STdJekPSUkkPSepc7pqaS2NThKRG0gGSnpa0PPsZ/na5a2oJktpIekXSo421bZXhLunvyL2rtndE/C0wpcwltQhJBwAnAm+Xu5YW8iTQKyJ6AyuA75e5nmZR5BQhqdkCjI+Iw4EjgUtawTUDfBtYXkzDVhnuwDjgRxHxKUBEvFvmelrKVOA7FHiDWYoi4omI2JItLiD3Ho0UFTNFSFIi4k8R8XL2+CNygbd/eatqXpK6AV8H7i6mfWsN90OAY7IZLJ+VlPxniEkaDvwxIpaUu5YyORd4vNxFNJNCU4QkHXT5sllo+wEvlreSZvczcp2zz4ppnOxnqEp6CvhfBTZdQ+669yH359xA4EFJB3/e33jVyDX/b+Cklq2o+TV0zRHxSNbmGnJ/xs9oydpaUFHTf6RIUkfgN8AVEfE/5a6nuUg6BXg3IhZJGlrMPsmGe0ScUN82SeOA2VmYvyTpM3ITEK1tqfqaQ33XLKkS6AEsySbr7Aa8LGlQRPx3C5ZYcg19nQEknQOcAhz/ef/l3YBipghJjqR25IJ9RkTMLnc9zWwIMFzS14AOwF6Sfh0R36pvh1b5JiZJFwFfjoiJkg4BfgccmPAP/zYkrQaqIiLp2fQkDQN+ChwXEZ/rX9wNyeZzWgEcD/yR3JQhZ0XEa2UtrBllU4rfB7wfEVeUu56WlPXcJ0TEKQ21a61j7vcCB0taRu7m0zmtJdhbmVuBTsCTkhZLurPcBTWH7Kbx1ilClgMPphzsmSHAPwN/n31tF2e9Wsu0yp67mVnqWmvP3cwsaQ53M7MEOdzNzBLkcDczS5DD3cwsQQ73REnaUMJj/UzSsdnju5syQZOkocXMYNfEelYXmtVS0kWSzi7lueocv0nX3oTjPiOpxT7kWdIkSRNKfMynJO1TymPazkn2HapWGpK+CBy59Y0iEXF+mUuqV0Q06+vYd5Vrl9QmIv5S7jrq+BVwMfAv5S7EctxzT5xybpK0TNKrks7M1u8m6fZsLuxHJc2V9M0Ch/gm8Nu849X2MiWNyo65TNKPi6hlkKQXsvmoX5B0aLa+jaQp2bGWSrosW3981vZVSfdKap93uKslvZT9+0rWvrZHKukCSQslLZH0G0l7ZOunS7olO/+qQtcsaU9Jj2X7Lst7zvKvfYOkf8naLJC0b7b+b7LlhZImb/0Lqu5fMJJulTS6wLnvkFSdfV2uy1u/WtJESc8Bp+et3zvbtlu2vIekNZLa1fcc1Dlf/jV1zd69vPVrclO2/1JJF2br95M0P3vT0DJJx2SHmgOMqv+rby3N4Z6+EUBfoA9wAnCTpP2y9d2BSuB84Kh69h8CLKq7UtKXgR8Df58df6Ckf2ykljeAYyOiHzAR+D/Z+rHk5r7pl829PkNSB2A6cGZEVJL7K3Nc3rH+JyIGkXsX6s8KnGt2RAyMiD7k3rV5Xt62/YCjyc0586MC+w4D3omIPhHRi7xfbnn2BBZkx58PXJCtvxm4OSIGsmPzu1wTEVVAb+A4Sb3ztm2MiKMj4v6tKyJiPbAEOC5b9Q/AvIjYTMPPQWPOA9Zn1zEQuEBSD+Cs7Phbv6cWZ3V8ALSX1GUHrtmagcM9fUcDMyPiLxHx/4Bnyf2wHg3MiojPssnDnq5n//0oPKHaQOCZiFibvf19BnBsI7XsDcxSbtqHqcDfZutPAO7cOvd6RLwPHAr8V0SsyNrcV+f4M/P+L/SLqZek/5D0KvBPeecCeDi77teBfQvs+ypwgqQfSzomC9C6NgFbe+KLyP2iJKtlVvb4Xwvs15gzJL0MvJLVnD/G/0A9+zwAnJk9HpnXrqHnoDEnAWdLWkxuKt0uQE9y89aMkTQJqMzmUt/qXeDLTTiHNSOHe/oKTQfb0Pq6PiE3C11R+0v6hv4610fdm4TXA09nveF/yDuu2H6K2sbqi3oebzUduDTr9V/HttfwaUPnyX6hDCAX8v9X0sQCx9+cNx/RX2j8/tUWtv152+45zXrGE8jNYNkbeKxOuz/Xc+w5wMnK3R8ZAPx7tn469T8HherK3y7gsojom/3rkX34yXxyv2T/CPxK297A7kDu+8V2AQ739M0HzszGUCvI/WC+BDwHnJaNve8LDK1n/+XAVwqsf5HcsEFX5T7mbRTwbEQ8lBcI1XX22ZtcKACMzlv/BHCRcrMbbr2J+wbQfet4OrlJop7N2+fMvP//s0B9nYA/KTct7D/Vc20FZUNOH0fEr8l9BGP/Juy+ADgtezwyb/0fgCMktZe0N7kZHOvai1yAr8++JicXc8KI2EDua3oz8GjezdZinoPV5H4hQO7+ylbzgHHZvkg6JLsXcRC5ecWnAfeQPTeSRG5e/dXF1GzNz6+WSd9D5IYKlpDr4X4nIv5b0m/IBcwyctPFvggUGn54DLiQOh/tFRF/kvR9csM5AuZu/XCMOtry157yjcB9kq7ir71LsmMfAiyVtBmYFhG3ShpDbhinLbnhgPxXw7SX9CK5DkqhG3k/yK7pD+R64J0KtKlPJbl7E58Bm9l2rL8xVwC/ljSe3HO3HiAi1kh6EFgK/J7csMs2ImKJpFeA14BVwPNNOO8D5IaDhuatK+Y5mELuw2r+me2/Jt3JzfsvckNz/5gd/+rs67QB2NpzH0DuHsQWbJfgWSFbMUkdI2JDdhPsJWBIoQ/vyF6hcUpEfLgD5/g2sH9EfGfnK971Za9I+SQiQtJIYFREJP15pgCSbgbmRMTvyl2L5bjn3ro9KqkzsDtwfQOfyjQeOBBoUrhLugfoBZyxU1V+vgwAbs16ux+S++zW1mCZg33X4p67mVmCfEPVzCxBDnczswQ53M3MEuRwNzNLkMPdzCxB/x/7jXV9VZKDzQAAAABJRU5ErkJggg==\n", 299 | "text/plain": [ 300 | "
" 301 | ] 302 | }, 303 | "metadata": { 304 | "needs_background": "light" 305 | }, 306 | "output_type": "display_data" 307 | } 308 | ], 309 | "source": [ 310 | "######################################################################\n", 311 | "# The model is set up with the hyperparameter below.\n", 312 | "\n", 313 | "version = 'Residual' # Architecture\n", 314 | "epochs = 10 # Number of epochs\n", 315 | "depth = 64 # Number of layers\n", 316 | "width = 256 # Width\n", 317 | "lr = 0.01 # Learning rate for Adagrad optimizer\n", 318 | "\n", 319 | "setup_and_train(epochs, lr, width, depth, version, plt_jacobian = True)" 320 | ] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "metadata": {}, 325 | "source": [ 326 | "### LayerNorm network\n", 327 | "\n", 328 | "A LayerNorm network converges very slowly." 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": 5, 334 | "metadata": {}, 335 | "outputs": [ 336 | { 337 | "name": "stdout", 338 | "output_type": "stream", 339 | "text": [ 340 | "| end of epoch 1 | time / epoch 9.97s | loss 2.23\n", 341 | "| end of epoch 2 | time / epoch 10.18s | loss 2.14\n", 342 | "| end of epoch 3 | time / epoch 9.88s | loss 2.14\n", 343 | "| end of epoch 4 | time / epoch 10.14s | loss 2.11\n", 344 | "| end of epoch 5 | time / epoch 10.80s | loss 2.09\n", 345 | "| end of epoch 6 | time / epoch 10.61s | loss 2.06\n", 346 | "| end of epoch 7 | time / epoch 10.23s | loss 2.04\n", 347 | "| end of epoch 8 | time / epoch 10.40s | loss 2.02\n", 348 | "| end of epoch 9 | time / epoch 9.90s | loss 2.01\n", 349 | "| end of epoch 10 | time / epoch 10.12s | loss 2.02\n", 350 | "-------------------------------------------------------\n", 351 | "Mean sq singular value of io Jacobian: 12.412\n" 352 | ] 353 | }, 354 | { 355 | "data": { 356 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEKCAYAAADpfBXhAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAaGklEQVR4nO3dfZRU1Z3u8e8jIIigqJBcI0YxERXtDkqDsDTKiLLQ8S3XiOBkVBRR40sUMDFXw2UMk4xoohBRg2LwKr7F6AzL6PUlSrw6oDaKgKIEHZSOmQtiwlwiKOjv/lGnO0VT3V0N1V3N7uezFos65+w69TvV3U/v3qfOPooIzMwsLTuVuwAzMys9h7uZWYIc7mZmCXK4m5klyOFuZpYgh7uZWYKaDHdJd0taLWlpA9slabqkFZIWSzqi9GWamVlzFNNznw2MaGT7icCB2b9xwO3bX5aZmW2PJsM9Il4APm6kyWnA/4qcBUAPSXuXqkAzM2u+jiXYxz7Aqrzlmmzdn+o3lDSOXO+eXXfddcDBBx9cgpc3M2s/Fi5c+FFE9GqqXSnCXQXWFZzTICJmAjMBqqqqorq6ugQvb2bWfkh6v5h2pfi0TA2wb95yb+DDEuzXzMy2USnCfS5wTvapmcHAuojYakjGzMxaT5PDMpIeAIYCPSXVAP8T6AQQEXcATwAnASuAT4AxLVWsmZkVp8lwj4jRTWwP4NJSFLNp0yZqamrYuHFjKXZn7VSXLl3o3bs3nTp1KncpZmVTihOqJVNTU0P37t3Zf//9kQqdpzVrXESwdu1aampq6NOnT7nLMSubNjX9wMaNG9lrr70c7LbNJLHXXnv5rz9r99pUuAMOdttu/h4ya4PhbmZm269NjbnXd8HsV0u6v1nnDWyyjSTGjx/Pz372MwBuuukm1q9fz+TJk5k8eTJTp05l5cqVfOlLXwKgW7durF+/vqR11jd06FBuuukmqqqqtqvN5MmT6datGxMnTmyJMgvWtH79emovVquurmbixInMmzevVV7frD1zz72ezp078+ijj/LRRx8V3N6zZ8+64LeGff755wCsXr2aJ598cpv2sXnz5lKWZNauONzr6dixI+PGjePmm28uuP3888/noYce4uOPG5tLLdej/8EPfsCAAQM4/vjjeeWVVxg6dCgHHHAAc+fOBXInkMeMGUNFRQWHH344zz//PAAbNmxg1KhRVFZWctZZZ7Fhw4a6/T799NMMGTKEI444gjPPPLMkfzWcfvrpDBgwgEMPPZSZM2cCMGvWLK666qq6NnfeeSfjx48H4L777mPQoEH079+fiy66qC7Iu3XrxqRJkzjyyCOZP38+AFdffTVTpkzZ6jUbOvbZs2dz5plncsoppzB8+HDmzZvHsccey8iRI+nbty/XXHMNc+bMYdCgQVRUVPDuu+9u9/GbpcjhXsCll17KnDlzWLdu3VbbunXrxvnnn8+0adMa3cdf//pXhg4dysKFC+nevTvXXXcdzzzzDI899hiTJk0CYMaMGQAsWbKEBx54gHPPPZeNGzdy++2307VrVxYvXsy1117LwoULAfjoo4+YMmUKzz77LK+99hpVVVX8/Oc/3+q1x44dS3Pm7bn77rtZuHAh1dXVTJ8+nbVr1zJq1Cjmzp3Lpk2bAPjVr37FmDFjWLZsGQ899BAvvfQSixYtokOHDsyZM6fumA877DBefvlljj76aACGDBlC586d68K7VkPHDjB//nzuuecennvuOQDeeOMNpk2bxpIlS7j33ntZvnw5r7zyCmPHjuUXv/hF0cdp1p606TH3ctltt90455xzmD59OrvssstW26+44gr69+/PhAkTGtzHzjvvzIgRuWnwKyoq6Ny5M506daKiooKVK1cC8OKLL3L55ZcDcPDBB7PffvuxfPlyXnjhBa644goAKisrqaysBGDBggW89dZbHHXUUQB89tlnDBkyZKvXvuuuu5p1vNOnT+exxx4DYNWqVfzhD39g8ODBHHfccTz++OMccsghbNq0iYqKCm699VYWLlzIwIG58xcbNmyoO//QoUMHzjjjjK32f9111zFlyhRuuOGGunUNHTvACSecwJ577lnXduDAgey9d24W6a997WsMHz687n2t/0vDzHIc7g248sorOeKIIxgzZuvZFHr06MHZZ5/Nbbfd1uDzO3XqVPeRvJ122onOnTvXPa4dS85d3FtYoY/zRQQnnHACDzzwQLOOpTHz5s3j2WefZf78+XTt2pWhQ4fW9aDHjh3LT37yEw4++OC69yEiOPfcc/npT3+61b66dOlChw4dtlp/3HHH8aMf/YgFCxZscSwN2XXXXbdYrn3voOH30sy25GGZBuy5556MHDmSWbNmFdw+fvx4fvnLX25XuBxzzDF1QxrLly/ngw8+4KCDDtpi/dKlS1m8eDEAgwcP5qWXXmLFihUAfPLJJ3W93W21bt069thjD7p27crbb7+9RQAfeeSRrFq1ivvvv5/Ro3OzUAwbNoxHHnmE1atXA/Dxxx/z/vtNz0B67bXXMnXq1CaP3cxKo0333Iv56GJLmjBhArfeemvBbT179uRb3/pWgydei/Hd736Xiy++mIqKCjp27Mjs2bPp3Lkzl1xyCWPGjKGyspL+/fszaNAgAHr16sXs2bMZPXo0n376KQBTpkyhb9++W+x37NixXHzxxQU/FjllyhRuueWWuuV3332XO+64g8rKSg466CAGDx68RfuRI0eyaNEi9thjDwD69evHlClTGD58OF988QWdOnVixowZ7Lfffo0e60knnUSvXn+7v0BDx25mpaHG/jxuSYVu1rFs2TIOOeSQstRjhZ188slcddVVDBs2rNylNIu/lyxVkhZGRMMXtGQ8LGMF/eUvf6Fv377ssssuO1ywm1kbH5ax8unRo8d2j+ebWfm0uZ57uYaJLB3+HjJrY+HepUsX1q5d6x9O22a187l36dKl3KWYlVWbGpbp3bs3NTU1rFmzptyl2A6s9k5MZu1Zmwr3Tp06+e45ZmYl0KaGZczMrDQc7mZmCXK4m5klyOFuZpYgh7uZWYIc7mZmCXK4m5klyOFuZpYgh7uZWYIc7mZmCXK4m5klyOFuZpYgh7uZWYIc7mZmCXK4m5klyOFuZpagosJd0ghJ70haIemaAtu/Kul5Sa9LWizppNKXamZmxWoy3CV1AGYAJwL9gNGS+tVrdh3wcEQcDowCbit1oWZmVrxieu6DgBUR8V5EfAY8CJxWr00Au2WPdwc+LF2JZmbWXMWE+z7AqrzlmmxdvsnAdyTVAE8AlxfakaRxkqolVfsm2GZmLaeYcFeBdVFveTQwOyJ6AycB90raat8RMTMiqiKiqlevXs2v1szMilJMuNcA++Yt92brYZcLgIcBImI+0AXoWYoCzcys+YoJ91eBAyX1kbQzuROmc+u1+QAYBiDpEHLh7nEXM7MyaTLcI2IzcBnwFLCM3Kdi3pR0vaRTs2YTgAslvQE8AJwXEfWHbszMrJV0LKZRRDxB7kRp/rpJeY/fAo4qbWlmZratfIWqmVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZgkqKtwljZD0jqQVkq5poM1ISW9JelPS/aUt08zMmqNjUw0kdQBmACcANcCrkuZGxFt5bQ4EfggcFRF/lvSllirYzMyaVkzPfRCwIiLei4jPgAeB0+q1uRCYERF/BoiI1aUt08zMmqOYcN8HWJW3XJOty9cX6CvpJUkLJI0otCNJ4yRVS6pes2bNtlVsZmZNKibcVWBd1FvuCBwIDAVGA3dJ6rHVkyJmRkRVRFT16tWrubWamVmRign3GmDfvOXewIcF2vxbRGyKiP8A3iEX9mZmVgbFhPurwIGS+kjaGRgFzK3X5l+BvwOQ1JPcMM17pSzUzMyK12S4R8Rm4DLgKWAZ8HBEvCnpekmnZs2eAtZKegt4Hrg6Ita2VNFmZtY4RdQfPm8dVVVVUV1dXZbXNjPbUUlaGBFVTbXzFapmZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWoKLCXdIISe9IWiHpmkbafVtSSKoqXYlmZtZcTYa7pA7ADOBEoB8wWlK/Au26A1cAL5e6SDMza55ieu6DgBUR8V5EfAY8CJxWoN2PganAxhLWZ2Zm26CYcN8HWJW3XJOtqyPpcGDfiHi8sR1JGiepWlL1mjVrml2smZkVp5hwV4F1UbdR2gm4GZjQ1I4iYmZEVEVEVa9evYqv0szMmqWYcK8B9s1b7g18mLfcHTgMmCdpJTAYmOuTqmZm5VNMuL8KHCipj6SdgVHA3NqNEbEuInpGxP4RsT+wADg1IqpbpGIzM2tSk+EeEZuBy4CngGXAwxHxpqTrJZ3a0gWamVnzdSymUUQ8ATxRb92kBtoO3f6yzMxse/gKVTOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBBU1K6S1LxfMfrWk+5t13sCS7s/Mmuaeu5lZghzuZmYJcribmSXI4W5mliCHu5lZghzuZmYJcribmSXIn3O3FufPzZu1PvfczcwS5HA3M0uQw93MLEEec09Aqce0zWzH5567mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWoKLCXdIISe9IWiHpmgLbx0t6S9JiSb+TtF/pSzUzs2I1Ge6SOgAzgBOBfsBoSf3qNXsdqIqISuARYGqpCzUzs+IV03MfBKyIiPci4jPgQeC0/AYR8XxEfJItLgB6l7ZMMzNrjmLCfR9gVd5yTbauIRcATxbaIGmcpGpJ1WvWrCm+SjMza5Ziwl0F1kXBhtJ3gCrgxkLbI2JmRFRFRFWvXr2Kr9LMzJqlmFkha4B985Z7Ax/WbyTpeOBa4NiI+LQ05ZmZ2bYopuf+KnCgpD6SdgZGAXPzG0g6HPglcGpErC59mWZm1hxNhntEbAYuA54ClgEPR8Sbkq6XdGrW7EagG/BrSYskzW1gd2Zm1gqKullHRDwBPFFv3aS8x8eXuC4zM9sOvkLVzCxBDnczswQ53M3MEuQbZNsOp9Q3BJ913sCS7s+sLXDP3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBLkz7mbtTO+TqB9cM/dzCxBDnczswQ53M3MEuRwNzNLkMPdzCxB/rSMtXv+9IilyOFu1saV+pePtQ8O91bmH1Qzaw0eczczS5B77mYl5r/OrC1wz93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQL2Iys+3iidfaJvfczcwS5HA3M0uQw93MLEEOdzOzBBUV7pJGSHpH0gpJ1xTY3lnSQ9n2lyXtX+pCzcyseE2Gu6QOwAzgRKAfMFpSv3rNLgD+HBFfB24Gbih1oWZmVrxieu6DgBUR8V5EfAY8CJxWr81pwD3Z40eAYZJUujLNzKw5ivmc+z7AqrzlGuDIhtpExGZJ64C9gI/yG0kaB4zLFtdLemdbii6BntSrLXHt7XjBx7zDuntMs5oncczNdFAxjYoJ90I98NiGNkTETGBmEa/ZoiRVR0RVuetoLe3teMHH3F6012Mupl0xwzI1wL55y72BDxtqI6kjsDvwcTEFmJlZ6RUT7q8CB0rqI2lnYBQwt16bucC52eNvA89FxFY9dzMzax1NDstkY+iXAU8BHYC7I+JNSdcD1RExF5gF3CtpBbke+6iWLLoEyj401Mra2/GCj7m98DE3QO5gm5mlx1eompklyOFuZpagdhvuki7PplR4U9LUctfTWiRNlBSSepa7lpYm6UZJb0taLOkxST3KXVNLaWqKkNRI2lfS85KWZT/D3yt3Ta1BUgdJr0t6vKm27TLcJf0duatqKyPiUOCmMpfUKiTtC5wAfFDuWlrJM8BhEVEJLAd+WOZ6WkSRU4SkZjMwISIOAQYDl7aDYwb4HrCsmIbtMtyBS4B/iYhPASJidZnraS03A9+nwAVmKYqIpyNic7a4gNw1GikqZoqQpETEnyLitezx/yMXePuUt6qWJak38PfAXcW0b6/h3hf4ZjaD5e8lJX9fL0mnAn+MiDfKXUuZnA88We4iWkihKUKSDrp82Sy0hwMvl7eSFncLuc7ZF8U0TvYeqpKeBf5bgU3XkjvuPcj9OTcQeFjSATv6hVdNHPP/AIa3bkUtr7Fjjoh/y9pcS+7P+DmtWVsrKmr6jxRJ6gb8BrgyIv6r3PW0FEknA6sjYqGkocU8J9lwj4jjG9om6RLg0SzMX5H0BbkJiNa0Vn0toaFjllQB9AHeyCbr7A28JmlQRPxnK5ZYco19nQEknQucDAzb0X95N6KYKUKSI6kTuWCfExGPlrueFnYUcKqkk4AuwG6S7ouI7zT0hHZ5EZOki4GvRMQkSX2B3wFfTfiHfwuSVgJVEZH0bHqSRgA/B46NiB36F3djsvmclgPDgD+SmzLk7Ih4s6yFtaBsSvF7gI8j4spy19Oasp77xIg4ubF27XXM/W7gAElLyZ18Ore9BHs7cyvQHXhG0iJJd5S7oJaQnTSunSJkGfBwysGeOQr4R+C47Gu7KOvVWqZd9tzNzFLXXnvuZmZJc7ibmSXI4W5mliCHu5lZghzuZmYJcrgnStL6Eu7rFknHZI/vas4ETZKGFjODXTPrWVloVktJF0s6p5SvVW//zTr2Zux3nqRWu8mzpMmSJpZ4n89K2qOU+7Ttk+wVqlYakvYEBtdeKBIRY8tcUoMiokU/x95Wjl1Sh4j4vNx11HMv8F3gn8tdiOW455445dwoaamkJZLOytbvJOm2bC7sxyU9IenbBXbxbeB/5+2vrpcpaXS2z6WSbiiilkGS/j2bj/rfJR2Ure8g6aZsX4slXZ6tH5a1XSLpbkmd83Z3taRXsn9fz9rX9UglXSjpVUlvSPqNpK7Z+tmSpmev/16hY5a0q6TfZs9dmvee5R/7ekn/nLVZIOnL2fqvZcuvSrq+9i+o+n/BSLpV0nkFXvt2SdXZ1+Wf8tavlDRJ0ovAmXnrd8+27ZQtd5W0SlKnht6Deq+Xf0w9s6uXa78mN2bPXyzpomz93pJeyC4aWirpm9mu5gKjG/7qW2tzuKfvvwP9gW8AxwM3Sto7W78/UAGMBYY08PyjgIX1V0r6CnADcFy2/4GSTm+ilreBYyLicGAS8JNs/Thyc98cns29PkdSF2A2cFZEVJD7K/OSvH39V0QMIncV6i0FXuvRiBgYEd8gd9XmBXnb9gaOJjfnzL8UeO4I4MOI+EZEHEbeL7c8uwILsv2/AFyYrZ8GTIuIgWzb/C7XRkQVUAkcK6kyb9vGiDg6Ih6sXRER64A3gGOzVacAT0XEJhp/D5pyAbAuO46BwIWS+gBnZ/uv/Z5alNXxZ6CzpL224ZitBTjc03c08EBEfB4R/xf4Pbkf1qOBX0fEF9nkYc838Py9KTyh2kBgXkSsyS5/nwMc00QtuwO/Vm7ah5uBQ7P1xwN31M69HhEfAwcB/xERy7M299Tb/wN5/xf6xXSYpP8jaQnwD3mvBfCv2XG/BXy5wHOXAMdLukHSN7MAre8zoLYnvpDcL0qyWn6dPb6/wPOaMlLSa8DrWc35Y/wPNfCch4Czssej8to19h40ZThwjqRF5KbS3Qs4kNy8NWMkTQYqsrnUa60GvtKM17AW5HBPX6HpYBtbX98GcrPQFfV8Sd/S3+b6qH+S8MfA81lv+JS8/Yqtp6htqr5o4HGt2cBlWa//n9jyGD5t7HWyXygDyIX8TyVNKrD/TXnzEX1O0+evNrPlz9tW72nWM55IbgbLSuC39dr9tYF9zwVOVO78yADguWz9bBp+DwrVlb9dwOUR0T/71ye7+ckL5H7J/hG4V1uewO5C7vvF2gCHe/peAM7KxlB7kfvBfAV4ETgjG3v/MjC0gecvA75eYP3L5IYNeip3m7fRwO8j4rG8QKiu95zdyYUCwHl5658GLlZudsPak7hvA/vXjqeTmyTq93nPOSvv//kF6usO/Em5aWH/oYFjKygbcvokIu4jdwvGI5rx9AXAGdnjUXnr3wf6SeosaXdyMzjWtxu5AF+XfU1OLOYFI2I9ua/pNODxvJOtxbwHK8n9QoDc+ZVaTwGXZM9FUt/sXMR+5OYVvxOYRfbeSBK5efVXFlOztTx/WiZ9j5EbKniDXA/3+xHxn5J+Qy5glpKbLvZloNDww2+Bi6h3a6+I+JOkH5IbzhHwRO3NMerpyN96ylOBeySN52+9S7J99wUWS9oE3BkRt0oaQ24YpyO54YD8T8N0lvQyuQ5KoRN5P8qO6X1yPfDuBdo0pILcuYkvgE1sOdbflCuB+yRNIPferQOIiFWSHgYWA38gN+yyhYh4Q9LrwJvAe8BLzXjdh8gNBw3NW1fMe3ATuZvV/CNbf032Jzfvv8gNzZ2e7f/q7Ou0HqjtuQ8gdw5iM9YmeFbIdkxSt4hYn50EewU4qtDNO7JPaJwcEX/Zhtf4HrBPRHx/+ytu+7JPpGyIiJA0ChgdEUnfzxRA0jRgbkT8rty1WI577u3b45J6ADsDP27krkwTgK8CzQp3SbOAw4CR21XljmUAcGvW2/0LuXu3tgdLHexti3vuZmYJ8glVM7MEOdzNzBLkcDczS5DD3cwsQQ53M7ME/X8Mho0j/tEeuwAAAABJRU5ErkJggg==\n", 357 | "text/plain": [ 358 | "
" 359 | ] 360 | }, 361 | "metadata": { 362 | "needs_background": "light" 363 | }, 364 | "output_type": "display_data" 365 | } 366 | ], 367 | "source": [ 368 | "######################################################################\n", 369 | "# The model is set up with the hyperparameter below.\n", 370 | "\n", 371 | "version = 'LayerNorm' # Architecture\n", 372 | "epochs = 10 # Number of epochs\n", 373 | "depth = 64 # Number of layers\n", 374 | "width = 256 # Width\n", 375 | "lr = 0.01 # Learning rate for Adagrad optimizer\n", 376 | "\n", 377 | "setup_and_train(epochs, lr, width, depth, version, plt_jacobian = True)" 378 | ] 379 | }, 380 | { 381 | "cell_type": "markdown", 382 | "metadata": {}, 383 | "source": [ 384 | "### ReZero network\n", 385 | "\n", 386 | "A deep neural network with ReZero connection quickly converges." 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 6, 392 | "metadata": {}, 393 | "outputs": [ 394 | { 395 | "name": "stdout", 396 | "output_type": "stream", 397 | "text": [ 398 | "| end of epoch 1 | time / epoch 11.33s | loss 1.68\n", 399 | "| end of epoch 2 | time / epoch 11.57s | loss 1.37\n", 400 | "| end of epoch 3 | time / epoch 10.92s | loss 1.24\n", 401 | "| end of epoch 4 | time / epoch 11.80s | loss 1.14\n", 402 | "| end of epoch 5 | time / epoch 11.87s | loss 1.04\n", 403 | "| end of epoch 6 | time / epoch 11.54s | loss 0.94\n", 404 | "| end of epoch 7 | time / epoch 10.90s | loss 0.84\n", 405 | "| end of epoch 8 | time / epoch 10.96s | loss 0.73\n", 406 | "| end of epoch 9 | time / epoch 10.92s | loss 0.63\n", 407 | "| end of epoch 10 | time / epoch 11.24s | loss 0.52\n", 408 | "-------------------------------------------------------\n", 409 | "Mean sq singular value of io Jacobian: 18.082\n" 410 | ] 411 | }, 412 | { 413 | "data": { 414 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEKCAYAAADpfBXhAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAZjElEQVR4nO3dfZRU1Z3u8e8jIAQFMUByjaDgBF8YaUEaBpfRoOIEvRk1maDgzKD4guK7YKK5GsIYJpkohmh8SVRccBMGX0YdWYYJagbjVQcFFBFFucgQbc1ERIe5qAjE3/2jDm1RVHefhuou2P181mJZ55x9Tv1Olf307n3q7FJEYGZmadmj2gWYmVnlOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBLUZLhLukfSu5KWN7Bdkm6RtErSMklHVr5MMzNrjjw995nAyEa2nwT0y/6NB+7Y+bLMzGxnNBnuEfEU8H4jTU4F/ncULAS6SdqvUgWamVnzta/AMfYH3iparsvW/aG0oaTxFHr37LXXXoMPPfTQCjy9meW1Zt2Hudr16b5XC1diO2rJkiXvRUTPptpVItxVZl3ZOQ0i4k7gToDa2tpYvHhxBZ7ezPI6d+aiXO1mnD2khSuxHSXp93naVeLTMnVA76LlXsA7FTiumZntoEqE+1xgbPapmWHA+ojYbkjGzMxaT5PDMpLmAMOBHpLqgO8DHQAi4ufAPOBkYBXwETCupYo1M7N8mgz3iBjTxPYALq5EMZs3b6auro6NGzdW4nC2G+nUqRO9evWiQ4cO1S7FLAmVuKBaMXV1dXTp0oU+ffoglbtOaymKCNatW0ddXR19+/atdjlmSdilph/YuHEj3bt3d7C3MZLo3r27/2Izq6BdKtwBB3sb5ffdrLJ2uXA3M7Odt0uNuZfKe8NFXnluzJDExIkTuemmmwCYNm0aGzZsYMqUKUyZMoUbbriBNWvW8IUvfAGAvffemw0bNlS0zlLDhw9n2rRp1NbW7lSbKVOmcNddd9GzZ082bdrE9773PcaMafR6OV/72tf44x//WL/87rvv0rt3b5577rnmn4jtNkp/9nxT0+7HPfcSHTt25KGHHuK9994ru71Hjx71wb87uvLKK1m6dCmPPPIIF1xwAZs3b260/fz581m6dClLly7lmWeeoWvXrkydOjX3823ZsmVnSzazHeBwL9G+fXvGjx/P9OnTy24/55xzuO+++3j//cbmUiv06K+++moGDx7MiBEjeP755xk+fDgHHXQQc+fOBQoXkMeNG8eAAQMYNGgQCxYsAODjjz9m9OjR1NTUcMYZZ/Dxxx/XH/exxx7jqKOO4sgjj2TUqFE7/FdDv3796Ny5Mx988AEAb7zxBiNHjmTw4MEcc8wxvPbaa9vtc/nll3PyySdz4oknNrrP2WefzcSJEznuuOO4+uqref/99znttNOoqalh2LBhLFu2bIdqNrP8HO5lXHzxxcyePZv169dvt23vvffmnHPO4eabb270GB9++CHDhw9nyZIldOnSheuuu47HH3+chx9+mMmTJwNw2223AfDyyy8zZ84czjrrLDZu3Mgdd9xB586dWbZsGddeey1LliwB4L333mPq1Kk88cQTvPDCC9TW1vKTn/xku+c+77zzaGrenhdeeIF+/frVDy+NHz+en/3sZyxZsoRp06Zx0UUXbdP+4YcfZvHixfzoRz+qX9fYPitXruSJJ57gpptu4vvf/z6DBg1i2bJl/PCHP2Ts2LGN1mZmO2+XHnOvlq5duzJ27FhuueUWPve5z223/bLLLmPgwIFMmjSpwWPsueeejBxZmAZ/wIABdOzYkQ4dOjBgwADWrFkDwNNPP82ll14KwKGHHsqBBx7IypUreeqpp7jssssAqKmpoaamBoCFCxfy6quvcvTRRwOwadMmjjrqqO2e++67726wrunTp3PXXXexevVqfvOb3wCwYcMGnn32WUaNGlXf7pNPPql//Pbbb3PZZZcxf/58OnbsmGufUaNG0a5du/rzfPDBBwE4/vjjWbduHevXr2efffZpsE4z2zkO9wZcccUVHHnkkYwbt/1sCt26dePMM8/k9ttvb3D/Dh061H+8b4899qgPxT322KN+HLpwc2955T4aGBGceOKJzJkzp1nnUuzKK6/kqquu4qGHHmLs2LG88cYbfPrpp3Tr1o2lS5eWfc6zzjqLa665hv79+9evb2wfgL32+mzK2HLn6Y8+mrUsD8s04POf/zynn346M2bMKLt94sSJ/OIXv9ipC4bHHnsss2fPBgrDGG+++SaHHHLINuuXL19eP0Y9bNgwnnnmGVatWgXARx99xMqVK3foub/5zW9SW1vLrFmz6Nq1K3379uWBBx4ACmH80ksvAYVPC3Xq1ImLL952honG9mnsPJ988kl69OhB165dd6huM8tnl+65V/vjV5MmTeLWW28tu61Hjx584xvfaPDCax4XXXQRF154IQMGDKB9+/bMnDmTjh07MmHCBMaNG0dNTQ0DBw5k6NChAPTs2ZOZM2cyZsyY+iGQqVOncvDBB29z3PPOO48LL7yw0Y9FAkyePJkzzzyT888/n9mzZzNhwgSmTp3K5s2bGT16NEcccQTXXXcdvXr1YuDAgfX77bvvvixYsKDBfUpNmTKl/nw6d+7MrFmzdvg1M7N81NjQQEsq92UdK1as4LDDDqtKPVZ9fv9b3o7eO1LtjpZ9RtKSiGi854aHZczMkuRwNzNL0C4X7tUaJrLq8vtuVlm7VLh36tSJdevW+Qe9jdk6n3unTp2qXYpZMnapT8v06tWLuro61q5dW+1SrJVt/SYmM6uMXSrcO3To4G/iMTOrgF1qWMbMzCrD4W5mliCHu5lZghzuZmYJcribmSXI4W5mliCHu5lZghzuZmYJcribmSXI4W5mliCHu5lZghzuZmYJcribmSXI4W5mliCHu5lZghzuZmYJyhXukkZKel3SKknXlNl+gKQFkl6UtEzSyZUv1czM8moy3CW1A24DTgL6A2Mk9S9pdh1wf0QMAkYDt1e6UDMzyy9Pz30osCoiVkfEJuBe4NSSNgF0zR7vA7xTuRLNzKy58nyH6v7AW0XLdcBflLSZAjwm6VJgL2BEuQNJGg+MBzjggAOaW6uZlXHuzEXbLM84e0iVKrFdSZ6eu8qsi5LlMcDMiOgFnAz8UtJ2x46IOyOiNiJqe/bs2fxqzcwslzzhXgf0LlruxfbDLucC9wNExL8DnYAelSjQzMyaL0+4LwL6SeoraU8KF0znlrR5EzgBQNJhFMJ9bSULNTOz/JoM94jYAlwCzAdWUPhUzCuSrpd0StZsEnC+pJeAOcDZEVE6dGNmZq0kzwVVImIeMK9k3eSix68CR1e2NDPbVRRftPUF292D71A1M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEG5wl3SSEmvS1ol6ZoG2pwu6VVJr0j6p8qWaWZmzdG+qQaS2gG3AScCdcAiSXMj4tWiNv2A7wJHR8QHkr7QUgWbWXWdO3PRNsszzh5SpUqsMXl67kOBVRGxOiI2AfcCp5a0OR+4LSI+AIiIdytbppmZNUeecN8feKtouS5bV+xg4GBJz0haKGlkuQNJGi9psaTFa9eu3bGKzcysSXnCXWXWRclye6AfMBwYA9wtqdt2O0XcGRG1EVHbs2fP5tZqZmY55Qn3OqB30XIv4J0ybR6JiM0R8R/A6xTC3szMqiBPuC8C+knqK2lPYDQwt6TNvwDHAUjqQWGYZnUlCzUzs/yaDPeI2AJcAswHVgD3R8Qrkq6XdErWbD6wTtKrwALg2xGxrqWKNjOzxjX5UUiAiJgHzCtZN7nocQATs39mZlZlvkPVzCxBDnczswQ53M3MEuRwNzNLkMPdzCxBDnczswQ53M3MEuRwNzNLkMPdzCxBDnczswQ53M3MEpRrbhkz2z2VfiWetR3uuZuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCcoW7pJGSXpe0StI1jbT7lqSQVFu5Es3MrLmaDHdJ7YDbgJOA/sAYSf3LtOsCXAY8V+kizcysefL03IcCqyJidURsAu4FTi3T7gfADcDGCtZnZmY7IE+47w+8VbRcl62rJ2kQ0DsiHm3sQJLGS1osafHatWubXayZmeWTJ9xVZl3Ub5T2AKYDk5o6UETcGRG1EVHbs2fP/FWamVmz5An3OqB30XIv4J2i5S7A4cCTktYAw4C5vqhqZlY9ecJ9EdBPUl9JewKjgblbN0bE+ojoERF9IqIPsBA4JSIWt0jFZmbWpCbDPSK2AJcA84EVwP0R8Yqk6yWd0tIFmplZ87XP0ygi5gHzStZNbqDt8J0vy8zMdobvUDUzS5DD3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQbnCXdJISa9LWiXpmjLbJ0p6VdIySb+VdGDlSzUzs7yaDHdJ7YDbgJOA/sAYSf1Lmr0I1EZEDfDPwA2VLtTMzPLL03MfCqyKiNURsQm4Fzi1uEFELIiIj7LFhUCvypZpZmbNkSfc9wfeKlquy9Y15FzgX8ttkDRe0mJJi9euXZu/SjMza5Y84a4y66JsQ+lvgVrgxnLbI+LOiKiNiNqePXvmr9LMzJqlfY42dUDvouVewDuljSSNAK4FvhoRn1SmPDMz2xF5eu6LgH6S+kraExgNzC1uIGkQ8AvglIh4t/JlmplZczQZ7hGxBbgEmA+sAO6PiFckXS/plKzZjcDewAOSlkqa28DhzMysFeQZliEi5gHzStZNLno8osJ1mZnZTvAdqmZmCXK4m5klyOFuZpYgh7uZWYIc7mZmCXK4m5klyOFuZpYgh7uZWYIc7mZmCXK4m5klKNf0A2ZmDTl35qL6xzPOHlLFSqyYe+5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCPHGYWWKKJ/Kytss9dzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0tQrnCXNFLS65JWSbqmzPaOku7Ltj8nqU+lCzUzs/yaDHdJ7YDbgJOA/sAYSf1Lmp0LfBARXwamAz+udKFmZpZfnp77UGBVRKyOiE3AvcCpJW1OBWZlj/8ZOEGSKlemmZk1R54pf/cH3ipargP+oqE2EbFF0nqgO/BecSNJ44Hx2eIGSa/vSNEV0IOS2hLX1s4XfM5Vcc+4Vn/Kqp9zFRySp1GecC/XA48daENE3AncmeM5W5SkxRFRW+06WktbO1/wObcVbfWc87TLMyxTB/QuWu4FvNNQG0ntgX2A9/MUYGZmlZcn3BcB/ST1lbQnMBqYW9JmLnBW9vhbwL9FxHY9dzMzax1NDstkY+iXAPOBdsA9EfGKpOuBxRExF5gB/FLSKgo99tEtWXQFVH1oqJW1tfMFn3Nb4XNugNzBNjNLj+9QNTNLkMPdzCxBbTbcJV2aTanwiqQbql1Pa5F0laSQ1KPatbQ0STdKek3SMkkPS+pW7ZpaSlNThKRGUm9JCyStyH6GL692Ta1BUjtJL0p6tKm2bTLcJR1H4a7amoj4c2BalUtqFZJ6AycCb1a7llbyOHB4RNQAK4HvVrmeFpFzipDUbAEmRcRhwDDg4jZwzgCXAyvyNGyT4Q5MAP4xIj4BiIh3q1xPa5kOfIcyN5ilKCIei4gt2eJCCvdopCjPFCFJiYg/RMQL2eP/RyHw9q9uVS1LUi/gfwJ352nfVsP9YOCYbAbL30kaUu2CWpqkU4C3I+KlatdSJecA/1rtIlpIuSlCkg66YtkstIOA56pbSYv7KYXO2ad5GueZfmC3JOkJ4H+U2XQthfPel8Kfc0OA+yUdtLvfeNXEOf8v4C9bt6KW19g5R8QjWZtrKfwZP7s1a2tFuab/SJGkvYEHgSsi4r+rXU9LkfR14N2IWCJpeJ59kg33iBjR0DZJE4CHsjB/XtKnFCYgWtta9bWEhs5Z0gCgL/BSNllnL+AFSUMj4j9bscSKa+x9BpB0FvB14ITd/Zd3I/JMEZIcSR0oBPvsiHio2vW0sKOBUySdDHQCukr6VUT8bUM7tMmbmCRdCHwpIiZLOhj4LXBAwj/825C0BqiNiKRn05M0EvgJ8NWI2K1/cTcmm89pJXAC8DaFKUPOjIhXqlpYC8qmFJ8FvB8RV1S7ntaU9dyvioivN9aurY653wMcJGk5hYtPZ7WVYG9jbgW6AI9LWirp59UuqCVkF423ThGyArg/5WDPHA38HXB89t4uzXq1lmmTPXczs9S11Z67mVnSHO5mZglyuJuZJcjhbmaWIIe7mVmCHO6JkrShgsf6qaRjs8d3N2eCJknD88xg18x61pSb1VLShZLGVvK5So7frHNvxnGflNRqX/IsaYqkqyp8zCck7VvJY9rOSfYOVasMSZ8Hhm29USQizqtySQ2KiBb9HPuucu6S2kXEn6pdR4lfAhcB/1DtQqzAPffEqeBGScslvSzpjGz9HpJuz+bCflTSPEnfKnOIbwG/KTpefS9T0pjsmMsl/ThHLUMlPZvNR/2spEOy9e0kTcuOtUzSpdn6E7K2L0u6R1LHosN9W9Lz2b8vZ+3re6SSzpe0SNJLkh6U1DlbP1PSLdnzry53zpL2kvTrbN/lRa9Z8blvkPQPWZuFkr6Yrf+zbHmRpOu3/gVV+heMpFslnV3mue+QtDh7X/6+aP0aSZMlPQ2MKlq/T7Ztj2y5s6S3JHVo6DUoeb7ic+qR3b289T25Mdt/maQLsvX7SXoqu2louaRjskPNBcY0/O5ba3O4p++bwEDgCGAEcKOk/bL1fYABwHnAUQ3sfzSwpHSlpC8BPwaOz44/RNJpTdTyGnBsRAwCJgM/zNaPpzD3zaBs7vXZkjoBM4EzImIAhb8yJxQd678jYiiFu1B/Wua5HoqIIRFxBIW7Ns8t2rYf8BUKc878Y5l9RwLvRMQREXE4Rb/ciuwFLMyO/xRwfrb+ZuDmiBjCjs3vcm1E1AI1wFcl1RRt2xgRX4mIe7euiIj1wEvAV7NVfwXMj4jNNP4aNOVcYH12HkOA8yX1Bc7Mjr/1/6mlWR0fAB0ldd+Bc7YW4HBP31eAORHxp4j4I/A7Cj+sXwEeiIhPs8nDFjSw/36Un1BtCPBkRKzNbn+fDRzbRC37AA+oMO3DdODPs/UjgJ9vnXs9It4HDgH+IyJWZm1mlRx/TtF/y/1iOlzS/5H0MvA3Rc8F8C/Zeb8KfLHMvi8DIyT9WNIxWYCW2gRs7YkvofCLkqyWB7LH/1Rmv6acLukF4MWs5uIx/vsa2Oc+4Izs8eiido29Bk35S2CspKUUptLtDvSjMG/NOElTgAHZXOpbvQt8qRnPYS3I4Z6+ctPBNra+1McUZqHLtb+kb+izuT5KLxL+AFiQ9Yb/qui4YvspapuqLxp4vNVM4JKs1//3bHsOnzT2PNkvlMEUQv5HkiaXOf7movmI/kTT16+2sO3P23avadYzvorCDJY1wK9L2n3YwLHnAiepcH1kMPBv2fqZNPwalKureLuASyNiYPavb/blJ09R+CX7NvBLbXsBuxOF/19sF+BwT99TwBnZGGpPCj+YzwNPA3+djb1/ERjewP4rgC+XWf8chWGDHip8zdsY4HcR8XBRICwu2WcfCqEAcHbR+seAC1WY3XDrRdzXgD5bx9MpTBL1u6J9zij677+Xqa8L8AcVpoX9mwbOraxsyOmjiPgVha9gPLIZuy8E/jp7PLpo/e+B/pI6StqHwgyOpbpSCPD12XtyUp4njIgNFN7Tm4FHiy625nkN1lD4hQCF6ytbzQcmZPsi6eDsWsSBFOYVvwuYQfbaSBKFefXX5KnZWp4/LZO+hykMFbxEoYf7nYj4T0kPUgiY5RSmi30OKDf88GvgAkq+2isi/iDpuxSGcwTM2/rlGCXa81lP+QZglqSJfNa7JDv2wcAySZuBuyLiVknjKAzjtKcwHFD8aZiOkp6j0EEpdyHve9k5/Z5CD7xLmTYNGUDh2sSnwGa2HetvyhXAryRNovDarQeIiLck3Q8sA/4vhWGXbUTES5JeBF4BVgPPNON576MwHDS8aF2e12AahS+r+Tu2f0/6UJj3XxSG5k7Ljv/t7H3aAGztuQ+mcA1iC7ZL8KyQbZikvSNiQ3YR7Hng6HJf3pF9QuPrEfFfO/AclwP7R8R3dr7iXV/2iZSPIyIkjQbGRETS32cKIOlmYG5E/LbatViBe+5t26OSugF7Aj9o5FuZJgEHAM0Kd0kzgMOB03eqyt3LYODWrLf7XxS+u7UtWO5g37W4525mliBfUDUzS5DD3cwsQQ53M7MEOdzNzBLkcDczS9D/B7Y9W+kSKqHTAAAAAElFTkSuQmCC\n", 415 | "text/plain": [ 416 | "
" 417 | ] 418 | }, 419 | "metadata": { 420 | "needs_background": "light" 421 | }, 422 | "output_type": "display_data" 423 | } 424 | ], 425 | "source": [ 426 | "######################################################################\n", 427 | "# The model is set up with the hyperparameter below.\n", 428 | "\n", 429 | "version = 'ReZero' # Architecture\n", 430 | "epochs = 10 # Number of epochs\n", 431 | "depth = 64 # Number of layers\n", 432 | "width = 256 # Width\n", 433 | "lr = 0.01 # Learning rate for Adagrad optimizer\n", 434 | "\n", 435 | "setup_and_train(epochs, lr, width, depth, version, plt_jacobian = True)" 436 | ] 437 | }, 438 | { 439 | "cell_type": "markdown", 440 | "metadata": {}, 441 | "source": [ 442 | "### 10000 layer ReZero network\n", 443 | "\n", 444 | "We can train a 10000 layer neural network with ReZero. This takes several hours, so consider your actions." 445 | ] 446 | }, 447 | { 448 | "cell_type": "code", 449 | "execution_count": null, 450 | "metadata": {}, 451 | "outputs": [], 452 | "source": [ 453 | "######################################################################\n", 454 | "# The model is set up with the hyperparameter below.\n", 455 | "\n", 456 | "version = 'ReZero' # Architecture\n", 457 | "epochs = 25 # Number of epochs\n", 458 | "depth = 10000 # Number of layers\n", 459 | "width = 150 # Width\n", 460 | "lr = 0.003 # Learning rate for Adagrad optimizer\n", 461 | "\n", 462 | "setup_and_train(epochs, lr, width, depth, version, plt_jacobian = False)" 463 | ] 464 | }, 465 | { 466 | "cell_type": "code", 467 | "execution_count": null, 468 | "metadata": {}, 469 | "outputs": [], 470 | "source": [] 471 | } 472 | ], 473 | "metadata": { 474 | "kernelspec": { 475 | "display_name": "Python 3", 476 | "language": "python", 477 | "name": "python3" 478 | }, 479 | "language_info": { 480 | "codemirror_mode": { 481 | "name": "ipython", 482 | "version": 3 483 | }, 484 | "file_extension": ".py", 485 | "mimetype": "text/x-python", 486 | "name": "python", 487 | "nbconvert_exporter": "python", 488 | "pygments_lexer": "ipython3", 489 | "version": "3.7.6" 490 | } 491 | }, 492 | "nbformat": 4, 493 | "nbformat_minor": 4 494 | } 495 | -------------------------------------------------------------------------------- /ReZero-Deep_Fast_Transformer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Training 128 layer ReZero Transformer on WikiText-2 language modeling\n", 8 | "\n", 9 | "In this notebook we will examine how the [ReZero](https://arxiv.org/abs/2003.04887) architecture addition enables or accelerates training in deep [Transformer](https://arxiv.org/pdf/1706.03762.pdf) networks or fully connected networks.\n", 10 | "\n", 11 | "The official ReZero repo is [here](https://github.com/majumderb/rezero). Although it is not required for this notebook, you can install ReZero for PyTorch Transformers via `pip install rezero`.\n", 12 | "\n", 13 | "Running time of the notebook: 7 minutes on laptop with single RTX 2060 GPU (+ 21 minutes for training 128 layer transformer at the end)." 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "metadata": {}, 19 | "source": [ 20 | "As a simple illustration of the power of ReZero we will now train a 128 layer ReZero-Transformer network on a sequence-to-sequence model, and explore why this is not possible with the vanilla Transformer architecture. This discussion is based to the basic PyTorch tutorial [Sequence-to-Sequence Modeling with nn.Transformer and TorchText\n", 21 | "](https://pytorch.org/tutorials/beginner/transformer_tutorial.html).\n", 22 | "\n", 23 | "\n", 24 | "Before we build the model, let us define the `ReZeroEncoderLayer` following the default PyTorch implementation, but adding a residual weight (default = 0) that by default initializes this layer as the identity map. \n", 25 | "\n", 26 | "We also add arguments for pre-residual and post-residual LayerNorm applications. Using post-norm and seting the residual weight to 1 reproduces the standard encoder layer. We will explore the impact of various configurations on signal propagation." 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 1, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "######################################################################\n", 36 | "# Import and set manual seed\n", 37 | "\n", 38 | "\n", 39 | "import numpy as np\n", 40 | "import time\n", 41 | "import math\n", 42 | "import matplotlib.pyplot as plt\n", 43 | "import torch\n", 44 | "from torch.nn import functional as F\n", 45 | "from torch.nn.modules.module import Module\n", 46 | "from torch.nn.modules.activation import MultiheadAttention\n", 47 | "from torch.nn.modules.container import ModuleList\n", 48 | "from torch.nn.modules.dropout import Dropout\n", 49 | "from torch.nn.modules.linear import Linear\n", 50 | "from torch.nn.modules.normalization import LayerNorm\n", 51 | "from torch.nn.init import xavier_uniform_\n", 52 | "\n", 53 | "torch.manual_seed(0)\n", 54 | "\n", 55 | "######################################################################\n", 56 | "# Define the ReZero Transformer\n", 57 | "\n", 58 | "\n", 59 | "class ReZeroEncoderLayer(Module):\n", 60 | " r\"\"\"ReZero-TransformerEncoderLayer is made up of self-attn and feedforward network.\n", 61 | "\n", 62 | " Args:\n", 63 | " d_model: the number of expected features in the input (required).\n", 64 | " nhead: the number of heads in the multiheadattention models (required).\n", 65 | " dim_feedforward: the dimension of the feedforward network model (default=2048).\n", 66 | " dropout: the dropout value (default=0.1).\n", 67 | " activation: the activation function of intermediate layer, relu or gelu (default=relu).\n", 68 | " use_LayerNorm: using either no LayerNorm (dafault=False), or use LayerNorm \"pre\", or \"post\"\n", 69 | "\n", 70 | " Examples::\n", 71 | " >>> encoder_layer = ReZeroEncoderLayer(d_model=512, nhead=8)\n", 72 | " >>> src = torch.rand(10, 32, 512)\n", 73 | " >>> out = encoder_layer(src)\n", 74 | " \"\"\"\n", 75 | "\n", 76 | " def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation = \"relu\", \n", 77 | " use_LayerNorm = False, init_resweight = 0, resweight_trainable = True):\n", 78 | " super(ReZeroEncoderLayer, self).__init__()\n", 79 | " self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)\n", 80 | " \n", 81 | " # Define the Resisdual Weight for ReZero\n", 82 | " self.resweight = torch.nn.Parameter(torch.Tensor([init_resweight]), requires_grad = resweight_trainable)\n", 83 | "\n", 84 | " # Implementation of Feedforward model\n", 85 | " self.linear1 = Linear(d_model, dim_feedforward)\n", 86 | " self.dropout = Dropout(dropout)\n", 87 | " self.linear2 = Linear(dim_feedforward, d_model)\n", 88 | " self.use_LayerNorm = use_LayerNorm\n", 89 | " if self.use_LayerNorm != False:\n", 90 | " self.norm1 = LayerNorm(d_model)\n", 91 | " self.norm2 = LayerNorm(d_model)\n", 92 | " self.dropout1 = Dropout(dropout)\n", 93 | " self.dropout2 = Dropout(dropout)\n", 94 | " if activation == \"relu\":\n", 95 | " self.activation = F.relu\n", 96 | " elif activation == \"gelu\":\n", 97 | " self.activation = F.gelu\n", 98 | " elif activation == \"tanh\":\n", 99 | " self.activation = torch.tanh\n", 100 | "\n", 101 | " def __setstate__(self, state):\n", 102 | " if 'activation' not in state:\n", 103 | " state['activation'] = F.relu\n", 104 | " super(ReZeroEncoderLayer, self).__setstate__(state)\n", 105 | "\n", 106 | " def forward(self, src, src_mask=None, src_key_padding_mask=None):\n", 107 | " # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor\n", 108 | " r\"\"\"Pass the input through the encoder layer.\n", 109 | " Args:\n", 110 | " src: the sequence to the encoder layer (required).\n", 111 | " src_mask: the mask for the src sequence (optional).\n", 112 | " src_key_padding_mask: the mask for the src keys per batch (optional).\n", 113 | " Shape:\n", 114 | " see the docs in Transformer class.\n", 115 | " \"\"\"\n", 116 | " src2 = src\n", 117 | " if self.use_LayerNorm == \"pre\":\n", 118 | " src2 = self.norm1(src2)\n", 119 | " src2 = self.self_attn(src2, src2, src2, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)[0]\n", 120 | " # Apply the residual weight to the residual connection. This enables ReZero.\n", 121 | " src2 = self.resweight * src2\n", 122 | " src2 = self.dropout1(src2)\n", 123 | " if self.use_LayerNorm == False:\n", 124 | " src = src + src2\n", 125 | " elif self.use_LayerNorm == \"pre\":\n", 126 | " src = src + src2\n", 127 | " elif self.use_LayerNorm == \"post\":\n", 128 | " src = self.norm1(src + src2)\n", 129 | " src2 = src\n", 130 | " if self.use_LayerNorm == \"pre\":\n", 131 | " src2 = self.norm1(src2)\n", 132 | " src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))\n", 133 | " src2 = self.resweight * src2\n", 134 | " src2 = self.dropout2(src2)\n", 135 | " if self.use_LayerNorm == False:\n", 136 | " src = src + src2\n", 137 | " elif self.use_LayerNorm == \"pre\":\n", 138 | " src = src + src2\n", 139 | " elif self.use_LayerNorm == \"post\":\n", 140 | " src = self.norm1(src + src2)\n", 141 | " return src\n" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": {}, 147 | "source": [ 148 | "## Signal propagation\n", 149 | "\n", 150 | "Let us pause and examine the signal propagation properties in a toy deep network `DeepEncoder` by evaluating the singular values of the Transformer input-output Jacobian `io_jacobian_TF`.\n", 151 | "\n", 152 | "The entries of the input-output Jacobian matrix reflects the change of each output with respect to each input. The singular value decomposition of this matrix reflects by how much in magnitude (singular value) an input signal (the corresponding singular vector) changes as it propagates through the network, see the [Wikipedia page](https://en.wikipedia.org/wiki/Singular_value_decomposition). A vanishing singular value means that the corresponding singular vector is mapped to zero (poor signal propagation), while a large singular value means that the corresponding singular vector is amplified in magnitude (chaotic signal propagation). Due to these properties, the singular value decomposition provides a useful tool to study signal propagation in neural networks. As we will see in this notebook, singular values close to unity (i.e. dynamical isometry) often coincide with strong training performance.\n", 153 | "\n", 154 | "We will compare both pre- and the vanilla (post-) norm variants of the Transformer with the ReZero proposal that eliminates LayerNorm. Since ReZero initializes each layer to perform the identity map by setting all residual weights to zero, we here instead set all the residual weights to 0.1, in order to see a non-trivial distribution of the input-output Jacobian singular values. We define `plot_jacobians` to plot the singular value distributions for each architecture." 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 2, 160 | "metadata": {}, 161 | "outputs": [ 162 | { 163 | "name": "stdout", 164 | "output_type": "stream", 165 | "text": [ 166 | "Jacobian mean squared singular values for 12 layer Transformers:\n", 167 | "0.810 Vanilla Transformer (post norm)\n", 168 | "10.242 Transformer with pre-norm\n", 169 | "1.059 ReZero Transformer (resweight = 0.1)\n" 170 | ] 171 | }, 172 | { 173 | "data": { 174 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEKCAYAAADpfBXhAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAbxklEQVR4nO3de5RU5Z3u8e9Dg5IRIyMyOY6okES8torcdHkBj1HUY2RIjIAmGmNi1IBxEhN1nBg1ayYXyYkmkskxiQtFRYi3YSlzdEZF4wzKRVsB8YJKpNUEgoY5JDGh8Xf+2JtO0VR37YbqruqX57MWi9p7v3vXr6q7n3773bverYjAzMzS0qvWBZiZWfU53M3MEuRwNzNLkMPdzCxBDnczswQ53M3MElQx3CXdImmNpGXtbJekH0paKel5SUdUv0wzM+uMIj33GcDJHWw/Bdgv/3cB8C/bX5aZmW2PiuEeEU8A73TQZDxwW2SeAvpL2rNaBZqZWef1rsIx9gJWlyw35+vebttQ0gVkvXt22WWX4QcccEAVnt7MamXVut+3u23wgF26sZIdx5IlS34bEQMrtatGuKvMurJzGkTEzcDNACNGjIjFixdX4enNrFbOn7Go3W0//+zIbqxkxyHpV0XaVeNqmWZg75LlQcBbVTiumZlto2qE+1zgnPyqmSOB9RGx1ZCMmZl1n4rDMpJmAWOBPSQ1A98E+gBExE+AecCpwErgD8B5XVWsmZkVUzHcI2Jyhe0BfKkaxWzcuJHm5mbee++9ahzOCujbty+DBg2iT58+tS7FzKqoGidUq6a5uZldd92VwYMHI5U7T2vVFBGsW7eO5uZmhgwZUutyzKyK6mr6gffee48BAwY42LuJJAYMGOC/lMwSVFfhDjjYu5nfb7M01V24m5nZ9qurMfe2OvqAxLYo8qGKhoYGGhsbaWlpYciQIcycOZP+/fu32/6+++7j2muv3WLd888/z4MPPsgpp5yy3TWbmW0L99zb+MAHPkBTUxPLli1j9913Z/r06R22nzBhAk1NTa3/Lr74Yo499ljGjRtX6Pkigvfff78apZuZtXK4d+Coo47izTffbF2+/vrrGTlyJIceeijf/OY3t2r/8ssvc9111zFz5kx69erV7j6rVq3iwAMP5OKLL+aII45g9erVzJo1i8bGRg455BAuv/zy7nmBZpYsh3s7Nm3axCOPPMLpp58OwMMPP8wrr7zCwoULaWpqYsmSJTzxxBOt7Tdu3MhZZ53FtGnT2GeffSru89JLL3HOOefw7LPP0qdPHy6//HIeffRRmpqaWLRoEffff3/3v2gzS4bDvY0//vGPHH744QwYMIB33nmHE088EciC+uGHH2bYsGEcccQRvPjii7zyyiut+33jG9/g4IMPZtKkSa3rOtpn33335cgjjwRg0aJFjB07loEDB9K7d2/OPvvsLX5xmJl1Vl2fUK2FzWPu69ev57TTTmP69OlccsklRARXXnklX/ziF7faZ/78+dxzzz0888wzW6xvb59Vq1axyy67bNHOzKya3HNvx2677cYPf/hDpk2bxsaNGxk3bhy33HILGzZsAODNN99kzZo1vPvuu5x33nncdttt7Lrrrlsco7192ho9ejSPP/44v/3tb9m0aROzZs1izJgxXf8izSxZdd1zr/V80MOGDeOwww7jrrvu4jOf+QwrVqzgqKOOAqBfv37cfvvtzJkzhzVr1nDRRRdtse+VV17JxIkTy+7T0NCwRds999yTb3/72xx//PFEBKeeeirjx4/vnhdpZklSrYYEyt2sY8WKFRx44IE1qWdH5vfdtpVv1tH9JC2JiBGV2nlYxswsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7ME1fV17tw5sbrHO2t2h5vXrVvHCSecAMCvf/1rGhoaGDhwIADPPfcchx12WGvb+++/n8GDB1e3PjOzKqnvcO9mAwYMoKmpCYBrrrmGfv36cdlllwHZB5A2b9semzZt2upDTNXQ0tJC797+cppZxsMyVTJ//nyOO+44JkyYwEEHHcSFF17YOk97v379uPrqqxk9ejQLFixgyZIljBkzhuHDhzNu3DjefvttAMaOHcvll1/OqFGjGDp0KL/85S+B7N6y5513Ho2NjQwbNozHHnsMgBkzZvCpT32Kj3/845x00knMnz+fMWPGcOaZZzJ06FCuuOIK7rjjDkaNGkVjYyOvvvpqbd4cM+t2DveCNs8WefjhhzNhwoSybRYuXMj3v/99li5dyquvvsq9994LwO9//3sOOeQQnn76aUaPHs3UqVO5++67WbJkCZ/73Oe46qqrWo/R0tLCwoULueGGG1rv8LT5hiFLly5l1qxZnHvuua03tV6wYAG33norjz76KJANH914440sXbqUmTNn8vLLL7Nw4UI+//nP86Mf/ajL3h8zqy/+O76gzbNFdmTUqFF8+MMfBmDy5Mk8+eSTnHHGGTQ0NPDJT34SyOZxX7ZsWetUwps2bWLPPfdsPcYnPvEJAIYPH86qVasAePLJJ5k6dSoABxxwAPvuuy8vv/wyACeeeCK777576/4jR45sPd5HPvIRTjrpJAAaGxtbe/xmlj6HexVJKrvct2/f1nH2iODggw9mwYIFZY+x8847A9m9XFtaWlr3aU/p1MGl+wP06tWrdblXr16txzOz9HlYpooWLlzI66+/zvvvv8/s2bM55phjtmqz//77s3bt2tZw37hxI8uXL+/wuMcddxx33HEHkN3K74033mD//fev/gsws2TUd8+9wqWL9eaoo47iiiuuYOnSpa0nV9vaaaeduPvuu7nkkktYv349LS0tXHrppRx88MHtHvfiiy/mwgsvpLGxkd69ezNjxowteuhmZm15yt8qmT9/PtOmTeOBBx6odSmd1pPfd6stT/nb/Tzlr5nZDqy+h2V6kLFjxzJ27Nhal2FmBrjnbmaWJIe7mVmCHO5mZglyuJuZJaiuT6hOeWRKVY930wk3dbjdU/6aWSoKhbukk4EbgQbgZxHxnTbb9wFuBfrnba6IiHlVrrXLdceUv2Zm3aHisIykBmA6cApwEDBZ0kFtmv0jMCcihgGTgB9Xu1AzMyuuSM99FLAyIl4DkHQXMB54oaRNAB/MH+8GvFXNIuvB5il/AYYMGcJ9991X44rMzNpXJNz3AlaXLDcDo9u0uQZ4WNJUYBfgY+UOJOkC4AKAffbZp7O11lSRKX/NzOpFkatlVGZd2wlpJgMzImIQcCowU9JWx46ImyNiRESM2Hyi0szMqq9IuDcDe5csD2LrYZfzgTkAEbEA6AvsUY0Czcys84oMyywC9pM0BHiT7ITpWW3avAGcAMyQdCBZuK/d3uIqXbpoZmblVQz3iGiRNAV4iOwyx1siYrmk64DFETEX+CrwU0l/TzZk89mo1VzCVXLNNddssbxhw4baFGJmtg0KXeeeX7M+r826q0sevwAcXd3SzMxsW3n6ATOzBNVduPfw0Zwex++3WZrqKtz79u3LunXrHDjdJCJYt24dffv2rXUpZlZldTVx2KBBg2hubmbt2u2+0MYK6tu3L4MGDap1GWZWZXUV7n369GHIkCG1LsPMrMerq2EZMzOrDoe7mVmCHO5mZglyuJuZJcjhbmaWoLq6WsbMeq6pv/nHLVfc2X/L5bNmd18x5p67mVmKHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjzuZtZMXdO3GrV1N/8rgaFWBHuuZuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCCoW7pJMlvSRppaQr2mlzpqQXJC2XdGd1yzQzs86o+CEmSQ3AdOBEoBlYJGluRLxQ0mY/4Erg6Ih4V9LfdFXBZmZWWZGe+yhgZUS8FhF/Bu4Cxrdp8wVgekS8CxARa6pbppmZdUaRcN8LWF2y3JyvKzUUGCrpPyU9JenkcgeSdIGkxZIWr127dtsqNjOzioqEu8qsizbLvYH9gLHAZOBnkvpvtVPEzRExIiJGDBw4sLO1mplZQUXCvRnYu2R5EPBWmTb/GhEbI+J14CWysDczsxooEu6LgP0kDZG0EzAJmNumzf3A8QCS9iAbpnmtmoWamVlxFcM9IlqAKcBDwApgTkQsl3SdpNPzZg8B6yS9ADwGfC0i1nVV0WZm1rFC87lHxDxgXpt1V5c8DuAr+T8zM6sxf0LVzCxBDnczswQ53M3MEuRwNzNLkMPdzCxBDnczswQ53M3MEuRwNzNLkMPdzCxBDnczswQ53M3MEuRwNzNLkMPdzCxBDnczswQ53M3MEuRwNzNLUKGbdZiZbbc7J3a8/azZ3VPHDsI9dzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwSVCjcJZ0s6SVJKyVd0UG7MySFpBHVK9HMzDqr4j1UJTUA04ETgWZgkaS5EfFCm3a7ApcAT3dFoWbWhSrd39R6nCI991HAyoh4LSL+DNwFjC/T7lvA94D3qlifmZltgyLhvhewumS5OV/XStIwYO+IeKCjA0m6QNJiSYvXrl3b6WLNzKyYIuGuMuuidaPUC/gB8NVKB4qImyNiRESMGDhwYPEqzcysU4qEezOwd8nyIOCtkuVdgUOA+ZJWAUcCc31S1cysdoqE+yJgP0lDJO0ETALmbt4YEesjYo+IGBwRg4GngNMjYnGXVGxmZhVVDPeIaAGmAA8BK4A5EbFc0nWSTu/qAs3MrPMqXgoJEBHzgHlt1l3dTtux21+WmZltD39C1cwsQQ53M+sSTat/V+sSdmgOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwSVGjiMDOrb1MembJ9B2h5o/XhTb332c5qrB64525mliCHu5lZghzuZmYJcribmSXI4W5m28Tztdc3h7uZWYIc7mZmCXK4m5klyOFuZpYgh7uZWYI8/YBZDW33tAFm7XDP3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEFQp3SSdLeknSSklXlNn+FUkvSHpe0iOS9q1+qWZmVlTFcJfUAEwHTgEOAiZLOqhNs2eBERFxKHA38L1qF2pmZsUVmThsFLAyIl4DkHQXMB54YXODiHispP1TwKerWaSZdZ8pLW8Uare+/8aKbWbSf3vLsW1UZFhmL2B1yXJzvq495wP/Vm6DpAskLZa0eO3atcWrNDOzTikS7iqzLso2lD4NjACuL7c9Im6OiBERMWLgwIHFqzQzs04pMizTDOxdsjwIeKttI0kfA64CxkTEn6pTnpmZbYsiPfdFwH6ShkjaCZgEzC1tIGkY8H+A0yNiTfXLNDOzzqgY7hHRAkwBHgJWAHMiYrmk6ySdnje7HugH/EJSk6S57RzOzMy6QaHb7EXEPGBem3VXlzz+WJXrMjOz7eBPqJqZJcjhbmaWoELDMmZmXe7OiZXbnDW76+tIhHvuZmYJcribmSXIwzJmnTTlkSm1LsGsIvfczcwS5J672Y7gzSW1rsC6mXvuZmYJcribmSXI4W5mliCHu5lZghzuZmYJcribmSXI4W5mliCHu5lZghzuZmYJ8idUe7BqznFy0wk3Ve1YZlZ77rmbmSXI4W5mliCHu5lZghzuZmYJcribmSXI4W5mliCHu5lZgnyduwG+Zt4sNe65m5klyD13sxT4HqnWhnvuZmYJcs/ddgjVPKdgxU1peaMqx7mp9z5VOc6OxOFuZj3HnRM73n7W7O6powfwsIyZWYIc7mZmCXK4m5klyGPuVrd8EtRs2xUKd0knAzcCDcDPIuI7bbbvDNwGDAfWARMjYlV1SzVLlK9Rty5QcVhGUgMwHTgFOAiYLOmgNs3OB96NiI8CPwC+W+1CzcysuCJj7qOAlRHxWkT8GbgLGN+mzXjg1vzx3cAJklS9Ms3MrDOKDMvsBawuWW4GRrfXJiJaJK0HBgC/LW0k6QLggnxxg6SXtqXoCvZo+7w9QFI1T2d6N5dSSFLvcR3rkpqn81SxhmfP6eyhe+J7vH+RRkXCvVwPPLahDRFxM3BzgefcZpIWR8SIrnyOanPNXa+n1QuuuTv0tHohq7lIuyLDMs3A3iXLg4C32msjqTewG/BOkQLMzKz6ioT7ImA/SUMk7QRMAua2aTMXODd/fAbwaERs1XM3M7PuUXFYJh9DnwI8RHYp5C0RsVzSdcDiiJgL/ByYKWklWY99UlcWXUGXDvt0Edfc9XpaveCau0NPqxcK1ix3sM3M0uPpB8zMEuRwNzNLUDLhLulTkpZLel/SiDbbrpS0UtJLksbVqsaOSDpc0lOSmiQtljSq1jVVImlq/p4ul/S9WtdTlKTLJIWkPWpdSyWSrpf0oqTnJd0nqX+taypH0sn598JKSVfUup5KJO0t6TFJK/Lv3y/XuqYiJDVIelbSA5XaJhPuwDLgE8ATpSvzqRImAQcDJwM/zqdUqDffA66NiMOBq/PluiXpeLJPJh8aEQcD02pcUiGS9gZOBKpzi6Cu9+/AIRFxKPAycGWN69lKwSlK6k0L8NWIOBA4EvhSD6gZ4MvAiiINkwn3iFgREeU+8ToeuCsi/hQRrwMryaZUqDcBfDB/vBtbf5ag3lwEfCci/gQQEWtqXE9RPwC+TpkP2dWjiHg4IlryxafIPmdSb4pMUVJXIuLtiHgmf/z/yAJzr9pW1TFJg4D/BfysSPtkwr0D5aZPqMcv4qXA9ZJWk/WC666H1sZQ4FhJT0t6XNLIWhdUiaTTgTcj4rla17KNPgf8W62LKKOn/IyVJWkwMAx4uraVVHQDWcfk/SKNe9R87pL+A/gfZTZdFRH/2t5uZdbVpNfWUf3ACcDfR8Q9ks4k++zAx7qzvrYq1Nsb+GuyP2lHAnMkfbjWH16rUPM/ACd1b0WVFfm+lnQV2VDCHd1ZW0F18zPWWZL6AfcAl0bEf9e6nvZIOg1YExFLJI0tsk+PCveI2JawKzJ9QrfoqH5Jt5GNpwH8goJ/enWlCvVeBNybh/lCSe+TTcK0trvqK6e9miU1AkOA5/IJSwcBz0gaFRG/7sYSt1Lp+1rSucBpwAm1/uXZjrr5GesMSX3Igv2OiLi31vVUcDRwuqRTgb7AByXdHhGfbm+HHWFYZi4wSdLOkoYA+wELa1xTOW8BY/LH/xN4pYa1FHE/WZ1IGgrsRB3PrhcRSyPibyJicEQMJgukI2od7JXkN8q5HDg9Iv5Q63raUWSKkrqST0n+c2BFRPzvWtdTSURcGRGD8u/dSWRTvLQb7NDDeu4dkTQB+BEwEHhQUlNEjMunSpgDvED2Z+2XImJTLWttxxeAG/OJ197jL1Mj16tbgFskLQP+DJxbp73Knu4mYGfg3/O/OJ6KiAtrW9KW2puipMZlVXI08BlgqaSmfN0/RMS8GtZUVZ5+wMwsQTvCsIyZ2Q7H4W5mliCHu5lZghzuZmYJcribmSXI4Z4oSRuqeKwbJB2XP/5ZZyZYkjS2yAx2naxnVbkZHSVdKOmcaj5Xm+N36rV34rjz285k2pUkXSPpsiof8z8k/XU1j2nbJ5nr3K1rSNodODIiLgWIiM/XuKR2RcRPuvj4dfHaJTXU4Wc1ZgIXA/9U60Is45574pS5XtIySUslTczX95L043wu6wckzZN0RplDnAH835LjtfYyJU3Oj7lM0ncL1DJK0n/l81H/l6T98/UNkqblx3pe0tR8/Ql526WSbpG0c8nhviZpYf7vo3n71h6ppC9IWiTpOUn3SPqrfP0MST/Mn/+1cq9Z0i6SHsz3XVbynpW+9g2S/ilv85SkD+XrP5IvL5J03ea/oNr+BSPpJkmfLfPc/6JsPv/lkq4tWb9K0tWSngQ+VbJ+t3xbr3z5ryStltSnvfegzfOVvqY9JK0q+Zpcn+//vKQv5uv3lPSEsvsOLJN0bH6oucDk9r/61t0c7un7BHA4cBjZRGTXS9ozXz8YaAQ+DxzVzv5HA0varpT0t8B3yaYgOBwYKenvKtTyInBcRAwjm7P+n/P1F5DN+zIsn7f8Dkl9gRnAxIhoJPsr86KSY/13RIwi+wTnDWWe696IGBkRh5FN53p+ybY9gWPI5mv5Tpl9TwbeiojDIuIQSn65ldiF7NOih5HdQ+AL+fobgRsjYiTbNr/KVRExAjgUGCPp0JJt70XEMRFx1+YVEbEeeI6/TF3xceChiNhIx+9BJecD6/PXMRL4grLpO87Kj7/5e6opr+NdYGdJA7bhNVsXcLin7xhgVkRsiojfAI+T/bAeA/wiIt7P51d5rJ3996T8ZGAjgfkRsTafb/wO4LgKtewG/ELZlAU/ILuBCmS/dH6yed7yiHgH2B94PSJeztvc2ub4s0r+L/eL6RBJv5S0FDi75LkA7s9f9wvAh8rsuxT4mKTvSjo2D9C2/gxs7okvIftFSV7LL/LHd5bZr5IzJT0DPJvXXDrGP7udfWYDE/PHk0radfQeVHIScI6yj+Y/DQwgm5dpEXCepGuAxnwu9M3WAH/bieewLuRwT1+56Vg7Wt/WH8lmoSu0v6QJ+Z/sTdr6JOG3gMfy3vDHS44rtp4itlJ90c7jzWYAU/Je/7Vs+Rr+1NHz5L9QhpOF/LclXV3m+BtL5tLZROXzVy1s+fO21Xua94wvI5v98VDgwTbtft/OsecCpyg7PzIceDRfP4P234NydZVuFzA1Ig7P/w3JbxzyBNkv2TeBmdryBHZfsu8XqwMO9/Q9AUzMx1AHkv1gLgSeBD6Zj71/CBjbzv4rgI+WWf802bDBHspuszYZeDwi7isJhMVt9tmNLBQAPluy/mHgQmWTpm0+ifsiMHjzeDrZJE+Pl+wzseT/BWXq2xV4W9m0rme389rKyoec/hARt5PdOOWITuz+FPDJ/PGkkvW/Ag5SNjvpbmTz97f1QbIAX59/TU4p8oQRsYHsa3oj8EDJydYi78Eqsl8IkJ1f2ewh4KJ8XyQNzc9F7Es2r/hPyWZVPCLfLrI56VcVqdm6nq+WSd99ZEMFz5H1cL8eEb+WdA9ZwCwjuzfn00C54YcHgS/SZn75iHhb0pVkwzkC5rVzw5Te/KWn/D3gVklf4S+9S/JjDwWel7QR+GlE3CTpPLJhnN5kwwGlV8PsLOlpsg5KuRN538hf06/IeuC7lmnTnkaycxPvAxvZcqy/kkuB2yV9ley9Ww8QEauVzU76PNl0zs+23TEinpP0LLAceA34z04872yy4aCxJeuKvAfTyG608hm2/poMJpvzXmRDc3+XH/9r+ddpA7C55z6c7BxEC1YXPCvkDkxSv4jYkJ8EWwgcXW5+8/wKjdMi4nfb8BxfBvaKiK9vf8X1L78i5Y8REZImAZMjoq7vJ1oNkm4E5kbEI7WuxTLuue/YHpDUn+xGG9/q4MYVXwX2AToV7pJ+DhwCnLldVfYsw4Gb8t7u78jue7ojWOZgry/uuZuZJcgnVM3MEuRwNzNLkMPdzCxBDnczswQ53M3MEvT/AQA3YutEngp3AAAAAElFTkSuQmCC\n", 175 | "text/plain": [ 176 | "
" 177 | ] 178 | }, 179 | "metadata": { 180 | "needs_background": "light" 181 | }, 182 | "output_type": "display_data" 183 | }, 184 | { 185 | "name": "stdout", 186 | "output_type": "stream", 187 | "text": [ 188 | "Jacobian mean squared singular values for 128 layer Transformers:\n", 189 | "0.000 Vanilla Transformer (post norm)\n", 190 | "66.281 Transformer with pre-norm\n", 191 | "2.142 ReZero Transformer (resweight = 0.1)\n" 192 | ] 193 | }, 194 | { 195 | "data": { 196 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEKCAYAAADpfBXhAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAb5ElEQVR4nO3de5RU9Znu8e9Dg5IRIyMyGUZUSCIq0gpy0+UFPERBj9EhMQo60RgTowaME02EMTFq1pxcJCeaQCbHJC4MKmK8DUuZozMqGmdQLtoKiCIqkVYTCCpzSGKk8T1/7E2naKq7qqG6q/rH81mLRe1dv73rrerup3+9d9W7FRGYmVlaulW7ADMzqzyHu5lZghzuZmYJcribmSXI4W5mliCHu5lZgkqGu6RbJK2XtKKV+yXpR5LWSHpe0lGVL9PMzNqjnJn7bGBCG/efAhyc/7sI+JddL8vMzHZFyXCPiCeAt9sYcgbwy8g8BfSW1K9SBZqZWft1r8A+9gfWFSw35uveajlQ0kVks3v22muv4YceemgFHt7MOtPajX9o9zYD+uzVAZXsnpYtW/b7iOhbalwlwl1F1hXtaRARNwM3A4wYMSKWLl1agYc3s8504ewl7d7mF58b2QGV7J4k/aaccZV4t0wjcEDBcn/gzQrs18zMdlIlwn0+cF7+rpmjgU0RscMhGTMz6zwlD8tImguMBfaT1Ah8C+gBEBE/BRYApwJrgD8CF3RUsWZmVp6S4R4Rk0vcH8CXK1HMli1baGxs5L333qvE7qwMPXv2pH///vTo0aPapZhZBVXihGrFNDY2svfeezNgwACkYudprZIigo0bN9LY2MjAgQOrXY6ZVVBNtR9477336NOnj4O9k0iiT58+/kvJLEE1Fe6Ag72T+fU2S1PNhbuZme26mjrm3tLOfFiiLeV8kKKuro76+nqampoYOHAgc+bMoXfv3q2Ov++++7juuuu2W/f888/z4IMPcsopp+xyzWZmO8Mz9xY+9KEP0dDQwIoVK9h3332ZNWtWm+MnTpxIQ0ND879LL72U448/nvHjx5f1eBHBBx98UInSzcyaOdzbcMwxx/DGG280L99www2MHDmSI444gm9961s7jF+9ejXXX389c+bMoVu3bq1us3btWg477DAuvfRSjjrqKNatW8fcuXOpr69nyJAhXHXVVZ3zBM0sWQ73VmzdupVHHnmE008/HYCHH36Yl19+mcWLF9PQ0MCyZct44oknmsdv2bKFc845hxkzZnDggQeW3Oall17ivPPO49lnn6VHjx5cddVVPProozQ0NLBkyRLuv//+zn/SZpYMh3sLf/rTnxg6dCh9+vTh7bff5qSTTgKyoH744YcZNmwYRx11FC+++CIvv/xy83bf/OY3Ofzww5k0aVLzura2Oeiggzj66KMBWLJkCWPHjqVv3750796dc889d7tfHGZm7VXTJ1SrYdsx902bNnHaaacxa9YsLrvsMiKC6dOn86UvfWmHbRYuXMg999zDM888s9361rZZu3Yte+2113bjzMwqyTP3Vuyzzz786Ec/YsaMGWzZsoXx48dzyy23sHnzZgDeeOMN1q9fzzvvvMMFF1zAL3/5S/bee+/t9tHaNi2NHj2axx9/nN///vds3bqVuXPnMmbMmI5/kmaWrJqeuVe7B/SwYcM48sgjufPOO/nsZz/LqlWrOOaYYwDo1asXt912G3fddRfr16/nkksu2W7b6dOnc/bZZxfdpq6ubrux/fr14zvf+Q4nnngiEcGpp57KGWec0TlP0sySpGodEih2sY5Vq1Zx2GGHVaWe3Zlfd2sPX6yjuiQti4gRpcb5sIyZWYIc7mZmCXK4m5klyOFuZpYgh7uZWYIc7mZmCarp97lzx9mV3d8589q8e+PGjYwbNw6A3/72t9TV1dG3b18AnnvuOY488sjmsffffz8DBgyobH1mZhVS2+Heyfr06UNDQwMA1157Lb169eLKK68Esg8gbbtvV2zdunWHDzFVQlNTE927+8tpZhkflqmQhQsXcsIJJzBx4kQGDx7MxRdf3NynvVevXlxzzTWMHj2aRYsWsWzZMsaMGcPw4cMZP348b731FgBjx47lqquuYtSoUQwaNIhf//rXQHZt2QsuuID6+nqGDRvGY489BsDs2bP5zGc+wyc/+UlOPvlkFi5cyJgxYzjrrLMYNGgQ06ZN4/bbb2fUqFHU19fzyiuvVOfFMbNO53Av07ZukUOHDmXixIlFxyxevJgf/OAHLF++nFdeeYV7770XgD/84Q8MGTKEp59+mtGjRzN16lTuvvtuli1bxuc//3muvvrq5n00NTWxePFibrzxxuYrPG27YMjy5cuZO3cu559/fvNFrRctWsStt97Ko48+CmSHj2666SaWL1/OnDlzWL16NYsXL+YLX/gCP/7xjzvs9TGz2uK/48u0rVtkW0aNGsVHP/pRACZPnsyTTz7JmWeeSV1dHZ/+9KeBrI/7ihUrmlsJb926lX79+jXv41Of+hQAw4cPZ+3atQA8+eSTTJ06FYBDDz2Ugw46iNWrVwNw0kknse+++zZvP3LkyOb9fexjH+Pkk08GoL6+vnnGb9Zelb7kpXU8h3sFSSq63LNnz+bj7BHB4YcfzqJFi4ruY8899wSya7k2NTU1b9OawtbBhdsDdOvWrXm5W7duzfszs/T5sEwFLV68mNdee40PPviAefPmcdxxx+0w5pBDDmHDhg3N4b5lyxZWrlzZ5n5POOEEbr/9diC7lN/rr7/OIYccUvknYGbJqO2Ze4m3LtaaY445hmnTprF8+fLmk6st7bHHHtx9991cdtllbNq0iaamJi6//HIOP/zwVvd76aWXcvHFF1NfX0/37t2ZPXv2djN0M7OW3PK3QhYuXMiMGTN44IEHql1Ku3Xl1906x64ec3fL38pxy18zs91YbR+W6ULGjh3L2LFjq12GmRngmbuZWZIc7mZmCXK4m5klyOFuZpagmj6hOuWRKRXd38xxM9u83y1/zSwVZYW7pAnATUAd8POI+G6L+w8EbgV652OmRcSCCtfa4Tqj5a+ZWWcoeVhGUh0wCzgFGAxMljS4xbBvAHdFxDBgEvCTShdqZmblK2fmPgpYExGvAki6EzgDeKFgTAAfzm/vA7xZySJrwbaWvwADBw7kvvvuq3JFZmatKyfc9wfWFSw3AqNbjLkWeFjSVGAv4BPFdiTpIuAigAMPPLC9tVZVOS1/zcxqRTnvllGRdS0b0kwGZkdEf+BUYI6kHfYdETdHxIiIGLHtRKWZmVVeOTP3RuCAguX+7HjY5UJgAkBELJLUE9gPWF+JIs2saytsPOYmYp2jnHBfAhwsaSDwBtkJ03NajHkdGAfMlnQY0BPYsKvFlXrropmZFVcy3COiSdIU4CGytzneEhErJV0PLI2I+cAVwM8k/SPZIZvPRbV6CVfItddeu93y5s2bq1OImdlOKOt97vl71he0WHdNwe0XgGMrW5qZme0stx8wM0tQzYV7Fz+a0+X49TZLU02Fe8+ePdm4caMDp5NEBBs3bqRnz57VLsXMKqymGof179+fxsZGNmzY5TfaWJl69uxJ//79q12GmVVYTYV7jx49GDhwYLXLMDPr8mrqsIyZmVWGw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0tQWeEuaYKklyStkTStlTFnSXpB0kpJd1S2TDMza4/upQZIqgNmAScBjcASSfMj4oWCMQcD04FjI+IdSX/TUQWbmVlp5czcRwFrIuLViHgfuBM4o8WYLwKzIuIdgIhYX9kyzcysPcoJ9/2BdQXLjfm6QoOAQZL+U9JTkiYU25GkiyQtlbR0w4YNO1exmZmVVPKwDKAi66LIfg4GxgL9gV9LGhIR7263UcTNwM0AI0aMaLmPmjLlkSkV3+fMcTMrvk8zs2LKmbk3AgcULPcH3iwy5l8jYktEvAa8RBb2ZmZWBeWE+xLgYEkDJe0BTALmtxhzP3AigKT9yA7TvFrJQs3MrHwlwz0imoApwEPAKuCuiFgp6XpJp+fDHgI2SnoBeAz4WkRs7KiizcysbeUccyciFgALWqy7puB2AF/N/5mZWZX5E6pmZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJaislr9mZsVM/d032r/RHb2z/8+ZV9libDueuZuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCfA1Vs93dHWeXHDL1d+92QiFWSWXN3CVNkPSSpDWSprUx7kxJIWlE5Uo0M7P2KhnukuqAWcApwGBgsqTBRcbtDVwGPF3pIs3MrH3KmbmPAtZExKsR8T5wJ3BGkXHfBr4PvFfB+szMbCeUE+77A+sKlhvzdc0kDQMOiIgH2tqRpIskLZW0dMOGDe0u1sy6voZ179Kw7l0unL2k2qUkrZxwV5F10Xyn1A34IXBFqR1FxM0RMSIiRvTt27f8Ks3MrF3KebdMI3BAwXJ/4M2C5b2BIcBCSQB/C8yXdHpELK1UoWbWuRrW+R0yXVk5M/clwMGSBkraA5gEzN92Z0Rsioj9ImJARAwAngIc7GZmVVQy3COiCZgCPASsAu6KiJWSrpd0ekcXaGZm7VfWh5giYgGwoMW6a1oZO3bXyzIzs13h9gNmZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZgkq60pMZmaVNvV334A7erd/w3PmVb6YBHnmbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyP/dONOWRKR2y35njZnbIfs2s6ypr5i5pgqSXJK2RNK3I/V+V9IKk5yU9IumgypdqZmblKhnukuqAWcApwGBgsqTBLYY9C4yIiCOAu4HvV7pQMzMrXzkz91HAmoh4NSLeB+4EzigcEBGPRcQf88WngP6VLdPMzNqjnHDfH1hXsNyYr2vNhcC/FbtD0kWSlkpaumHDhvKrNDOzdikn3FVkXRQdKP0DMAK4odj9EXFzRIyIiBF9+/Ytv0ozM2uXct4t0wgcULDcH3iz5SBJnwCuBsZExJ8rU56Zme2McmbuS4CDJQ2UtAcwCZhfOEDSMOD/AKdHxPrKl2lmZu1RcuYeEU2SpgAPAXXALRGxUtL1wNKImE92GKYX8CtJAK9HxOkdWPd2Our942a7m4Z171a7BKuQsj7EFBELgAUt1l1TcPsTFa7LzMx2gdsPmJklyO0HzFJxx9nVrsBqiGfuZmYJ8szdzKqm8ATu0AN6V7GS9HjmbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7mVmCfJk9M+taduVC4OfMq1wdNc4zdzOzBDnczcwS5HA3M0uQj7mb1ZJdOZ5sVsAzdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEFlhbukCZJekrRG0rQi9+8paV5+/9OSBlS6UDMzK1/JT6hKqgNmAScBjcASSfMj4oWCYRcC70TExyVNAr4H+KN2XdiUR6ZUu4Sqmjlu5s5v3MqnTKc0vb7z++wkm3pvAeCKd/tUuRLbVeW0HxgFrImIVwEk3QmcARSG+xnAtfntu4GZkhQRUcFarRW7RRC/sWznttt/+E5ttkuvaRcIcUtfOeG+P7CuYLkRGN3amIhokrQJ6AP8vnCQpIuAi/LFzZJe2pmiS9iv5eN2Aa65wzy17UYXqXc7Vav5tp3ftLZf53Pvarmmtust7pByBpUT7iqyruWMvJwxRMTNwM1lPOZOk7Q0IkZ05GNUmmvueF2tXnDNnaGr1QtZzeWMK+eEaiNwQMFyf+DN1sZI6g7sA7xdTgFmZlZ55YT7EuBgSQMl7QFMAua3GDMfOD+/fSbwqI+3m5lVT8nDMvkx9CnAQ0AdcEtErJR0PbA0IuYDvwDmSFpDNmOf1JFFl9Chh306iGvueF2tXnDNnaGr1Qtl1ixPsM3M0uNPqJqZJcjhbmaWoGTCXdJnJK2U9IGkES3um563RnhJ0vhq1dgWSUMlPSWpQdJSSaOqXVMpkqbmr+lKSd+vdj3lknSlpJC0X7VrKUXSDZJelPS8pPsk9a52TcWUalFSayQdIOkxSavy79+vVLumckiqk/SspAdKjU0m3IEVwKeAJwpXShpMdoL3cGAC8JO8pUKt+T5wXUQMBa7Jl2uWpBPJPpl8REQcDsyockllkXQAWSuNrvIx0n8HhkTEEcBqYHqV69lBQYuSU4DBwOT8566WNQFXRMRhwNHAl7tAzQBfAVaVMzCZcI+IVRFR7BOvZwB3RsSfI+I1YA1ZS4VaE8CH89v7sONnCWrNJcB3I+LPABGxvsr1lOuHwNcp8iG7WhQRD0dEU774FNnnTGpNc4uSiHgf2NaipGZFxFsR8Ux++/+RBeb+1a2qbZL6A/8T+Hk545MJ9zYUa59Qi1/Ey4EbJK0jmwXX3AythUHA8XkX0Mcljax2QaVIOh14IyKeq3YtO+nzwL9Vu4giusrPWFF5F9thwNPVraSkG8kmJh+UM7ic9gM1Q9J/AH9b5K6rI+JfW9usyLqqzNraqh8YB/xjRNwj6Syyzw58ojPra6lEvd2Bvyb7k3YkcJekj1b7w2slav4n4OTOrai0cr6vJV1Ndijh9s6srUw18zPWXpJ6AfcAl0fEf1e7ntZIOg1YHxHLJI0tZ5suFe4RsTNhV077hE7RVv2Sfkl2PA3gV5T5p1dHKlHvJcC9eZgvlvQBWROmDZ1VXzGt1SypHhgIPCcJsu+DZySNiojfdmKJOyj1fS3pfOA0YFy1f3m2omZ+xtpDUg+yYL89Iu6tdj0lHAucLulUoCfwYUm3RcQ/tLbB7nBYZj4wKb+gyEDgYGBxlWsq5k1gTH77fwAvV7GWctxPVieSBgF7UMPd9SJieUT8TUQMiIgBZIF0VLWDvRRJE4CrgNMj4o/VrqcV5bQoqSnKfsP/AlgVEf+72vWUEhHTI6J//r07iazFS6vBDl1s5t4WSROBHwN9gQclNUTE+LxVwl1k/eebgC9HxNZq1tqKLwI35Y3X3uMvrZFr1S3ALZJWAO8D59forLKrmwnsCfx7/hfHUxFxcXVL2l5rLUqqXFYpxwKfBZZLasjX/VNELKhiTRXl9gNmZgnaHQ7LmJntdhzuZmYJcribmSXI4W5mliCHu5lZghzuiZK0uYL7ulHSCfntn7enwZKkseV0sGtnPWuLdXSUdLGk8yr5WC32367n3o79LmzZybQjSbpW0pUV3ud/SPrrSu7Tdk0y73O3jiFpX+DoiLgcICK+UOWSWhURP+3g/dfEc5dUV4Of1ZgDXAr8c7ULsYxn7olT5gZJKyQtl3R2vr6bpJ/kvawfkLRA0plFdnEm8H8L9tc8y5Q0Od/nCknfK6OWUZL+K+9H/V+SDsnX10make/reUlT8/Xj8rHLJd0iac+C3X1N0uL838fz8c0zUklflLRE0nOS7pH0V/n62ZJ+lD/+q8Wes6S9JD2Yb7ui4DUrfO6bJf1zPuYpSR/J138sX14i6fptf0G1/AtG0kxJnyvy2P+irJ//SknXFaxfK+kaSU8CnylYv09+X7d8+a8krZPUo7XXoMXjFT6n/SStLfia3JBv/7ykL+Xr+0l6Qtl1B1ZIOj7f1XxgcutffetsDvf0fQoYChxJ1ojsBkn98vUDgHrgC8AxrWx/LLCs5UpJfwd8j6wFwVBgpKS/L1HLi8AJETGMrGf9/8rXX0TW92VY3rf8dkk9gdnA2RFRT/ZX5iUF+/rviBhF9gnOG4s81r0RMTIijiRr53phwX39gOPI+rV8t8i2E4A3I+LIiBhCwS+3AnuRfVr0SLJrCHwxX38TcFNEjGTn+qtcHREjgCOAMZKOKLjvvYg4LiLu3LYiIjYBz/GX1hWfBB6KiC20/RqUciGwKX8eI4EvKmvfcU6+/23fUw15He8Ae0rqsxPP2TqAwz19xwFzI2JrRPwOeJzsh/U44FcR8UHeX+WxVrbvR/FmYCOBhRGxIe83fjtwQola9gF+paxlwQ/JLqAC2S+dn27rWx4RbwOHAK9FxOp8zK0t9j+34P9iv5iGSPq1pOXAuQWPBXB//rxfAD5SZNvlwCckfU/S8XmAtvQ+sG0mvozsFyV5Lb/Kb99RZLtSzpL0DPBsXnPhMf55rWwzDzg7vz2pYFxbr0EpJwPnKfto/tNAH7K+TEuACyRdC9TnvdC3WQ/8XTsewzqQwz19xdqxtrW+pT+RdaEra3tJE/M/2Ru040nCbwOP5bPhTxbsV+zYIrZUfdHK7W1mA1PyWf91bP8c/tzW4+S/UIaThfx3JF1TZP9bCnrpbKX0+asmtv952+E1zWfGV5J1fzwCeLDFuD+0su/5wCnKzo8MBx7N18+m9degWF2F9wuYGhFD838D8wuHPEH2S/YNYI62P4Hdk+z7xWqAwz19TwBn58dQ+5L9YC4GngQ+nR97/wgwtpXtVwEfL7L+abLDBvspu8zaZODxiLivIBCWtthmH7JQAPhcwfqHgYuVNU3bdhL3RWDAtuPpZE2eHi/Y5uyC/xcVqW9v4C1lbV3PbeW5FZUfcvpjRNxGduGUo9qx+VPAp/PbkwrW/wYYrKw76T5k/ftb+jBZgG/KvyanlPOAEbGZ7Gt6E/BAwcnWcl6DtWS/ECA7v7LNQ8Al+bZIGpSfiziIrK/4z8i6Kh6V3y+ynvRry6nZOp7fLZO++8gOFTxHNsP9ekT8VtI9ZAGzguzanE8DxQ4/PAh8iRb95SPiLUnTyQ7nCFjQygVTuvOXmfL3gVslfZW/zC7J9z0IeF7SFuBnETFT0gVkh3G6kx0OKHw3zJ6SniaboBQ7kffN/Dn9hmwGvneRMa2pJzs38QGwhe2P9ZdyOXCbpCvIXrtNABGxTll30ufJ2jk/23LDiHhO0rPASuBV4D/b8bjzyA4HjS1YV85rMIPsQiufZcevyQCynvciOzT39/n+v5Z/nTYD22buw8nOQTRhNcFdIXdjknpFxOb8JNhi4Nhi/c3zd2icFhHv7sRjfAXYPyK+vusV1778HSl/ioiQNAmYHBE1fT3RSpB0EzA/Ih6pdi2W8cx99/aApN5kF9r4dhsXrrgCOBBoV7hL+gUwBDhrl6rsWoYDM/PZ7rtk1z3dHaxwsNcWz9zNzBLkE6pmZglyuJuZJcjhbmaWIIe7mVmCHO5mZgn6/zRRdu+VLP4eAAAAAElFTkSuQmCC\n", 197 | "text/plain": [ 198 | "
" 199 | ] 200 | }, 201 | "metadata": { 202 | "needs_background": "light" 203 | }, 204 | "output_type": "display_data" 205 | } 206 | ], 207 | "source": [ 208 | "########################################################################\n", 209 | "#Compare io-Jacobian singular values during training of deep Transformer and Deepformer encoders\n", 210 | "\n", 211 | "\n", 212 | "########################################################################\n", 213 | "#Define input-output Jacobian\n", 214 | "\n", 215 | "def io_jacobian_TF(model, x):\n", 216 | " le = x.size()[0]\n", 217 | " emb = x.size()[2]\n", 218 | " noutputs = emb * le\n", 219 | " x = x.reshape(noutputs)\n", 220 | " x = x.repeat(noutputs,1)\n", 221 | " x.requires_grad_(True)\n", 222 | " y = model(x.reshape(emb*le,le,emb).transpose(0,1)).reshape(noutputs,-1)\n", 223 | " y.backward(torch.eye(noutputs).to(device))\n", 224 | " return x.grad.data\n", 225 | "\n", 226 | "\n", 227 | "########################################################################\n", 228 | "#Define an example deep transformer encoder network\n", 229 | "\n", 230 | "class DeepEncoder(torch.nn.Module):\n", 231 | "\n", 232 | " def __init__(self, ninp, nhead, nhid, nlayers, dropout = 0, variant = 'ReZero', \n", 233 | " use_LayerNorm = False, init_resweight = 0):\n", 234 | " super(DeepEncoder, self).__init__()\n", 235 | " from torch.nn import TransformerEncoder\n", 236 | " if variant == 'ReZero':\n", 237 | " encoder_layers = ReZeroEncoderLayer(ninp, nhead, nhid, dropout, \n", 238 | " use_LayerNorm = use_LayerNorm, init_resweight = init_resweight)\n", 239 | " else:\n", 240 | " encoder_layers = torch.nn.TransformerEncoderLayer(ninp, nhead, nhid, dropout)\n", 241 | " self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)\n", 242 | " self._reset_parameters()\n", 243 | " \n", 244 | " def forward(self, src):\n", 245 | " src = self.transformer_encoder(src)\n", 246 | " return src\n", 247 | " \n", 248 | " def _reset_parameters(self):\n", 249 | " r\"\"\"Initiate parameters in the transformer model.\"\"\"\n", 250 | " for p in self.parameters():\n", 251 | " #print(p.dim()>1)\n", 252 | " if p.dim() > 1:\n", 253 | " xavier_uniform_(p)\n", 254 | " \n", 255 | " \n", 256 | "########################################################################\n", 257 | "#Define a way to plot histograms of the singular value distributions\n", 258 | "\n", 259 | "def plot_jacobians(layers = 128):\n", 260 | " d_TF = list()\n", 261 | " d_ReZero = list()\n", 262 | " d_TF_prenorm = list()\n", 263 | "\n", 264 | " torch.manual_seed(0)\n", 265 | " for i in range(10):\n", 266 | " src = torch.randn(length,1, emb).to(device)\n", 267 | " \n", 268 | " model = DeepEncoder(emb, nhead, nhid, layers, dropout, use_LayerNorm='post', init_resweight = 1\n", 269 | " ).to(device)\n", 270 | " J = io_jacobian_TF(model,src)\n", 271 | " v, d, u = torch.svd(J)\n", 272 | " d_TF.append(d.cpu().numpy().tolist())\n", 273 | "\n", 274 | " model = DeepEncoder(emb, nhead, nhid, layers, dropout, use_LayerNorm='pre', init_resweight = 1\n", 275 | " ).to(device)\n", 276 | " J = io_jacobian_TF(model,src)\n", 277 | " v, d, u = torch.svd(J)\n", 278 | " d_TF_prenorm.append(d.cpu().numpy().tolist())\n", 279 | "\n", 280 | " model = DeepEncoder(emb, nhead, nhid, layers, dropout, use_LayerNorm = False, init_resweight = .1\n", 281 | " ).to(device)\n", 282 | " J = io_jacobian_TF(model,src)\n", 283 | " v, d, u = torch.svd(J)\n", 284 | " d_ReZero.append(d.cpu().numpy().tolist())\n", 285 | "\n", 286 | " d_TF = np.asarray(d_TF).flatten()\n", 287 | " d_TF_prenorm = np.asarray(d_TF_prenorm).flatten()\n", 288 | " d_ReZero = np.asarray(d_ReZero).flatten()\n", 289 | "\n", 290 | " print(\"%0.3f\" % np.mean(d_TF**2),'Vanilla Transformer (post norm)')\n", 291 | " print(\"%3.3f\" % np.mean(d_TF_prenorm**2), 'Transformer with pre-norm')\n", 292 | " print(\"%0.3f\" % np.mean(d_ReZero**2),'ReZero Transformer (resweight = 0.1)')\n", 293 | " \n", 294 | " fig, ax = plt.subplots()\n", 295 | " opacity=.7\n", 296 | " plt.ylim((0,1))\n", 297 | " plt.xlim((-11,4))\n", 298 | " ax.hist(np.log(d_ReZero)/np.log(10), bins = 10, alpha = opacity, label='ReZero', density = True)\n", 299 | " ax.hist(np.log(d_TF_prenorm)/np.log(10), bins = 10, alpha = opacity, label='TF prenorm', density = True)\n", 300 | " ax.hist(np.log(d_TF)/np.log(10), bins = 10, alpha = opacity, label='TF', density = True)\n", 301 | " ax.legend(loc='upper left')\n", 302 | " ax.set_xlabel('log (io-Jacobian singular values)')\n", 303 | " plt.show()\n", 304 | "\n", 305 | " \n", 306 | "########################################################################\n", 307 | "#Use GPU if available\n", 308 | "\n", 309 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 310 | "\n", 311 | "\n", 312 | "########################################################################\n", 313 | "#Plot singular value distributions for 12 and 128 layer networks\n", 314 | "\n", 315 | "emb = 16\n", 316 | "length = 4\n", 317 | "nhead = 4\n", 318 | "nhid = 64\n", 319 | "dropout = 0\n", 320 | "\n", 321 | "layers = 12\n", 322 | "print('Jacobian mean squared singular values for ', layers, ' layer Transformers:')\n", 323 | "plot_jacobians(layers = layers)\n", 324 | "\n", 325 | "layers = 128\n", 326 | "print('Jacobian mean squared singular values for ', layers, ' layer Transformers:')\n", 327 | "plot_jacobians(layers = layers)" 328 | ] 329 | }, 330 | { 331 | "cell_type": "markdown", 332 | "metadata": {}, 333 | "source": [ 334 | "The blue/orange/green histograms correspond to the Rezero/Prenorm/vanilla Transformer architectures respectively. We observe that while for shallow (12 layer) networks all variants have many singular values close to one (i.e. log(1) = 0 in the figure), for deep networks both the vanilla as well as the prenorm versions have many large/small singular values. The ReZero Transformer maintains singular values much closer to one. We will see below how this affects the training dynamics for deep Transformer networks." 335 | ] 336 | }, 337 | { 338 | "cell_type": "markdown", 339 | "metadata": {}, 340 | "source": [ 341 | "## Language modeling\n", 342 | "\n", 343 | "We now use each of the three Transformer archtectures defined above to model the WikiText-2 dataset, following the basic PyTorch tutorial PyTorch tutorial [Sequence-to-Sequence Modeling with nn.Transformer and TorchText\n", 344 | "](https://pytorch.org/tutorials/beginner/transformer_tutorial.html). The model is tasked to predict which word will follow a sequence of words, and we refer to the tutorial for details.\n" 345 | ] 346 | }, 347 | { 348 | "cell_type": "markdown", 349 | "metadata": {}, 350 | "source": [ 351 | "### Define the model\n", 352 | "\n", 353 | "We now define the `TransformerModel` and several functions that load and prepare the data. Finally, we arrive at the function `setup_and_train`, that defines, trains and evaluates the model, and takes the following parameters as input:\n", 354 | "\n", 355 | "`encoder_version` : Defines Transformer architecture: `'ReZero'`, `'pre'`, or `'post'`.\n", 356 | "\n", 357 | "`epochs` : Number of epochs to train\n", 358 | " \n", 359 | "`lr` : Learning rate\n", 360 | "\n", 361 | "`emsize` : Embedding size\n", 362 | "\n", 363 | "`nhid` : Width of feed-forward layers\n", 364 | "\n", 365 | "`nlayers` : Number of TransformerEncoder layers\n", 366 | "\n", 367 | "`nhead` : Number of self attention heads\n", 368 | "\n", 369 | "`dropout` : Dropout." 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "execution_count": 3, 375 | "metadata": {}, 376 | "outputs": [], 377 | "source": [ 378 | "######################################################################\n", 379 | "# Define the model\n", 380 | "\n", 381 | "class TransformerModel(torch.nn.Module):\n", 382 | "\n", 383 | " def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.1, \n", 384 | " encoder_version = 'ReZero'):\n", 385 | " super(TransformerModel, self).__init__()\n", 386 | " from torch.nn import TransformerEncoder, TransformerEncoderLayer\n", 387 | " self.model_type = 'Transformer'\n", 388 | " self.src_mask = None\n", 389 | " self.pos_encoder = PositionalEncoding(ninp, dropout)\n", 390 | " if encoder_version == 'ReZero':\n", 391 | " encoder_layers = ReZeroEncoderLayer(ninp, nhead, nhid, dropout, \n", 392 | " activation = \"relu\", use_LayerNorm = False, init_resweight = 0, \n", 393 | " resweight_trainable = True)\n", 394 | " elif encoder_version == 'pre':\n", 395 | " encoder_layers = ReZeroEncoderLayer(ninp, nhead, nhid, dropout, \n", 396 | " activation = \"relu\", use_LayerNorm = 'pre', init_resweight = 1, \n", 397 | " resweight_trainable = False)\n", 398 | " elif encoder_version == 'post':\n", 399 | " encoder_layers = ReZeroEncoderLayer(ninp, nhead, nhid, dropout, \n", 400 | " activation = \"relu\", use_LayerNorm = 'post', init_resweight = 1, \n", 401 | " resweight_trainable = False)\n", 402 | " self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)\n", 403 | " self.encoder = torch.nn.Embedding(ntoken, ninp)\n", 404 | " self.ninp = ninp\n", 405 | " self.decoder = torch.nn.Linear(ninp, ntoken)\n", 406 | " self._reset_parameters()\n", 407 | " self.init_weights()\n", 408 | " \n", 409 | " def _generate_square_subsequent_mask(self, sz):\n", 410 | " mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)\n", 411 | " mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))\n", 412 | " return mask\n", 413 | "\n", 414 | " def init_weights(self):\n", 415 | " initrange = 0.1\n", 416 | " self.encoder.weight.data.uniform_(-initrange, initrange)\n", 417 | " self.decoder.bias.data.zero_()\n", 418 | " self.decoder.weight.data.uniform_(-initrange, initrange)\n", 419 | " \n", 420 | " def _reset_parameters(self):\n", 421 | " for p in self.parameters():\n", 422 | " if p.dim() > 1:\n", 423 | " xavier_uniform_(p)\n", 424 | "\n", 425 | " def forward(self, src):\n", 426 | " if self.src_mask is None or self.src_mask.size(0) != len(src):\n", 427 | " device = src.device\n", 428 | " mask = self._generate_square_subsequent_mask(len(src)).to(device)\n", 429 | " self.src_mask = mask\n", 430 | "\n", 431 | " src = self.encoder(src) * math.sqrt(self.ninp)\n", 432 | " src = self.pos_encoder(src)\n", 433 | " output = self.transformer_encoder(src, self.src_mask)\n", 434 | " output = self.decoder(output)\n", 435 | " return output\n", 436 | " \n", 437 | "\n", 438 | "######################################################################\n", 439 | "# Positional Encoding\n", 440 | "\n", 441 | "class PositionalEncoding(torch.nn.Module):\n", 442 | "\n", 443 | " def __init__(self, d_model, dropout=0.1, max_len=5000):\n", 444 | " super(PositionalEncoding, self).__init__()\n", 445 | " self.dropout = torch.nn.Dropout(p=dropout)\n", 446 | "\n", 447 | " pe = torch.zeros(max_len, d_model)\n", 448 | " position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n", 449 | " div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))\n", 450 | " pe[:, 0::2] = torch.sin(position * div_term)\n", 451 | " pe[:, 1::2] = torch.cos(position * div_term)\n", 452 | " pe = pe.unsqueeze(0).transpose(0, 1)\n", 453 | " self.register_buffer('pe', pe)\n", 454 | "\n", 455 | " def forward(self, x):\n", 456 | " x = x + self.pe[:x.size(0), :]\n", 457 | " return self.dropout(x)\n", 458 | "\n", 459 | "\n", 460 | "######################################################################\n", 461 | "# Load and batch data\n", 462 | "\n", 463 | "import torchtext\n", 464 | "from torchtext.data.utils import get_tokenizer\n", 465 | "TEXT = torchtext.data.Field(tokenize=get_tokenizer(\"basic_english\"),\n", 466 | " init_token='',\n", 467 | " eos_token='',\n", 468 | " lower=True)\n", 469 | "train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(TEXT)\n", 470 | "TEXT.build_vocab(train_txt)\n", 471 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 472 | "\n", 473 | "def batchify(data, bsz):\n", 474 | " \n", 475 | " data = TEXT.numericalize([data.examples[0].text])\n", 476 | " nbatch = data.size(0) // bsz\n", 477 | " data = data.narrow(0, 0, nbatch * bsz)\n", 478 | " data = data.view(bsz, -1).t().contiguous()\n", 479 | " return data.to(device)\n", 480 | "\n", 481 | "batch_size = 50\n", 482 | "eval_batch_size = 10\n", 483 | "train_data = batchify(train_txt, batch_size)\n", 484 | "val_data = batchify(val_txt, eval_batch_size)\n", 485 | "test_data = batchify(test_txt, eval_batch_size)\n", 486 | "\n", 487 | "\n", 488 | "######################################################################\n", 489 | "# get_batch() function generates the input and target sequence for\n", 490 | "\n", 491 | "bptt = 35\n", 492 | "def get_batch(source, i):\n", 493 | " seq_len = min(bptt, len(source) - 1 - i)\n", 494 | " data = source[i:i+seq_len]\n", 495 | " target = source[i+1:i+1+seq_len].view(-1)\n", 496 | " return data, target\n", 497 | "\n", 498 | "\n", 499 | "######################################################################\n", 500 | "# setup_and_train function calls, trains and evaluates the model\n", 501 | "\n", 502 | "def setup_and_train(epochs, lr, emsize, nhid, nlayers, nhead, dropout, encoder_version, plt_jacobian = True):\n", 503 | " \n", 504 | " ntokens = len(TEXT.vocab.stoi) # the size of vocabulary\n", 505 | " \n", 506 | " ######################################################################\n", 507 | " # Model setup\n", 508 | "\n", 509 | " model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, \n", 510 | " dropout, encoder_version = encoder_version).to(device)\n", 511 | " model.to(device)\n", 512 | "\n", 513 | "\n", 514 | " ######################################################################\n", 515 | " # Define criterion and optimizer\n", 516 | "\n", 517 | " criterion = torch.nn.CrossEntropyLoss()\n", 518 | " optimizer = torch.optim.Adagrad(model.parameters(), lr = lr)\n", 519 | " scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.9)\n", 520 | "\n", 521 | " ######################################################################\n", 522 | " # Define the training\n", 523 | "\n", 524 | " def train():\n", 525 | " model.train() # Turn on the train mode\n", 526 | " total_loss = 0.\n", 527 | " start_time = time.time()\n", 528 | " ntokens = len(TEXT.vocab.stoi)\n", 529 | " for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):\n", 530 | " data, targets = get_batch(train_data, i)\n", 531 | " optimizer.zero_grad()\n", 532 | " output = model(data)\n", 533 | " loss = criterion(output.view(-1, ntokens), targets)\n", 534 | " loss.backward()\n", 535 | " torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)\n", 536 | " optimizer.step()\n", 537 | " total_loss += loss.item()\n", 538 | " log_interval = 200\n", 539 | " if batch % log_interval == 0 and batch > 0:\n", 540 | " cur_loss = total_loss / log_interval\n", 541 | " elapsed = time.time() - start_time\n", 542 | " print('| epoch {:3d} | {:5d}/{:5d} batches | '\n", 543 | " 'lr {:02.2f} | ms/batch {:5.0f} | '\n", 544 | " 'loss {:5.2f} | ppl {:6.0f}'.format(\n", 545 | " epoch, batch, len(train_data) // bptt, scheduler.get_lr()[0],\n", 546 | " elapsed * 1000 / log_interval,cur_loss, math.exp(cur_loss)))\n", 547 | " total_loss = 0\n", 548 | " start_time = time.time()\n", 549 | "\n", 550 | " ######################################################################\n", 551 | " # Define the evaluation\n", 552 | "\n", 553 | " def evaluate(eval_model, data_source):\n", 554 | " eval_model.eval() # Turn on the evaluation mode\n", 555 | " total_loss = 0.\n", 556 | " ntokens = len(TEXT.vocab.stoi)\n", 557 | " with torch.no_grad():\n", 558 | " for i in range(0, data_source.size(0) - 1, bptt):\n", 559 | " data, targets = get_batch(data_source, i)\n", 560 | " output = eval_model(data)\n", 561 | " output_flat = output.view(-1, ntokens)\n", 562 | " total_loss += len(data) * criterion(output_flat, targets).item()\n", 563 | " return total_loss / (len(data_source) - 1)\n", 564 | "\n", 565 | " ######################################################################\n", 566 | " # Train the model\n", 567 | "\n", 568 | " for epoch in range(1, epochs + 1):\n", 569 | " epoch_start_time = time.time()\n", 570 | " train()\n", 571 | " val_loss = evaluate(model, val_data)\n", 572 | " print('-' * 88)\n", 573 | " print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '\n", 574 | " 'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),\n", 575 | " val_loss, math.exp(val_loss)))\n", 576 | " print('-' * 88)\n", 577 | " scheduler.step()\n", 578 | "\n", 579 | " ######################################################################\n", 580 | " #Plot the spectrum of the io-Jacobian singular values after training\n", 581 | " \n", 582 | " if plt_jacobian == True:\n", 583 | " d_ = list()\n", 584 | " for i in range(10):\n", 585 | " src = torch.randn(16,1, emsize).to(device)\n", 586 | " J = io_jacobian_TF(model.transformer_encoder,src)\n", 587 | " v, d, u = torch.svd(J)\n", 588 | " d_.append(d.cpu().numpy().tolist())\n", 589 | " d_ = np.asarray(d_).flatten()\n", 590 | " print('Mean sq singular value of io Jacobian:', \"%0.3f\" % np.mean(d_**2))\n", 591 | " fig, ax = plt.subplots()\n", 592 | " opacity=.7\n", 593 | " plt.ylim((0,1))\n", 594 | " plt.xlim((-11,4))\n", 595 | " ax.hist(np.log(d_)/np.log(10), bins = 10, alpha = opacity, label='Model.transformer_encoder',\n", 596 | " density = True)\n", 597 | " ax.legend(loc='upper left')\n", 598 | " ax.set_xlabel('log (io-Jacobian singular values)')\n", 599 | " plt.show()" 600 | ] 601 | }, 602 | { 603 | "cell_type": "markdown", 604 | "metadata": {}, 605 | "source": [ 606 | "### Train and compare three Transformer architectures:\n", 607 | "\n", 608 | "We can now easily use the function `setup_and_train` to run experiments by changing between Transformer architectures and modifying hyperparameters.\n", 609 | "\n", 610 | "### Vanilla, or post-norm Transformer\n", 611 | "\n", 612 | "First, let us use the `'post'` architecture that corresponds to a vanilla Transformer (i.e. we set `resweight = 1` and it is not trainable). Our experiment uses the Adagrad optimizer and no learning-rate warmup. For a 12 layer transformer network we observe slow training.\n", 613 | "\n", 614 | "Although the mean squared singular values of the Jacobian remain close to one, the histogram shows a large spread. This indicates that some signals get amplified while others are attenuated, which is associated with poor trainng performance." 615 | ] 616 | }, 617 | { 618 | "cell_type": "code", 619 | "execution_count": 14, 620 | "metadata": {}, 621 | "outputs": [ 622 | { 623 | "name": "stdout", 624 | "output_type": "stream", 625 | "text": [ 626 | "| epoch 1 | 200/ 1192 batches | lr 0.01 | ms/batch 49 | loss 7.01 | ppl 1103\n", 627 | "| epoch 1 | 400/ 1192 batches | lr 0.01 | ms/batch 49 | loss 6.89 | ppl 981\n", 628 | "| epoch 1 | 600/ 1192 batches | lr 0.01 | ms/batch 49 | loss 6.89 | ppl 979\n", 629 | "| epoch 1 | 800/ 1192 batches | lr 0.01 | ms/batch 49 | loss 6.89 | ppl 987\n", 630 | "| epoch 1 | 1000/ 1192 batches | lr 0.01 | ms/batch 49 | loss 6.89 | ppl 978\n", 631 | "----------------------------------------------------------------------------------------\n", 632 | "| end of epoch 1 | time: 63.66s | valid loss 6.72 | valid ppl 829.18\n", 633 | "----------------------------------------------------------------------------------------\n", 634 | "| epoch 2 | 200/ 1192 batches | lr 0.01 | ms/batch 49 | loss 6.80 | ppl 897\n", 635 | "| epoch 2 | 400/ 1192 batches | lr 0.01 | ms/batch 48 | loss 6.80 | ppl 900\n", 636 | "| epoch 2 | 600/ 1192 batches | lr 0.01 | ms/batch 48 | loss 6.82 | ppl 915\n", 637 | "| epoch 2 | 800/ 1192 batches | lr 0.01 | ms/batch 48 | loss 6.84 | ppl 931\n", 638 | "| epoch 2 | 1000/ 1192 batches | lr 0.01 | ms/batch 48 | loss 6.83 | ppl 928\n", 639 | "----------------------------------------------------------------------------------------\n", 640 | "| end of epoch 2 | time: 63.01s | valid loss 6.74 | valid ppl 845.14\n", 641 | "----------------------------------------------------------------------------------------\n", 642 | "| epoch 3 | 200/ 1192 batches | lr 0.01 | ms/batch 48 | loss 6.77 | ppl 873\n", 643 | "| epoch 3 | 400/ 1192 batches | lr 0.01 | ms/batch 48 | loss 6.78 | ppl 878\n", 644 | "| epoch 3 | 600/ 1192 batches | lr 0.01 | ms/batch 48 | loss 6.79 | ppl 893\n", 645 | "| epoch 3 | 800/ 1192 batches | lr 0.01 | ms/batch 48 | loss 6.81 | ppl 911\n", 646 | "| epoch 3 | 1000/ 1192 batches | lr 0.01 | ms/batch 48 | loss 6.81 | ppl 910\n", 647 | "----------------------------------------------------------------------------------------\n", 648 | "| end of epoch 3 | time: 63.03s | valid loss 6.76 | valid ppl 864.26\n", 649 | "----------------------------------------------------------------------------------------\n" 650 | ] 651 | } 652 | ], 653 | "source": [ 654 | "######################################################################\n", 655 | "# The model is set up with the hyperparameter below.\n", 656 | "\n", 657 | "encoder_version = 'post' # architecture: 'ReZero', 'pre', or 'post' (vanilla)\n", 658 | "nlayers = 12 # the number of Layers\n", 659 | "lr = .01 # Initial learning rate\n", 660 | "epochs = 3 # The number of epochs\n", 661 | "emsize = 128 # embedding dimension\n", 662 | "nhid = 256 # the dimension of the feedforward network model\n", 663 | "nhead = 8 # the number of heads in self attention\n", 664 | "dropout = 0.1 # the dropout value\n", 665 | "\n", 666 | "setup_and_train(epochs, lr, emsize, nhid, nlayers, nhead, dropout, encoder_version, plt_jacobian = False)" 667 | ] 668 | }, 669 | { 670 | "cell_type": "markdown", 671 | "metadata": {}, 672 | "source": [ 673 | "### Pre-norm Transformer\n", 674 | "\n", 675 | "Next, let us use the `'pre'` architecture that applies the `LayerNorm` before the residual connection. For the 12 layer Transformer network with otherwise identical hyperparameters we observe faster training.\n", 676 | "\n", 677 | "Again, the mean squared singular values of the Jacobian remain close to one, and compared to the 'post' architecture, the histogram shows a smaller spread in the singular values. This coincides with somewhat better trainng performance." 678 | ] 679 | }, 680 | { 681 | "cell_type": "code", 682 | "execution_count": 11, 683 | "metadata": {}, 684 | "outputs": [ 685 | { 686 | "name": "stdout", 687 | "output_type": "stream", 688 | "text": [ 689 | "| epoch 1 | 200/ 1192 batches | lr 0.01 | ms/batch 50 | loss 7.84 | ppl 2536\n", 690 | "| epoch 1 | 400/ 1192 batches | lr 0.01 | ms/batch 50 | loss 6.04 | ppl 419\n", 691 | "| epoch 1 | 600/ 1192 batches | lr 0.01 | ms/batch 50 | loss 5.90 | ppl 364\n", 692 | "| epoch 1 | 800/ 1192 batches | lr 0.01 | ms/batch 50 | loss 5.85 | ppl 348\n", 693 | "| epoch 1 | 1000/ 1192 batches | lr 0.01 | ms/batch 50 | loss 5.77 | ppl 321\n", 694 | "----------------------------------------------------------------------------------------\n", 695 | "| end of epoch 1 | time: 65.10s | valid loss 5.49 | valid ppl 243.25\n", 696 | "----------------------------------------------------------------------------------------\n", 697 | "| epoch 2 | 200/ 1192 batches | lr 0.01 | ms/batch 50 | loss 5.46 | ppl 236\n", 698 | "| epoch 2 | 400/ 1192 batches | lr 0.01 | ms/batch 50 | loss 5.36 | ppl 214\n", 699 | "| epoch 2 | 600/ 1192 batches | lr 0.01 | ms/batch 50 | loss 5.36 | ppl 213\n", 700 | "| epoch 2 | 800/ 1192 batches | lr 0.01 | ms/batch 50 | loss 5.39 | ppl 220\n", 701 | "| epoch 2 | 1000/ 1192 batches | lr 0.01 | ms/batch 50 | loss 5.37 | ppl 215\n", 702 | "----------------------------------------------------------------------------------------\n", 703 | "| end of epoch 2 | time: 65.22s | valid loss 5.34 | valid ppl 208.18\n", 704 | "----------------------------------------------------------------------------------------\n", 705 | "| epoch 3 | 200/ 1192 batches | lr 0.01 | ms/batch 51 | loss 5.19 | ppl 180\n", 706 | "| epoch 3 | 400/ 1192 batches | lr 0.01 | ms/batch 50 | loss 5.13 | ppl 170\n", 707 | "| epoch 3 | 600/ 1192 batches | lr 0.01 | ms/batch 50 | loss 5.14 | ppl 171\n", 708 | "| epoch 3 | 800/ 1192 batches | lr 0.01 | ms/batch 50 | loss 5.20 | ppl 181\n", 709 | "| epoch 3 | 1000/ 1192 batches | lr 0.01 | ms/batch 50 | loss 5.19 | ppl 179\n", 710 | "----------------------------------------------------------------------------------------\n", 711 | "| end of epoch 3 | time: 65.33s | valid loss 5.27 | valid ppl 193.95\n", 712 | "----------------------------------------------------------------------------------------\n" 713 | ] 714 | } 715 | ], 716 | "source": [ 717 | "######################################################################\n", 718 | "# The model is set up with the hyperparameter below.\n", 719 | "\n", 720 | "encoder_version = 'pre' # architecture: 'ReZero', 'pre', or 'post' (vanilla)\n", 721 | "nlayers = 12 # the number of Layers\n", 722 | "lr = .01 # Initial learning rate\n", 723 | "epochs = 3 # The number of epochs\n", 724 | "emsize = 128 # embedding dimension\n", 725 | "nhid = 256 # the dimension of the feedforward network model\n", 726 | "nhead = 8 # the number of heads in self attention\n", 727 | "dropout = 0.1 # the dropout value\n", 728 | "\n", 729 | "setup_and_train(epochs, lr, emsize, nhid, nlayers, nhead, dropout, encoder_version, \n", 730 | " plt_jacobian = False)" 731 | ] 732 | }, 733 | { 734 | "cell_type": "markdown", 735 | "metadata": {}, 736 | "source": [ 737 | "### ReZero Transformer\n", 738 | "\n", 739 | "Finally, we us use the `'ReZero'` architecture that eliminates the `LayerNorm` but set the residual weight initially to zero, and registers it as a trainable parameter. `ReZero` enables the use of a higher learning rate compared to the other architectures. For the 12 layer Transformer network with otherwise identical hyperparameters we observe the fastest training.\n", 740 | "\n", 741 | "\n", 742 | "The mean squared singular values of the Jacobian are very close to one and the histogram shows a very small spread in the singular values. This coincides with the best trainng performance observed in this comparison." 743 | ] 744 | }, 745 | { 746 | "cell_type": "code", 747 | "execution_count": 12, 748 | "metadata": {}, 749 | "outputs": [ 750 | { 751 | "name": "stdout", 752 | "output_type": "stream", 753 | "text": [ 754 | "| epoch 1 | 200/ 1192 batches | lr 0.01 | ms/batch 47 | loss 6.51 | ppl 671\n", 755 | "| epoch 1 | 400/ 1192 batches | lr 0.01 | ms/batch 47 | loss 5.90 | ppl 365\n", 756 | "| epoch 1 | 600/ 1192 batches | lr 0.01 | ms/batch 47 | loss 5.77 | ppl 322\n", 757 | "| epoch 1 | 800/ 1192 batches | lr 0.01 | ms/batch 47 | loss 5.74 | ppl 312\n", 758 | "| epoch 1 | 1000/ 1192 batches | lr 0.01 | ms/batch 47 | loss 5.67 | ppl 291\n", 759 | "----------------------------------------------------------------------------------------\n", 760 | "| end of epoch 1 | time: 60.97s | valid loss 5.43 | valid ppl 228.46\n", 761 | "----------------------------------------------------------------------------------------\n", 762 | "| epoch 2 | 200/ 1192 batches | lr 0.01 | ms/batch 47 | loss 5.29 | ppl 199\n", 763 | "| epoch 2 | 400/ 1192 batches | lr 0.01 | ms/batch 47 | loss 5.24 | ppl 189\n", 764 | "| epoch 2 | 600/ 1192 batches | lr 0.01 | ms/batch 47 | loss 5.25 | ppl 190\n", 765 | "| epoch 2 | 800/ 1192 batches | lr 0.01 | ms/batch 47 | loss 5.29 | ppl 199\n", 766 | "| epoch 2 | 1000/ 1192 batches | lr 0.01 | ms/batch 47 | loss 5.28 | ppl 196\n", 767 | "----------------------------------------------------------------------------------------\n", 768 | "| end of epoch 2 | time: 61.27s | valid loss 5.30 | valid ppl 199.76\n", 769 | "----------------------------------------------------------------------------------------\n", 770 | "| epoch 3 | 200/ 1192 batches | lr 0.01 | ms/batch 48 | loss 5.03 | ppl 153\n", 771 | "| epoch 3 | 400/ 1192 batches | lr 0.01 | ms/batch 47 | loss 5.02 | ppl 151\n", 772 | "| epoch 3 | 600/ 1192 batches | lr 0.01 | ms/batch 47 | loss 5.04 | ppl 154\n", 773 | "| epoch 3 | 800/ 1192 batches | lr 0.01 | ms/batch 47 | loss 5.10 | ppl 165\n", 774 | "| epoch 3 | 1000/ 1192 batches | lr 0.01 | ms/batch 47 | loss 5.10 | ppl 164\n", 775 | "----------------------------------------------------------------------------------------\n", 776 | "| end of epoch 3 | time: 61.36s | valid loss 5.23 | valid ppl 186.93\n", 777 | "----------------------------------------------------------------------------------------\n" 778 | ] 779 | } 780 | ], 781 | "source": [ 782 | "######################################################################\n", 783 | "# The model is set up with the hyperparameter below.\n", 784 | "\n", 785 | "encoder_version = 'ReZero' # architecture: 'ReZero', 'pre', or 'post' (vanilla)\n", 786 | "nlayers = 12 # the number of Layers\n", 787 | "lr = .01 # Initial learning rate\n", 788 | "epochs = 3 # The number of epochs\n", 789 | "emsize = 128 # embedding dimension\n", 790 | "nhid = 256 # the dimension of the feedforward network model\n", 791 | "nhead = 8 # the number of heads in self attention\n", 792 | "dropout = 0.1 # the dropout value\n", 793 | "\n", 794 | "setup_and_train(epochs, lr, emsize, nhid, nlayers, nhead, dropout, encoder_version, \n", 795 | " plt_jacobian = False)" 796 | ] 797 | }, 798 | { 799 | "cell_type": "markdown", 800 | "metadata": {}, 801 | "source": [ 802 | "### 128 layer ReZero Transformer\n", 803 | "\n", 804 | "As promised in the title, we can use the `'ReZero'` architecture to train extremely deep Transformer networks. To render a `128` layer transformer tranable, we again reduce the learning rate (to the Adagrad default value of `lr = 0.01`).\n", 805 | "\n", 806 | "Training this 128 layer network takes about 20 minutes and after three epochs achieves the best validation ppl of around `168`. \n", 807 | "\n", 808 | "Unfortunately, it would require too much memory to quickly evaluate the input-output Jacobian for this deep network." 809 | ] 810 | }, 811 | { 812 | "cell_type": "code", 813 | "execution_count": 7, 814 | "metadata": {}, 815 | "outputs": [ 816 | { 817 | "name": "stdout", 818 | "output_type": "stream", 819 | "text": [ 820 | "| epoch 1 | 200/ 1192 batches | lr 0.01 | ms/batch 322 | loss 6.89 | ppl 985\n", 821 | "| epoch 1 | 400/ 1192 batches | lr 0.01 | ms/batch 319 | loss 5.84 | ppl 344\n", 822 | "| epoch 1 | 600/ 1192 batches | lr 0.01 | ms/batch 328 | loss 5.72 | ppl 304\n", 823 | "| epoch 1 | 800/ 1192 batches | lr 0.01 | ms/batch 318 | loss 5.68 | ppl 292\n", 824 | "| epoch 1 | 1000/ 1192 batches | lr 0.01 | ms/batch 324 | loss 5.59 | ppl 269\n", 825 | "----------------------------------------------------------------------------------------\n", 826 | "| end of epoch 1 | time: 427.25s | valid loss 5.35 | valid ppl 211.38\n", 827 | "----------------------------------------------------------------------------------------\n", 828 | "| epoch 2 | 200/ 1192 batches | lr 0.01 | ms/batch 320 | loss 5.16 | ppl 174\n", 829 | "| epoch 2 | 400/ 1192 batches | lr 0.01 | ms/batch 322 | loss 5.12 | ppl 167\n", 830 | "| epoch 2 | 600/ 1192 batches | lr 0.01 | ms/batch 321 | loss 5.11 | ppl 166\n", 831 | "| epoch 2 | 800/ 1192 batches | lr 0.01 | ms/batch 319 | loss 5.15 | ppl 173\n", 832 | "| epoch 2 | 1000/ 1192 batches | lr 0.01 | ms/batch 322 | loss 5.13 | ppl 170\n", 833 | "----------------------------------------------------------------------------------------\n", 834 | "| end of epoch 2 | time: 422.47s | valid loss 5.19 | valid ppl 179.56\n", 835 | "----------------------------------------------------------------------------------------\n", 836 | "| epoch 3 | 200/ 1192 batches | lr 0.01 | ms/batch 326 | loss 4.84 | ppl 126\n", 837 | "| epoch 3 | 400/ 1192 batches | lr 0.01 | ms/batch 320 | loss 4.83 | ppl 126\n", 838 | "| epoch 3 | 600/ 1192 batches | lr 0.01 | ms/batch 323 | loss 4.85 | ppl 128\n", 839 | "| epoch 3 | 800/ 1192 batches | lr 0.01 | ms/batch 321 | loss 4.92 | ppl 136\n", 840 | "| epoch 3 | 1000/ 1192 batches | lr 0.01 | ms/batch 320 | loss 4.91 | ppl 135\n", 841 | "----------------------------------------------------------------------------------------\n", 842 | "| end of epoch 3 | time: 424.36s | valid loss 5.12 | valid ppl 167.97\n", 843 | "----------------------------------------------------------------------------------------\n" 844 | ] 845 | } 846 | ], 847 | "source": [ 848 | "######################################################################\n", 849 | "# The model is set up with the hyperparameter below.\n", 850 | "\n", 851 | "encoder_version = 'ReZero' # architecture: 'ReZero', 'pre', or 'post' (vanilla)\n", 852 | "nlayers = 128 # the number of Layers\n", 853 | "lr = .01 # Initial learning rate\n", 854 | "epochs = 3 # The number of epochs\n", 855 | "emsize = 128 # embedding dimension\n", 856 | "nhid = 256 # the dimension of the feedforward network model\n", 857 | "nhead = 8 # the number of heads in self attention\n", 858 | "dropout = 0.1 # the dropout value\n", 859 | "\n", 860 | "setup_and_train(epochs, lr, emsize, nhid, nlayers, nhead, dropout, encoder_version, plt_jacobian = False)" 861 | ] 862 | } 863 | ], 864 | "metadata": { 865 | "kernelspec": { 866 | "display_name": "Python 3", 867 | "language": "python", 868 | "name": "python3" 869 | }, 870 | "language_info": { 871 | "codemirror_mode": { 872 | "name": "ipython", 873 | "version": 3 874 | }, 875 | "file_extension": ".py", 876 | "mimetype": "text/x-python", 877 | "name": "python", 878 | "nbconvert_exporter": "python", 879 | "pygments_lexer": "ipython3", 880 | "version": "3.7.7" 881 | } 882 | }, 883 | "nbformat": 4, 884 | "nbformat_minor": 2 885 | } 886 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.5 2 | matplotlib 3 | torch==1.6.0 4 | torchvision==0.7.0 5 | torchtext==0.6.0 6 | --------------------------------------------------------------------------------