├── README.md ├── agd.py ├── architecture ├── __init__.py ├── fcn.py ├── resnet.py └── vgg.py ├── data ├── __init__.py ├── cifar10.py ├── cifar100.py ├── imagenet.py └── mnist.py ├── latex ├── algorithm │ ├── agd.py │ └── agd.tex ├── figures │ ├── apbs.tex │ ├── comparison-table.tex │ ├── exp1.tex │ ├── exp2.tex │ ├── exp3.tex │ ├── exp4.tex │ ├── pdf │ │ ├── maj-min.pdf │ │ ├── plot0.pdf │ │ ├── plot1_full.pdf │ │ ├── plot2.pdf │ │ ├── plot4.pdf │ │ └── plots3_0.pdf │ ├── schematic.tex │ ├── showcase.tex │ └── theory-table.tex ├── format.tex ├── macros.tex ├── main.tex ├── minted-output │ ├── default-pyg-prefix.pygstyle │ ├── default.pygstyle │ └── listing1.pygtex ├── packages.tex ├── refs.bib ├── section │ ├── 00-abstract.tex │ ├── 01-intro.tex │ ├── 03-sketch.tex │ ├── 04-bounds.tex │ ├── 07-experiments.tex │ ├── 08-discuss.tex │ ├── 98-ack.tex │ └── 99-appendix.tex └── tmlr │ ├── LICENSE │ ├── fancyhdr.sty │ ├── tmlr.bst │ └── tmlr.sty ├── main.py └── supercloud ├── multi-cifar.sh ├── multi-imagenet.sh ├── single-cifar.sh └── single-imagenet.sh /README.md: -------------------------------------------------------------------------------- 1 |
6 |
7 |
10 | Jeremy Bernstein* · 11 | Chris Mingard* · 12 | Kevin Huang · 13 | Navid Azizan · 14 | Yisong Yue 15 |
16 | 17 | ## Getting started 18 | 19 | Install PyTorch and a GPU, and run: 20 | ```bash 21 | python main.py 22 | ``` 23 | Command line arguments are: 24 | ```bash 25 | --arch # options: fcn, vgg, resnet18, resnet50 26 | --dataset # options: cifar10, cifar100, mnist, imagenet 27 | --train_bs # training batch size 28 | --test_bs # testing batch size 29 | --epochs # number of training epochs 30 | --depth # number of layers for fcn 31 | --width # hidden layer width for fcn 32 | --distribute # train over multiple gpus (for imagenet) 33 | --gain # experimental acceleration of training 34 | ``` 35 | No training hyperparameters are neccessary. Optionally, you can try `--gain 10.0` which we have found can accelerate training. Chris is maintaining a [separate repository](https://github.com/C1510/agd_exp) with some more experimental features. 36 | 37 | ## Repository structure 38 | . 39 | ├── architecture/ # network architectures 40 | ├── data/ # datasets and preprocessing 41 | ├── latex/ # source code for the paper 42 | ├── supercloud/ # mit supercloud run files 43 | ├── agd.py # automatic gradient descent 44 | ├── main.py # entrypoint to training 45 | 46 | ## Description of the method 47 | 48 | For the $k\text{th}$ weight matrix $W_k$ in $\mathbb{R}^{d_k \times d_{k-1}}$ and square or cross-entropy loss $\mathcal{L}$: 49 | - initial weights are drawn from the uniform measure over orthogonal matrices, and then scaled by $\sqrt{d_k / d_{k-1}}$. 50 | - weights are updated according to: 51 | ```math 52 | W_k \gets W_k - \frac{\eta}{L} \cdot \sqrt{\tfrac{d_k}{d_{k-1}}} \cdot \frac{ \nabla_{W_k} \mathcal{L}}{\Vert{ \nabla_{W_k}\mathcal{L}}\Vert _F}. 53 | ``` 54 | $L$ measures the depth of the network, and the learning rate $\eta$ is set automatically via: 55 | 56 | - $G \gets \frac{1}{L} \sum_{k\in\{1...L\}} \sqrt{\tfrac{d_k}{d_{k-1}}}\cdot \Vert\nabla_{W_k} \mathcal{L}\Vert_F$; 57 | - $\eta \gets \log\Big( \tfrac{1+\sqrt{1+4G}}{2}\Big)$. 58 | 59 | This procedure is slightly modified for convolutional layers. 60 | 61 | ## Citation 62 | 63 | If you find AGD helpful and you'd like to cite the paper, we'd appreciate it: 64 | 65 | ```bibtex 66 | @article{agd-2023, 67 | author = {Jeremy Bernstein and Chris Mingard and Kevin Huang and Navid Azizan and Yisong Yue}, 68 | title = {{A}utomatic {G}radient {D}escent: {D}eep {L}earning without {H}yperparameters}, 69 | journal = {arXiv:2304.05187}, 70 | year = 2023 71 | } 72 | ``` 73 | 74 | ## References 75 | 76 | Our paper titled `Automatic Gradient Descent: Deep Learning without Hyperparameters` is available [at this link](https://arxiv.org/abs/2304.05187). The derivation of AGD is a refined version of the majorise-minimise analysis given in my [PhD thesis](https://arxiv.org/abs/2210.10101) `Optimisation & Generalisation in Networks of Neurons`, and was worked out in close collaboration with Chris and Kevin. In turn, this develops the perturbation analysis from [our earlier paper](https://arxiv.org/abs/2002.03432) `On the Distance between two Neural Networks and the Stability of Learning` with a couple insights from [Greg Yang and Edward Hu's](https://arxiv.org/abs/2011.14522) `Feature Learning in Infinite-Width Neural Networks` thrown in for good measure. 77 | 78 | ## Acknowledgements 79 | 80 | Some architecture definitions were adapted from [kuangliu/pytorch-cifar](https://github.com/kuangliu/pytorch-cifar). 81 | 82 | ## License 83 | 84 | We are making AGD available under the [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/) license. 85 | -------------------------------------------------------------------------------- /agd.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from torch.optim.optimizer import Optimizer 5 | from torch.nn.init import orthogonal_ 6 | 7 | def singular_value(p): 8 | sv = math.sqrt(p.shape[0] / p.shape[1]) 9 | if p.dim() == 4: 10 | sv /= math.sqrt(p.shape[2] * p.shape[3]) 11 | return sv 12 | 13 | class AGD(Optimizer): 14 | 15 | def __init__(self, net, gain=1.0): 16 | 17 | self.net = net 18 | self.depth = len(list(net.parameters())) 19 | self.gain = gain 20 | 21 | for p in self.net.parameters(): 22 | if p.dim() == 1: raise Exception("Biases are not supported.") 23 | 24 | super().__init__(net.parameters(), defaults=dict()) 25 | 26 | @torch.no_grad() 27 | def init_weights(self): 28 | 29 | for p in self.net.parameters(): 30 | if p.dim() == 2: orthogonal_(p) 31 | if p.dim() == 4: 32 | for kx in range(p.shape[2]): 33 | for ky in range(p.shape[3]): 34 | orthogonal_(p[:,:,kx,ky]) 35 | p *= singular_value(p) 36 | 37 | @torch.no_grad() 38 | def step(self): 39 | 40 | G = 0 41 | for p in self.net.parameters(): 42 | G += singular_value(p) * p.grad.norm(dim=(0,1)).sum() 43 | G /= self.depth 44 | 45 | log = math.log(0.5 * (1 + math.sqrt(1 + 4*G))) 46 | 47 | for p in self.net.parameters(): 48 | factor = singular_value(p) / p.grad.norm(dim=(0,1), keepdim=True) 49 | p -= self.gain * log / self.depth * torch.nan_to_num(factor) * p.grad 50 | 51 | return log 52 | -------------------------------------------------------------------------------- /architecture/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxbz/agd/1f8454d24335d55aeffa1249d34d84d00d37d5d4/architecture/__init__.py -------------------------------------------------------------------------------- /architecture/fcn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FCN(nn.Module): 6 | def __init__(self, depth, width, input_dim, output_dim, bias=False): 7 | super(FCN, self).__init__() 8 | 9 | self.initial = nn.Linear(input_dim, width, bias=bias) 10 | self.layers = nn.ModuleList([nn.Linear(width, width, bias=bias) for _ in range(depth-2)]) 11 | self.final = nn.Linear(width, output_dim, bias=bias) 12 | 13 | def forward(self, x): 14 | x = x.view(x.shape[0],-1) 15 | 16 | x = self.initial(x) 17 | x = F.relu(x) * math.sqrt(2) 18 | 19 | for layer in self.layers: 20 | x = layer(x) 21 | x = F.relu(x) * math.sqrt(2) 22 | 23 | return self.final(x) 24 | -------------------------------------------------------------------------------- /architecture/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from functools import partial 7 | from typing import Any, Callable, List, Optional, Type, Union 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch import Tensor 12 | 13 | ### For CIFAR-10 14 | 15 | def PreActResNet18(output_dim): return PreActResNet(PreActBlock, [2,2,2,2], output_dim) 16 | def PreActResNet34(output_dim): return PreActResNet(PreActBlock, [3,4,6,3], output_dim) 17 | def PreActResNet50(output_dim): return PreActResNet(PreActBottleneck, [3,4,6,3], output_dim) 18 | def PreActResNet101(output_dim): return PreActResNet(PreActBottleneck, [3,4,23,3], output_dim) 19 | def PreActResNet152(output_dim): return PreActResNet(PreActBottleneck, [3,8,36,3], output_dim) 20 | 21 | class PreActBlock(nn.Module): 22 | '''Pre-activation version of the BasicBlock.''' 23 | expansion = 1 24 | 25 | def __init__(self, in_planes, planes, stride=1): 26 | super(PreActBlock, self).__init__() 27 | self.bn1 = nn.BatchNorm2d(in_planes, affine=False) 28 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 29 | self.bn2 = nn.BatchNorm2d(planes, affine=False) 30 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 31 | 32 | if stride != 1 or in_planes != self.expansion*planes: 33 | self.shortcut = nn.Sequential( 34 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 35 | ) 36 | 37 | def forward(self, x): 38 | out = F.relu(self.bn1(x)) 39 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 40 | out = self.conv1(out) 41 | out = self.conv2(F.relu(self.bn2(out))) 42 | out += shortcut 43 | return out 44 | 45 | 46 | class PreActBottleneck(nn.Module): 47 | '''Pre-activation version of the original Bottleneck module.''' 48 | expansion = 4 49 | 50 | def __init__(self, in_planes, planes, stride=1): 51 | super(PreActBottleneck, self).__init__() 52 | self.bn1 = nn.BatchNorm2d(in_planes, affine=False) 53 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 54 | self.bn2 = nn.BatchNorm2d(planes, affine=False) 55 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 56 | self.bn3 = nn.BatchNorm2d(planes, affine=False) 57 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 58 | 59 | if stride != 1 or in_planes != self.expansion*planes: 60 | self.shortcut = nn.Sequential( 61 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 62 | ) 63 | 64 | def forward(self, x): 65 | out = F.relu(self.bn1(x)) 66 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 67 | out = self.conv1(out) 68 | out = self.conv2(F.relu(self.bn2(out))) 69 | out = self.conv3(F.relu(self.bn3(out))) 70 | out += shortcut 71 | return out 72 | 73 | 74 | class PreActResNet(nn.Module): 75 | def __init__(self, block, num_blocks, num_classes=10): 76 | super(PreActResNet, self).__init__() 77 | self.in_planes = 64 78 | 79 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 80 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 81 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 82 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 83 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 84 | self.linear = nn.Linear(512*block.expansion, num_classes, bias=False) 85 | 86 | def _make_layer(self, block, planes, num_blocks, stride): 87 | strides = [stride] + [1]*(num_blocks-1) 88 | layers = [] 89 | for stride in strides: 90 | layers.append(block(self.in_planes, planes, stride)) 91 | self.in_planes = planes * block.expansion 92 | return nn.Sequential(*layers) 93 | 94 | def forward(self, x): 95 | out = self.conv1(x) 96 | out = self.layer1(out) 97 | out = self.layer2(out) 98 | out = self.layer3(out) 99 | out = self.layer4(out) 100 | out = F.avg_pool2d(out, 4) 101 | out = out.view(out.size(0), -1) 102 | out = self.linear(out) 103 | return out 104 | 105 | ### For ImageNet 106 | 107 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 108 | """3x3 convolution with padding""" 109 | return nn.Conv2d( 110 | in_planes, 111 | out_planes, 112 | kernel_size=3, 113 | stride=stride, 114 | padding=dilation, 115 | groups=groups, 116 | bias=False, 117 | dilation=dilation, 118 | ) 119 | 120 | 121 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 122 | """1x1 convolution""" 123 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 124 | 125 | 126 | class BasicBlock(nn.Module): 127 | expansion: int = 1 128 | 129 | def __init__( 130 | self, 131 | inplanes: int, 132 | planes: int, 133 | stride: int = 1, 134 | downsample: Optional[nn.Module] = None, 135 | groups: int = 1, 136 | base_width: int = 64, 137 | dilation: int = 1, 138 | norm_layer: Optional[Callable[..., nn.Module]] = None, 139 | affine=False 140 | ) -> None: 141 | super().__init__() 142 | self.affine = affine 143 | if norm_layer is None: 144 | norm_layer = nn.BatchNorm2d 145 | if groups != 1 or base_width != 64: 146 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 147 | if dilation > 1: 148 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 149 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 150 | self.conv1 = conv3x3(inplanes, planes, stride) 151 | self.bn1 = norm_layer(planes, affine=self.affine) 152 | self.relu = nn.ReLU(inplace=True) 153 | self.conv2 = conv3x3(planes, planes) 154 | self.bn2 = norm_layer(planes, affine=affine) 155 | self.downsample = downsample 156 | self.stride = stride 157 | 158 | def forward(self, x: Tensor) -> Tensor: 159 | identity = x 160 | 161 | out = self.conv1(x) 162 | out = self.bn1(out) 163 | out = self.relu(out) 164 | 165 | out = self.conv2(out) 166 | out = self.bn2(out) 167 | 168 | if self.downsample is not None: 169 | identity = self.downsample(x) 170 | 171 | out += identity 172 | out = self.relu(out) 173 | 174 | return out 175 | 176 | 177 | class Bottleneck(nn.Module): 178 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 179 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 180 | # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385. 181 | # This variant is also known as ResNet V1.5 and improves accuracy according to 182 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 183 | 184 | expansion: int = 4 185 | 186 | def __init__( 187 | self, 188 | inplanes: int, 189 | planes: int, 190 | stride: int = 1, 191 | downsample: Optional[nn.Module] = None, 192 | groups: int = 1, 193 | base_width: int = 64, 194 | dilation: int = 1, 195 | norm_layer: Optional[Callable[..., nn.Module]] = None, 196 | affine=False 197 | ) -> None: 198 | super().__init__() 199 | if norm_layer is None: 200 | norm_layer = nn.BatchNorm2d 201 | width = int(planes * (base_width / 64.0)) * groups 202 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 203 | self.conv1 = conv1x1(inplanes, width) 204 | self.bn1 = norm_layer(width, affine=affine) 205 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 206 | self.bn2 = norm_layer(width, affine=affine) 207 | self.conv3 = conv1x1(width, planes * self.expansion) 208 | self.bn3 = norm_layer(planes * self.expansion, affine=affine) 209 | self.relu = nn.ReLU(inplace=True) 210 | self.downsample = downsample 211 | self.stride = stride 212 | 213 | def forward(self, x: Tensor) -> Tensor: 214 | identity = x 215 | 216 | out = self.conv1(x) 217 | out = self.bn1(out) 218 | out = self.relu(out) 219 | 220 | out = self.conv2(out) 221 | out = self.bn2(out) 222 | out = self.relu(out) 223 | 224 | out = self.conv3(out) 225 | out = self.bn3(out) 226 | 227 | if self.downsample is not None: 228 | identity = self.downsample(x) 229 | 230 | out += identity 231 | out = self.relu(out) 232 | 233 | return out 234 | 235 | 236 | class ResNet(nn.Module): 237 | def __init__( 238 | self, 239 | block: Type[Union[BasicBlock, Bottleneck]], 240 | layers: List[int], 241 | num_classes: int = 1000, 242 | zero_init_residual: bool = False, 243 | groups: int = 1, 244 | width_per_group: int = 64, 245 | replace_stride_with_dilation: Optional[List[bool]] = None, 246 | norm_layer: Optional[Callable[..., nn.Module]] = None, 247 | bias=False, 248 | affine=False 249 | ) -> None: 250 | 251 | self.bias=bias 252 | self.affine=affine 253 | super().__init__() 254 | if norm_layer is None: 255 | norm_layer = nn.BatchNorm2d 256 | self._norm_layer = norm_layer 257 | 258 | self.inplanes = 64 259 | self.dilation = 1 260 | if replace_stride_with_dilation is None: 261 | # each element in the tuple indicates if we should replace 262 | # the 2x2 stride with a dilated convolution instead 263 | replace_stride_with_dilation = [False, False, False] 264 | if len(replace_stride_with_dilation) != 3: 265 | raise ValueError( 266 | "replace_stride_with_dilation should be None " 267 | f"or a 3-element tuple, got {replace_stride_with_dilation}" 268 | ) 269 | self.groups = groups 270 | self.base_width = width_per_group 271 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 272 | self.bn1 = norm_layer(self.inplanes, affine=self.affine) 273 | self.relu = nn.ReLU(inplace=True) 274 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 275 | self.layer1 = self._make_layer(block, 64, layers[0]) 276 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) 277 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) 278 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) 279 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 280 | self.fc = nn.Linear(512 * block.expansion, num_classes, bias=self.bias) 281 | self.out_dim = num_classes 282 | 283 | for m in self.modules(): 284 | if isinstance(m, nn.Conv2d): 285 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 286 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 287 | if self.affine: 288 | nn.init.constant_(m.weight, 1) 289 | nn.init.constant_(m.bias, 0) 290 | 291 | # Zero-initialize the last BN in each residual branch, 292 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 293 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 294 | if zero_init_residual: 295 | for m in self.modules(): 296 | if isinstance(m, Bottleneck) and m.bn3.weight is not None: 297 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 298 | elif isinstance(m, BasicBlock) and m.bn2.weight is not None: 299 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 300 | 301 | def _make_layer( 302 | self, 303 | block: Type[Union[BasicBlock, Bottleneck]], 304 | planes: int, 305 | blocks: int, 306 | stride: int = 1, 307 | dilate: bool = False, 308 | ) -> nn.Sequential: 309 | norm_layer = self._norm_layer 310 | downsample = None 311 | previous_dilation = self.dilation 312 | if dilate: 313 | self.dilation *= stride 314 | stride = 1 315 | if stride != 1 or self.inplanes != planes * block.expansion: 316 | downsample = nn.Sequential( 317 | conv1x1(self.inplanes, planes * block.expansion, stride), 318 | norm_layer(planes * block.expansion, affine=False), 319 | ) 320 | 321 | layers = [] 322 | layers.append( 323 | block( 324 | self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer 325 | ) 326 | ) 327 | self.inplanes = planes * block.expansion 328 | for _ in range(1, blocks): 329 | layers.append( 330 | block( 331 | self.inplanes, 332 | planes, 333 | groups=self.groups, 334 | base_width=self.base_width, 335 | dilation=self.dilation, 336 | norm_layer=norm_layer, 337 | affine=self.affine 338 | ) 339 | ) 340 | 341 | return nn.Sequential(*layers) 342 | 343 | def _forward_impl(self, x: Tensor) -> Tensor: 344 | # See note [TorchScript super()] 345 | x = self.conv1(x) 346 | x = self.bn1(x) 347 | x = self.relu(x) 348 | x = self.maxpool(x) 349 | 350 | x = self.layer1(x) 351 | x = self.layer2(x) 352 | x = self.layer3(x) 353 | x = self.layer4(x) 354 | 355 | x = self.avgpool(x) 356 | x = torch.flatten(x, 1) 357 | x = self.fc(x) 358 | 359 | return x 360 | 361 | def forward(self, x: Tensor) -> Tensor: 362 | return self._forward_impl(x) 363 | 364 | 365 | def _resnet( 366 | block: Type[Union[BasicBlock, Bottleneck]], 367 | layers: List[int], 368 | weights: Optional, 369 | progress: bool, 370 | bias=False, 371 | affine=False, 372 | num_classes=10, 373 | ) -> ResNet: 374 | 375 | model = ResNet(block, layers, bias=bias, affine=affine, num_classes=num_classes) 376 | 377 | return model 378 | 379 | 380 | def resnet18(num_classes, weights: Optional = None, progress: bool = True, bias=False, affine=False) -> ResNet: 381 | return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, bias=bias, affine=affine, num_classes=num_classes) 382 | 383 | 384 | def resnet34(weights: Optional = None, progress: bool = True, bias=False, affine=False) -> ResNet: 385 | return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, bias=bias, affine=affine) 386 | 387 | 388 | def resnet50(num_classes,weights: Optional = None, progress: bool = True, bias=False, affine=False) -> ResNet: 389 | return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, bias=bias, affine=affine, num_classes=num_classes) 390 | 391 | 392 | def resnet101(weights: Optional = None, progress: bool = True, bias=False, affine=False) -> ResNet: 393 | return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, bias=bias, affine=affine) 394 | 395 | 396 | def resnet152(weights: Optional = None, progress: bool = True, bias=False, affine=False) -> ResNet: 397 | return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, bias=bias, affine=affine) 398 | -------------------------------------------------------------------------------- /architecture/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | def VGG11(output_dim): return VGG_CIFAR([64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], output_dim) 4 | def VGG13(output_dim): return VGG_CIFAR([64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], output_dim) 5 | def VGG16(output_dim): return VGG_CIFAR([64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], output_dim) 6 | def VGG19(output_dim): return VGG_CIFAR([64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], output_dim) 7 | 8 | class VGG_CIFAR(nn.Module): 9 | def __init__(self, vgg_cfg, output_dim=10, bias=False, affine=False): 10 | super(VGG_CIFAR, self).__init__() 11 | self.bias = bias 12 | self.affine = affine 13 | self.features = self._make_layers(vgg_cfg) 14 | self.classifier = nn.Linear(512, output_dim, bias=self.bias) 15 | 16 | def forward(self, x): 17 | out = self.features(x) 18 | out = out.view(out.size(0), -1) 19 | out = self.classifier(out) 20 | return out 21 | 22 | def _make_layers(self, cfg): 23 | layers = [] 24 | in_channels = 3 25 | for x in cfg: 26 | if x == 'M': 27 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 28 | else: 29 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1, bias=self.bias), 30 | nn.BatchNorm2d(x, affine=self.affine), 31 | nn.ReLU(inplace=True)] 32 | in_channels = x 33 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 34 | return nn.Sequential(*layers) 35 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxbz/agd/1f8454d24335d55aeffa1249d34d84d00d37d5d4/data/__init__.py -------------------------------------------------------------------------------- /data/cifar10.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms 2 | 3 | def getData(): 4 | 5 | mean = (0.4914, 0.4822, 0.4465) 6 | std = (0.2023, 0.1994, 0.2010) 7 | 8 | transform_train = transforms.Compose([ 9 | transforms.RandomCrop(32, padding=4), 10 | transforms.RandomHorizontalFlip(), 11 | transforms.ToTensor(), 12 | transforms.Normalize(mean, std), 13 | ]) 14 | 15 | transform_test = transforms.Compose([ 16 | transforms.ToTensor(), 17 | transforms.Normalize(mean, std), 18 | ]) 19 | 20 | trainset = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train) 21 | testset = datasets.CIFAR10('./data', train=False, download=True, transform=transform_test) 22 | 23 | input_dim = 3*32*32 24 | output_dim = 10 25 | 26 | return trainset, testset, input_dim, output_dim 27 | -------------------------------------------------------------------------------- /data/cifar100.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms 2 | 3 | def getData(): 4 | 5 | mean = (0.5071, 0.4867, 0.4408) 6 | std = (0.2675, 0.2565, 0.2761) 7 | 8 | transform_train = transforms.Compose([ 9 | transforms.RandomCrop(32, padding=4), 10 | transforms.RandomHorizontalFlip(), 11 | transforms.ToTensor(), 12 | transforms.Normalize(mean, std), 13 | ]) 14 | 15 | transform_test = transforms.Compose([ 16 | transforms.ToTensor(), 17 | transforms.Normalize(mean, std), 18 | ]) 19 | 20 | trainset = datasets.CIFAR100('./data', train=True, download=True, transform=transform_train) 21 | testset = datasets.CIFAR100('./data', train=False, download=True, transform=transform_test) 22 | 23 | input_dim = 3*32*32 24 | output_dim = 100 25 | 26 | return trainset, testset, input_dim, output_dim 27 | -------------------------------------------------------------------------------- /data/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torchvision import datasets, transforms 3 | 4 | def getData(): 5 | 6 | mean = (0.485, 0.456, 0.406) 7 | std = (0.229, 0.224, 0.225) 8 | 9 | traindir = os.path.join(os.getenv('IMAGENET_PATH'), "train") 10 | valdir = os.path.join(os.getenv('IMAGENET_PATH'), "val") 11 | 12 | trainset = datasets.ImageFolder( 13 | traindir, 14 | transforms.Compose([ 15 | transforms.RandomResizedCrop(224), 16 | transforms.RandomHorizontalFlip(), 17 | transforms.ToTensor(), 18 | transforms.Normalize(mean, std), 19 | ])) 20 | 21 | testset = datasets.ImageFolder( 22 | valdir, 23 | transforms.Compose([ 24 | transforms.Resize(256), 25 | transforms.CenterCrop(224), 26 | transforms.ToTensor(), 27 | transforms.Normalize(mean, std), 28 | ])) 29 | 30 | input_dim = 3*224*224 31 | output_dim = 1000 32 | 33 | return trainset, testset, input_dim, output_dim 34 | -------------------------------------------------------------------------------- /data/mnist.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms 2 | 3 | def getData(): 4 | mean = (0.1307,) 5 | std = (0.3081,) 6 | 7 | transform = transforms.Compose([ 8 | transforms.ToTensor(), 9 | transforms.Normalize(mean, std) 10 | ]) 11 | 12 | trainset = datasets.MNIST('./data', train=True, download=True, transform=transform) 13 | testset = datasets.MNIST('./data', train=False, download=True, transform=transform) 14 | 15 | input_dim = 1*28*28 16 | output_dim = 10 17 | 18 | return trainset, testset, input_dim, output_dim 19 | -------------------------------------------------------------------------------- /latex/algorithm/agd.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from torch.nn.init import orthogonal_ 5 | 6 | def singular_value(p): 7 | sv = math.sqrt(p.shape[0] / p.shape[1]) 8 | if p.dim() == 4: 9 | sv /= math.sqrt(p.shape[2] * p.shape[3]) 10 | return sv 11 | 12 | class AGD: 13 | @torch.no_grad() 14 | def __init__(self, net, gain=1.0): 15 | 16 | self.net = net 17 | self.depth = len(list(net.parameters())) 18 | self.gain = gain 19 | 20 | for p in net.parameters(): 21 | if p.dim() == 1: raise Exception("Biases are not supported.") 22 | if p.dim() == 2: orthogonal_(p) 23 | if p.dim() == 4: 24 | for kx in range(p.shape[2]): 25 | for ky in range(p.shape[3]): 26 | orthogonal_(p[:,:,kx,ky]) 27 | p *= singular_value(p) 28 | 29 | @torch.no_grad() 30 | def step(self): 31 | 32 | G = 0 33 | for p in self.net.parameters(): 34 | G += singular_value(p) * p.grad.norm(dim=(0,1)).sum() 35 | G /= self.depth 36 | 37 | log = math.log(0.5 * (1 + math.sqrt(1 + 4*G))) 38 | 39 | for p in self.net.parameters(): 40 | factor = singular_value(p) / p.grad.norm(dim=(0,1), keepdim=True) 41 | p -= self.gain * log / self.depth * factor * p.grad 42 | -------------------------------------------------------------------------------- /latex/algorithm/agd.tex: -------------------------------------------------------------------------------- 1 | \begin{algorithm}[t] 2 | \caption{\captiontitle{Automatic gradient descent.} The matrix $\mW_k$ in $\R^{d_k \times d_{k-1}}$ is the weight matrix at layer $k$. The gradient $\nabla_{\mW_k} \el$ is with respect to the objective $\el$ evaluated on a mini-batch $B$ of training samples.}\label{alg:agd} 3 | \begin{algorithmic} 4 | \tt 5 | \setstretch{1.8}\vspace{0.5em} 6 | \DEF[initialise\_weights] 7 | \FOR{layer $k$ in $\{1,...,L\}$:} 8 | \STATE $\mW_k \sim \uniform(\mathtt{orthogonal}(d_k,d_{k-1}))$ \WCOMMENT{sample a semi-orthogonal matrix} 9 | \STATE $\mW_k \gets \mW_k \cdot \sqrt{\frac{d_k}{d_{k-1}}}$ \WCOMMENT{rescale its singular values} 10 | \ENDFOR 11 | \ENDDEF 12 | \vspace{-1.6ex}\DEF[update\_weights] 13 | \STATE $G \gets \frac{1}{L}\sum_{l=1}^L \norm{\nabla_{\mW_k} \el}_F \cdot \sqrt{\frac{d_k}{d_{k-1}}}$ \WCOMMENT{get gradient summary} 14 | \STATE $\smash{\eta \gets \log\frac{1 + \sqrt{1+ 4G }}{2}}$ \WCOMMENT{set automatic learning rate} 15 | \FOR{layer $k$ in $\{1,...,L\}$:} 16 | \STATE $\mW_k \gets \mW_k - \frac{\eta}{L} \cdot \frac{\nabla_{\mW_k} \el}{\norm{\nabla_{\mW_k} \el}_F} \cdot \sqrt{\frac{d_k}{d_{k-1}}}$ \WCOMMENT{update weights} 17 | \ENDFOR 18 | \ENDDEF 19 | \setstretch{1.0} 20 | \end{algorithmic} 21 | \end{algorithm} -------------------------------------------------------------------------------- /latex/figures/apbs.tex: -------------------------------------------------------------------------------- 1 | \begin{figure} 2 | \centering 3 | \begin{tikzpicture} 4 | [thick, block/.style={draw, minimum size=1.1cm}] 5 | 6 | \node [block,fill=red!30] (aa) {$\el$}; 7 | \node [block,fill=green!30] (bb) [right =of aa] {$\ell$}; 8 | 9 | \node [block,fill=orange!30] (a) [right =of bb] {$\vf$}; 10 | \node [block,fill=blue!30] (b) [right =of a] {$\mW_L$}; 11 | \node [block,fill=blue!30] (c) [right =of b] {$\mW_2$}; 12 | \node [block,fill=blue!30] (d) [right =of c] {$\mW_1$}; 13 | 14 | \node [block,fill=orange!10] (v) [below =of a] {$\Delta \vf$}; 15 | \node [block,fill=blue!10] (x) [above =of b] {$\Delta \mW_L$}; 16 | \node [block,fill=blue!10] (y) [above =of c] {$\Delta \mW_2$}; 17 | \node [block,fill=blue!10] (z) [above =of d] {$\Delta \mW_1$}; 18 | 19 | \node [block,fill=red!10] (uu) [below =of aa] {$\Delta \el$}; 20 | \node [block,fill=green!10] (vv) [below =of bb] {$\Delta \ell$}; 21 | 22 | \node (q) [above =of y, yshift=-0.9cm, align=center] {perturbations applied by optimiser\\[0.5ex]$\overbrace{\hspace{15.5em}}$}; 23 | \node (p) [below =of vv, yshift=0.9cm, align=center] {$\underbrace{\hspace{15.5em}}$\\[0.5ex]perturbations induced by optimiser}; 24 | 25 | \node [below =of d, yshift=0.9cm, align=center] {layer $1$}; 26 | \node [below =of c, yshift=0.9cm, align=center] {layer $2$}; 27 | \node [below =of b, yshift=0.9cm, align=center] {layer $L$}; 28 | \node [above =of a, yshift=-0.89cm, align=center] {output}; 29 | \node [above =of bb, yshift=-0.825cm, align=center] {loss}; 30 | \node [above =of aa, yshift=-0.89cm, align=center] {objective}; 31 | 32 | \draw[-latex] (bb) edge (aa); 33 | \draw[-latex] (a) edge (bb); 34 | \draw[-latex] (b) edge (a); 35 | \draw[-latex, dashed] (c) edge (b); 36 | \draw[-latex] (d) edge (c); 37 | 38 | \draw[-latex] (a) edge (v); 39 | \draw[-latex] (x) edge (b); 40 | \draw[-latex] (y) edge (c); 41 | \draw[-latex] (z) edge (d); 42 | 43 | \draw[-latex] (aa) edge (uu); 44 | \draw[-latex] (bb) edge (vv); 45 | 46 | \end{tikzpicture} 47 | \caption{\captiontitle{Perturbation hierarchy of a deep neural network.} When training a neural network, the optimiser applies structured perturbations to the weights, in the form of one perturbation matrix $\Delta \mW_k$ per weight matrix $\mW_k$. Deep relative trust \citep{my-fromage} provides a tool to understand how structured weight perturbations of this form affect the network output $\vf$. Combining deep relative trust with a Bregman divergence \citep{bregman1967relaxation} allows us to analyse the full perturbation hierarchy.} 48 | \label{fig:apbs} 49 | \end{figure} -------------------------------------------------------------------------------- /latex/figures/comparison-table.tex: -------------------------------------------------------------------------------- 1 | \begin{table} 2 | \centering 3 | \begin{tabularx}{\textwidth}{lXccccc} 4 | \toprule 5 | \textbf{Optimiser} & 6 | \textbf{Reference} & 7 | \makecell{\textbf{Hyperparameter}\\\textbf{Free}} & 8 | \makecell{\textbf{Width}\\\textbf{Scaling}} & \makecell{\textbf{Depth}\\\textbf{Scaling}} & 9 | \makecell{\textbf{Automatic}\\\textbf{Schedule}} & 10 | \makecell{\textbf{Memory}\\\textbf{Cost}} \\ 11 | \midrule 12 | Adam & $\mathrlap{\text{\citet{kingma_adam:_2015}}}$ & \xmark & \xmark & \xmark & \xmark & $3 \times \#$weights\\ 13 | SGD + mom. & $\mathrlap{\text{\citet{bottou}}}$ & \xmark & \xmark & \xmark & \xmark & $2\times\#$weights\\ 14 | SGD + muP & $\mathrlap{\text{\citet{Yang2021TensorPI}}}$ & \xmark & \cmark & \xmark & \xmark & $1\times\#$weights\\ 15 | AGD & this paper & \cmark & \cmark & \cmark & \cmark & $1\times\#$weights\\ 16 | \bottomrule 17 | \end{tabularx} 18 | \caption{\captiontitle{Comparing practical optimisers.} Adam and momentum-SGD employ running estimates of gradient statistics and thereby use more memory than AGD. In addition, Adam and SGD do not provide guidance on scaling hyperparameters with network architecture, although muP fixes this for the case of width scaling.} 19 | \label{tab:practice} 20 | \end{table} -------------------------------------------------------------------------------- /latex/figures/exp1.tex: -------------------------------------------------------------------------------- 1 | \begin{figure} 2 | \centering 3 | \includegraphics[width=\textwidth]{figures/pdf/plot1_full} 4 | \caption{\captiontitle{Benchmarking automatic gradient descent on a range of architectures and datasets.} Solid lines are AGD and faint dashed lines are tuned Adam except for ImageNet where the dashed line is SGD with a fixed learning rate of 0.1. ImageNet used cross-entropy loss with a mini-batch size of 1024. The other experiments used square loss with a mini-batch size of 128. 5 | The \captiontitle{top row} plots the automatic learning rate ($\eta$ in the main text) and objective value. The maximum and minimum learning rate for each epoch is included in addition to the mean for the first three plots. The \captiontitle{bottom row} shows the train and test accuracy. 6 | }\label{fig:1} 7 | \end{figure} -------------------------------------------------------------------------------- /latex/figures/exp2.tex: -------------------------------------------------------------------------------- 1 | \begin{figure} 2 | \centering 3 | \includegraphics[width=\textwidth]{figures/pdf/plot2} 4 | \caption{\captiontitle{Comparing automatic gradient descent to tuned Adam and SGD.} An eight-layer fully-connected network was trained on CIFAR-10 with square loss. Dotted lines show test and solid lines show train performance. 5 | The \captiontitle{left panel} shows the objective value: AGD and Adam attained a smaller training objective than SGD. The \captiontitle{middle panel} shows train and test accuracies. The \captiontitle{right panel} shows the relative update size averaged over layers: $\tfrac{1}{L}\sum_{k=1}^L \norm{\Delta \mW_k}_F/\norm{\mW_k}_F$. We plot the maximum, minimum and mean over an epoch.} \label{fig:2} 6 | \end{figure} -------------------------------------------------------------------------------- /latex/figures/exp3.tex: -------------------------------------------------------------------------------- 1 | \begin{figure} 2 | \centering 3 | \includegraphics[width=\textwidth]{figures/pdf/plots3_0} 4 | \caption{\captiontitle{Benchmarking automatic gradient descent on networks of varying width and depth.} We trained fully-connected networks on CIFAR-10 with square loss and a mini-batch size of 128. The depth ranged from $2$ to $32$, and the width from $64$ to $2048$, in powers of two. In terms of training performance, wider was always better, while depth 8 and depth 16 were superior to depth 32. In terms of test accuracy, the best performance was achieved at depth 4 and width 2048: 63.7\%. The worst test performance was achieved by the smallest network of depth 2 and width 64: 42.55\%. 5 | Larger networks display two broadly distinct phases of training: the automatic learning rate increases slowly while the objective decreases slowly, followed by a rapid decrease in the automatic learning rate and objective. This second phase typically coincides with reaching 100\% train accuracy. See \cref{fig:2} for a comparison between Adam, SGD and AGD for the 256-width 8-layer FCN.} \label{fig:3} 6 | \end{figure} -------------------------------------------------------------------------------- /latex/figures/exp4.tex: -------------------------------------------------------------------------------- 1 | \begin{figure} 2 | \centering 3 | \includegraphics[width=\textwidth]{figures/pdf/plot4} 4 | \caption{\captiontitle{Benchmarking automatic gradient descent at varying mini-batch size.} We trained four-layer fully-connected networks on CIFAR-10. The mini-batch size ranged from 32 to 4096. Test accuracy generally improved with increasing mini-batch size: the final test accuracies, in order of increasing mini-batch size, were 55.0\%, 58.0\%, 60.0\% and 59.8\%. The automatic learning rate seemed to initially dip, and this effect was more pronounced for larger mini-batch sizes. Metrics were computed every iteration during the first epoch and once per epoch from thereon---this explains the kinks visible in the plots. 5 | } \label{fig:4} 6 | \end{figure} -------------------------------------------------------------------------------- /latex/figures/pdf/maj-min.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxbz/agd/1f8454d24335d55aeffa1249d34d84d00d37d5d4/latex/figures/pdf/maj-min.pdf -------------------------------------------------------------------------------- /latex/figures/pdf/plot0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxbz/agd/1f8454d24335d55aeffa1249d34d84d00d37d5d4/latex/figures/pdf/plot0.pdf -------------------------------------------------------------------------------- /latex/figures/pdf/plot1_full.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxbz/agd/1f8454d24335d55aeffa1249d34d84d00d37d5d4/latex/figures/pdf/plot1_full.pdf -------------------------------------------------------------------------------- /latex/figures/pdf/plot2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxbz/agd/1f8454d24335d55aeffa1249d34d84d00d37d5d4/latex/figures/pdf/plot2.pdf -------------------------------------------------------------------------------- /latex/figures/pdf/plot4.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxbz/agd/1f8454d24335d55aeffa1249d34d84d00d37d5d4/latex/figures/pdf/plot4.pdf -------------------------------------------------------------------------------- /latex/figures/pdf/plots3_0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxbz/agd/1f8454d24335d55aeffa1249d34d84d00d37d5d4/latex/figures/pdf/plots3_0.pdf -------------------------------------------------------------------------------- /latex/figures/schematic.tex: -------------------------------------------------------------------------------- 1 | \begin{figure} 2 | \begin{minipage}{\textwidth} 3 | \centering 4 | \raisebox{-0.5\height}{\includegraphics{figures/pdf/maj-min}} \hfill \raisebox{-0.5\height}{ 5 | \begin{tikzpicture} 6 | [thick, block/.style={draw, minimum size=1.1cm}] 7 | 8 | \node [block,fill=red!30] (a) {$\el$}; 9 | \node [block,fill=green!30] (b) [right =of a] {$\ell$}; 10 | \node [block,fill=orange!30] (c) [right =of b] {$\vf$}; 11 | \node [block,fill=blue!30] (d) [right =of c] {$\vw$}; 12 | 13 | \node [block,fill=red!10] (u) [below =of a] {$\Delta \el$}; 14 | \node [block,fill=green!10] (v) [below =of b] {$\Delta \ell$}; 15 | \node [block,fill=orange!10] (w) [below =of c] {$\Delta \vf$}; 16 | \node [block,fill=blue!10] (x) [above =of d] {$\Delta \vw$}; 17 | 18 | \node (p) [left =of x, align=right, xshift=0.8cm, yshift=-0.04cm] {perturbation applied\\by optimiser}; 19 | \node (q) [below =of v, yshift=0.9cm, align=center] {$\underbrace{\hspace{14.5em}}$\\[0.5ex]perturbations induced by optimiser}; 20 | \node [below =of d, yshift=0.9cm, align=center] {weights}; 21 | \node [above =of c, yshift=-0.825cm, align=center] {model}; 22 | \node [above =of b, yshift=-0.825cm, align=center] {loss}; 23 | \node [above =of a, yshift=-0.89cm, align=center] {objective}; 24 | 25 | \draw[-latex] (b) edge (a); 26 | \draw[-latex] (c) edge (b); 27 | \draw[-latex] (d) edge (c); 28 | 29 | \draw[-latex] (a) edge (u); 30 | \draw[-latex] (b) edge (v); 31 | \draw[-latex] (c) edge (w); 32 | \draw[-latex] (x) edge (d); 33 | \end{tikzpicture}} 34 | \end{minipage} 35 | \caption{\captiontitle{Majorise-minimise and the perturbation hierarchy.} The \captiontitle{left panel} depicts the majorise-minimise meta-algorithm \citep{mm}, which is an algorithmic pattern for reducing an objective (blue) by minimising a sequence of upper bounds (one shown in red). The upper bounds, known as a \textit{majorisation}, must lie tangent to the objective to guarantee an improvement in one step of the meta-algorithm. The \captiontitle{right panel} depicts the perturbation hierarchy of a generic machine learning model: the optimiser perturbs the weights and this induces perturbations to the model output, the loss on individual training examples and ultimately the overall objective. Majorising machine learning objective functions requires addressing the full perturbation hierarchy.} 36 | \label{fig:maj-min} 37 | \end{figure} -------------------------------------------------------------------------------- /latex/figures/showcase.tex: -------------------------------------------------------------------------------- 1 | \begin{figure} 2 | \centering 3 | \includegraphics[width=\textwidth]{figures/pdf/plot0} 4 | \caption{\captiontitle{Automatic gradient descent trains neural networks reliably without hyperparameters.} Solid lines show train accuracy and dotted lines show test accuracy. The networks are unregularised with biases and affine parameters disabled, as these features are not yet supported by AGD. In the \captiontitle{left panel}---unlike AGD---Adam and SGD failed to train a 32-layer fully-connected network on CIFAR-10 with their default learning rates of 0.001 for Adam and 0.1 for SGD. The \captiontitle{middle panel} displays a learning rate grid search for ResNet-18 trained on CIFAR-10. AGD attained performance comparable to the best tuned performance of Adam and SGD. In the \captiontitle{right panel}, AGD trained ResNet-50 on ImageNet to a top-1 test accuracy of 65.5\%. The ImageNet baseline is SGD with a learning rate of 0.1 and no learning rate decay schedule.} 5 | \label{fig:showcase} 6 | \end{figure} -------------------------------------------------------------------------------- /latex/figures/theory-table.tex: -------------------------------------------------------------------------------- 1 | \begin{table} 2 | \centering 3 | \begin{tabularx}{\textwidth}{Xlcc} 4 | \toprule 5 | \textbf{Theory} & \textbf{Reference} & \makecell{\textbf{Handles the Loss}\\\begin{tikzpicture}[thick, block/.style={draw, minimum size=0.6cm}] 6 | \node [block,fill=red!30] (a) {$\el$}; 7 | \node [block,fill=green!30] (b) [right=0.5cm of a] {$\ell$}; 8 | \node [block,fill=orange!30] (c) [right=0.5cm of b] {$\vf$}; 9 | \draw[-latex] (b) edge (a); 10 | \draw[-latex] (c) edge (b); 11 | \end{tikzpicture}} & \makecell{\textbf{Non-Linear Network}\\ \begin{tikzpicture} 12 | [thick, block/.style={draw, minimum size=0.6cm}] 13 | \node [block,fill=orange!30] (c) {$\vf$}; 14 | \node [block,fill=blue!30] (d) [right= 0.5cm of c] {$\vw$}; 15 | \draw[-latex] (d) edge (c); 16 | \end{tikzpicture}} \\ 17 | \midrule 18 | mirror descent & $\mathrlap{\text{\citet{nemirovsky_yudin_1983}}}$\hspace{10em} & \cmark & \xmark \\ 19 | Gauss-Newton method &\citet{gauss-newton}& \cmark & \xmark \\ 20 | natural gradient descent &\citet{amari}& \cmark & \xmark \\ 21 | neural tangent kernel &\citet{NTKjacot}& \cmark & \xmark \\ 22 | deep relative trust &\citet{my-fromage}& \xmark & \cmark \\ 23 | tensor programs & \citet{Yang2021TensorPI}& \xmark & \cmark \\ 24 | automatic gradient descent & this paper &\cmark & \cmark \\ 25 | \bottomrule 26 | \end{tabularx} 27 | \caption{\captiontitle{Comparing popular frameworks for first-order optimisation theory.} Frameworks differ in whether they can handle the interaction between the model output $\vf$ and the objective $\el$, and the complex non-linear interaction between the weights $\vw$ and the model output $\vf$. Our framework handles both aspects.} 28 | \label{tab:theory} 29 | \end{table} -------------------------------------------------------------------------------- /latex/format.tex: -------------------------------------------------------------------------------- 1 | % References 2 | \renewcommand*{\backrefalt}[4]{\ifcase #1 No citations. \or Cited on page #2. \else Cited on pages #2. \fi} 3 | 4 | % Page numbers 5 | \fancypagestyle{empty}{\fancyhf{}} 6 | \fancypagestyle{plain}{\fancyfoot[C]{\sffamily\selectfont\thepage}} 7 | \pagestyle{plain} 8 | 9 | % Captions 10 | \renewcommand\AlCapFnt{\bf\sffamily} 11 | \captionsetup[figure]{labelfont={bf,sf}} 12 | \captionsetup[table]{labelfont={bf,sf}} 13 | \newcommand{\captiontitle}[1]{\textsf{\textbf{#1}}} 14 | 15 | % Paragraph 16 | \makeatletter 17 | \renewcommand\paragraph{\@startsection{paragraph}{4}{\z@}{0ex \@plus1ex \@minus.2ex}{-1em}{\bfseries\sffamily}} 18 | \makeatother 19 | 20 | % Theorem 21 | \makeatletter 22 | \xpatchcmd \thmt@restatable{\thmt@toks{}} 23 | {\def\thmt@tmp@restatename{#3}\thmt@toks{}}{}{\fail} 24 | 25 | \renewtheoremstyle{plain}{} 26 | {\item[\ifcsname thmt@tmp@restatename\endcsname 27 | \ifthmt@thisistheone\hskip\labelsep\mbox{\hyperref[proof:\thmt@tmp@restatename]{\bf\sffamily ##1\ ##2}} \fi 28 | \else \hskip\labelsep\bf\sffamily ##1\ ##2 \fi \normalfont\sffamily(##3)\theorem@separator]} 29 | \makeatother 30 | 31 | \theoremstyle{plain} 32 | \theorembodyfont{\normalfont} 33 | \newtheorem{approximation}{Approximation} 34 | 35 | \theoremstyle{plain} 36 | \theorembodyfont{\normalfont} 37 | \newtheorem{fact}{Fact} 38 | 39 | \theoremstyle{plain} 40 | \theorembodyfont{\normalfont} 41 | \newtheorem{prescription}{Prescription} 42 | \crefname{prescription}{Prescription}{Prescriptions} 43 | 44 | \theoremstyle{plain} 45 | \theorembodyfont{\normalfont} 46 | \newtheorem{proposition}{Proposition} 47 | 48 | \theoremstyle{plain} 49 | \theorembodyfont{\normalfont} 50 | \newtheorem{theorem}{Theorem} 51 | 52 | \theoremstyle{plain} 53 | \theorembodyfont{\normalfont} 54 | \newtheorem{lemma}{Lemma} 55 | 56 | \theoremstyle{plain} 57 | \theorembodyfont{\normalfont} 58 | \newtheorem{corollary}{Corollary} 59 | 60 | \theoremstyle{plain} 61 | \theorembodyfont{\normalfont} 62 | \newtheorem{ansatz}{Ansatz} 63 | 64 | \theoremstyle{plain} 65 | \theorembodyfont{\normalfont} 66 | \newtheorem{assumption}{Assumption} 67 | \crefname{assumption}{Assumption}{Assumptions} 68 | 69 | \theoremstyle{plain} 70 | \theorembodyfont{\normalfont} 71 | \newtheorem{definition}{Definition} 72 | 73 | \theoremstyle{plain} 74 | \theorembodyfont{\normalfont} 75 | \newtheorem{example}{Example} 76 | 77 | % Algorithm 78 | \newcommand{\WCOMMENT}[1]{\hfill\begin{minipage}{20em}\COMMENT{#1}\end{minipage}} 79 | \SetAlCapSkip{1em} 80 | \SetKwComment{Comment}{\small$\#$ }{} 81 | 82 | \algrenewcommand\algorithmicindent{2.0em} 83 | 84 | \renewcommand\algorithmicdo{} 85 | \renewcommand\algorithmicthen{} 86 | 87 | 88 | \algrenewcommand\alglinenumber[1]{\sffamily\footnotesize #1.} 89 | 90 | \algblockdefx[DEF]{DEF}{ENDDEF}% 91 | [1][function]{\textbf{def} \texttt{#1()}:} 92 | 93 | \makeatletter 94 | \ifthenelse{\equal{\ALG@noend}{t}}{\algtext*{ENDDEF}} 95 | \makeatother 96 | 97 | \makeatletter 98 | \xpatchcmd{\algorithmic}{\labelsep 0.5em}{\labelsep 1.0em}{\typeout{Success!}}{\typeout{Oh dear!}} 99 | \makeatother -------------------------------------------------------------------------------- /latex/macros.tex: -------------------------------------------------------------------------------- 1 | \newcommand{\cmark}{\textcolor{green!80!black}{\ding{51}}} 2 | \newcommand{\xmark}{\textcolor{red!80!black}{\ding{55}}} 3 | 4 | \DeclareMathOperator*{\argmax}{arg\,max} 5 | \DeclareMathOperator*{\argmin}{arg\,min} 6 | 7 | \DeclareMathOperator*{\bi}{{(i)}} 8 | \newcommand{\defeq}{\vcentcolon=} 9 | \newcommand{\el}{\mathcal{L}} 10 | \newcommand\mydots{\makebox[1em][c]{.\hfil.\hfil.}} 11 | 12 | \newcommand{\softmax}{\operatorname{softmax}} 13 | \newcommand{\relu}{\operatorname{relu}} 14 | \newcommand{\bregman}{\operatorname{bregman}} 15 | \newcommand{\kl}{D_\mathrm{KL}} 16 | \newcommand{\vect}{\operatorname{vec}} 17 | 18 | %%% Annotation 19 | 20 | \newcommand{\kevin}[1]{\textcolor{red}{[Kevin: #1]}} 21 | \newcommand{\chris}[1]{\textcolor{blue}{[Chris: #1]}} 22 | 23 | %%% Macros 24 | 25 | % Constants 26 | 27 | \newcommand{\econst}{\mathrm{e}} 28 | \newcommand{\iunit}{\mathrm{i}} 29 | 30 | % Abbreviations 31 | 32 | \newcommand{\eps}{\varepsilon} 33 | \newcommand{\oldphi}{\phi} 34 | \renewcommand{\phi}{\varphi} 35 | 36 | % Typesetting 37 | 38 | \newcommand{\vct}[1]{\bm{#1}} 39 | \newcommand{\mtx}[1]{\bm{#1}} 40 | \newcommand{\set}[1]{\mathsf{#1}} 41 | \newcommand{\coll}[1]{\mathcal{#1}} 42 | 43 | \newcommand{\half}{\tfrac{1}{2}} 44 | 45 | % Sets 46 | \newcommand{\N}{\mathbb{N}} 47 | \newcommand{\Z}{\mathbb{Z}} 48 | \newcommand{\Q}{\mathbb{Q}} 49 | \newcommand{\R}{\mathbb{R}} 50 | \newcommand{\Sph}{\mathbb{S}} 51 | % \renewcommand{\C}{\mathbb{C}} 52 | \newcommand{\F}{\mathbb{F}} 53 | 54 | \newcommand{\Sym}{\mathbb{H}} 55 | 56 | \newcommand{\comp}{\textsf{c}} 57 | 58 | % Elementary functions 59 | 60 | \newcommand{\sign}{\operatorname{sign}} 61 | \newcommand{\erf}{\operatorname{erf}} 62 | \newcommand{\Normal}{\operatorname{Normal}} 63 | \newcommand{\Orthant}{\operatorname{Orthant}} 64 | 65 | % Asymptotics 66 | 67 | \newcommand{\bigO}{O} 68 | 69 | % Linear algebra 70 | 71 | \newcommand{\range}{\operatorname{range}} 72 | \newcommand{\nullsp}{\operatorname{null}} 73 | 74 | \newcommand{\lspan}{\operatorname{lin}} 75 | 76 | \newcommand{\rank}{\operatorname{rank}} 77 | \newcommand{\srank}{\operatorname{rank_{stable}}} 78 | \newcommand{\trace}{\operatorname{tr}} 79 | \newcommand{\diag}{\operatorname{diag}} 80 | \newcommand{\Id}{\mathbf{I}} 81 | 82 | \newcommand{\pinv}{\dagger} 83 | 84 | \newcommand{\psdle}{\preccurlyeq} 85 | \newcommand{\psdge}{\succcurlyeq} 86 | \newcommand{\psdlt}{\prec} 87 | \newcommand{\psdgt}{\succ} 88 | 89 | % Mensuration 90 | 91 | \newcommand{\abs}[1]{\vert {#1} \vert} 92 | \newcommand{\norm}[1]{\Vert {#1} \Vert} 93 | \newcommand{\ip}[2]{\langle {#1}, \ {#2} \rangle} 94 | \newcommand{\absip}[2]{\abs{\ip{#1}{#2}}} 95 | 96 | \newcommand{\abssq}[1]{\abs{#1}^2} 97 | \newcommand{\pnorm}[2]{\norm{#2}_{#1}} 98 | \newcommand{\normsq}[1]{\norm{#1}^2} 99 | \newcommand{\fnorm}[1]{\norm{#1}_{\mathrm{F}}} 100 | \newcommand{\specnorm}[1]{\Vert #1 \Vert_{*}} 101 | \newcommand{\fnormsq}[1]{\norm{#1}_{\mathrm{F}}^2} 102 | 103 | \newcommand{\labs}[1]{\left\vert {#1} \right\vert} 104 | \newcommand{\lnorm}[1]{\left\Vert {#1} \right\Vert} 105 | 106 | % Calculus 107 | 108 | \newcommand{\diff}{\mathrm{d}} 109 | \newcommand{\idiff}{\,\diff} 110 | \newcommand{\ddx}[1]{\frac{\diff}{\diff{#1}}} 111 | \newcommand{\dydx}[2]{\frac{\diff{#1}}{\diff{#2}}} 112 | \newcommand{\pypx}[2]{\frac{\partial{#1}}{\partial{#2}}} 113 | 114 | % Probability 115 | 116 | \newcommand{\Expect}{\operatorname{\mathbb{E}}} 117 | \newcommand{\Var}{\operatorname{Var}} 118 | \newcommand{\Cov}{\operatorname{Cov}} 119 | 120 | \newcommand{\Probe}{\mathbb{P}} 121 | \newcommand{\Prob}[1]{\Probe\left\{ #1 \right\}} 122 | \newcommand{\Probc}[2]{\Probe_{#1}\left\{ #2 \right\}} 123 | 124 | \newcommand{\condbar}{\, \vert \,} 125 | \newcommand{\lcondbar}{\, \big\vert \,} 126 | 127 | \newcommand{\normal}{\textsc{normal}} 128 | \newcommand{\uniform}{\textsc{uniform}} 129 | \newcommand{\orthogonal}{\mathrm{orthogonal}} 130 | 131 | \newcommand{\comple}{\blacktriangleleft} 132 | 133 | 134 | % Convex analysis 135 | 136 | \newcommand{\conv}{\operatorname{conv}} 137 | 138 | % Misc 139 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 140 | % BOLD LETTERS FOR VECTORS AND MATRICES 141 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 142 | % Vectors 143 | \def\vzero{{\bm{0}}} 144 | \def\vone{{\bm{1}}} 145 | \def\vmu{{\bm{\mu}}} 146 | \def\vtheta{{\bm{\theta}}} 147 | \def\va{{\bm{a}}} 148 | \def\vb{{\bm{b}}} 149 | \def\vc{{\bm{c}}} 150 | \def\vd{{\bm{d}}} 151 | \def\ve{{\bm{e}}} 152 | \def\vf{{\bm{f}}} 153 | \def\vg{{\bm{g}}} 154 | \def\vh{{\bm{h}}} 155 | \def\vi{{\bm{i}}} 156 | \def\vj{{\bm{j}}} 157 | \def\vk{{\bm{k}}} 158 | \def\vl{{\bm{l}}} 159 | \def\vm{{\bm{m}}} 160 | \def\vn{{\bm{n}}} 161 | \def\vo{{\bm{o}}} 162 | \def\vp{{\bm{p}}} 163 | \def\vq{{\bm{q}}} 164 | \def\vr{{\bm{r}}} 165 | \def\vs{{\bm{s}}} 166 | \def\vt{{\bm{t}}} 167 | \def\vu{{\bm{u}}} 168 | \def\vv{{\bm{v}}} 169 | \def\vw{{\bm{w}}} 170 | \def\vx{{\bm{x}}} 171 | \def\vy{{\bm{y}}} 172 | \def\vz{{\bm{z}}} 173 | 174 | % Matrices 175 | \def\mA{{\bm{A}}} 176 | \def\mB{{\bm{B}}} 177 | \def\mC{{\bm{C}}} 178 | \def\mD{{\bm{D}}} 179 | \def\mE{{\bm{E}}} 180 | \def\mF{{\bm{F}}} 181 | \def\mG{{\bm{G}}} 182 | \def\mH{{\bm{H}}} 183 | \def\mI{{\bm{I}}} 184 | \def\mJ{{\bm{J}}} 185 | \def\mK{{\bm{K}}} 186 | \def\mL{{\bm{L}}} 187 | \def\mM{{\bm{M}}} 188 | \def\mN{{\bm{N}}} 189 | \def\mO{{\bm{O}}} 190 | \def\mP{{\bm{P}}} 191 | \def\mQ{{\bm{Q}}} 192 | \def\mR{{\bm{R}}} 193 | \def\mS{{\bm{S}}} 194 | \def\mT{{\bm{T}}} 195 | \def\mU{{\bm{U}}} 196 | \def\mV{{\bm{V}}} 197 | \def\mW{{\bm{W}}} 198 | \def\mX{{\bm{X}}} 199 | \def\mY{{\bm{Y}}} 200 | \def\mZ{{\bm{Z}}} 201 | \def\mDelta{{\bm{\Delta}}} 202 | \def\mGamma{{\bm{\Gamma}}} 203 | \def\mLambda{{\bm{\Lambda}}} 204 | \def\mPhi{{\bm{\Phi}}} 205 | \def\mSigma{{\bm{\Sigma}}} -------------------------------------------------------------------------------- /latex/main.tex: -------------------------------------------------------------------------------- 1 | \documentclass[10pt]{article} 2 | \usepackage[preprint]{tmlr/tmlr} 3 | 4 | \input{packages} 5 | \input{macros} 6 | \input{format} 7 | 8 | \title{Automatic Gradient Descent:\\Deep Learning without Hyperparameters} 9 | 10 | \newcommand{\authspace}{\hspace{3.19em}} 11 | \newcommand{\auth}[2]{\begin{tabular}{@{}l@{}}{#1}\\\normalfont{#2}\end{tabular}} 12 | \newcommand{\cred}[1]{{\color{red} #1}} 13 | 14 | \author{\sffamily\auth{Jeremy Bernstein$^\star$}{MIT}\authspace\auth{\hspace{-5pt}Chris Mingard$^\star$}{\hspace{-5pt}U.\ Oxford}\authspace\auth{Kevin Huang}{U.\ Washington}\authspace\auth{Navid Azizan}{MIT} \authspace\auth{Yisong Yue}{Caltech}} 15 | 16 | \begin{document} 17 | 18 | \maketitle 19 | \thispagestyle{empty} 20 | 21 | \vspace{-5ex} 22 | \hfill$\star$ denotes equal contribution.\\ 23 | 24 | \input{section/00-abstract} 25 | {\sffamily\textbf{Keywords:}} majorise-minimise meta-algorithm, operator perturbation theory, architecture-aware optimisation 26 | 27 | \sffamily 28 | \setstretch{0} 29 | \tableofcontents 30 | \setstretch{1} 31 | \normalfont 32 | 33 | \input{figures/showcase} 34 | 35 | \input{section/01-intro} 36 | \input{section/03-sketch} 37 | \input{section/04-bounds} 38 | \input{section/07-experiments} 39 | \input{section/08-discuss} 40 | \input{section/98-ack} 41 | 42 | \bibliography{refs} 43 | \bibliographystyle{tmlr/tmlr} 44 | 45 | \newpage 46 | \appendix 47 | \input{section/99-appendix} 48 | 49 | \end{document} 50 | -------------------------------------------------------------------------------- /latex/minted-output/default-pyg-prefix.pygstyle: -------------------------------------------------------------------------------- 1 | 2 | \makeatletter 3 | \def\PYG@reset{\let\PYG@it=\relax \let\PYG@bf=\relax% 4 | \let\PYG@ul=\relax \let\PYG@tc=\relax% 5 | \let\PYG@bc=\relax \let\PYG@ff=\relax} 6 | \def\PYG@tok#1{\csname PYG@tok@#1\endcsname} 7 | \def\PYG@toks#1+{\ifx\relax#1\empty\else% 8 | \PYG@tok{#1}\expandafter\PYG@toks\fi} 9 | \def\PYG@do#1{\PYG@bc{\PYG@tc{\PYG@ul{% 10 | \PYG@it{\PYG@bf{\PYG@ff{#1}}}}}}} 11 | \def\PYG#1#2{\PYG@reset\PYG@toks#1+\relax+\PYG@do{#2}} 12 | 13 | \@namedef{PYG@tok@w}{\def\PYG@tc##1{\textcolor[rgb]{0.73,0.73,0.73}{##1}}} 14 | \@namedef{PYG@tok@c}{\let\PYG@it=\textit\def\PYG@tc##1{\textcolor[rgb]{0.25,0.50,0.50}{##1}}} 15 | \@namedef{PYG@tok@cp}{\def\PYG@tc##1{\textcolor[rgb]{0.74,0.48,0.00}{##1}}} 16 | \@namedef{PYG@tok@k}{\let\PYG@bf=\textbf\def\PYG@tc##1{\textcolor[rgb]{0.00,0.50,0.00}{##1}}} 17 | \@namedef{PYG@tok@kp}{\def\PYG@tc##1{\textcolor[rgb]{0.00,0.50,0.00}{##1}}} 18 | \@namedef{PYG@tok@kt}{\def\PYG@tc##1{\textcolor[rgb]{0.69,0.00,0.25}{##1}}} 19 | \@namedef{PYG@tok@o}{\def\PYG@tc##1{\textcolor[rgb]{0.40,0.40,0.40}{##1}}} 20 | \@namedef{PYG@tok@ow}{\let\PYG@bf=\textbf\def\PYG@tc##1{\textcolor[rgb]{0.67,0.13,1.00}{##1}}} 21 | \@namedef{PYG@tok@nb}{\def\PYG@tc##1{\textcolor[rgb]{0.00,0.50,0.00}{##1}}} 22 | \@namedef{PYG@tok@nf}{\def\PYG@tc##1{\textcolor[rgb]{0.00,0.00,1.00}{##1}}} 23 | \@namedef{PYG@tok@nc}{\let\PYG@bf=\textbf\def\PYG@tc##1{\textcolor[rgb]{0.00,0.00,1.00}{##1}}} 24 | \@namedef{PYG@tok@nn}{\let\PYG@bf=\textbf\def\PYG@tc##1{\textcolor[rgb]{0.00,0.00,1.00}{##1}}} 25 | \@namedef{PYG@tok@ne}{\let\PYG@bf=\textbf\def\PYG@tc##1{\textcolor[rgb]{0.82,0.25,0.23}{##1}}} 26 | \@namedef{PYG@tok@nv}{\def\PYG@tc##1{\textcolor[rgb]{0.10,0.09,0.49}{##1}}} 27 | \@namedef{PYG@tok@no}{\def\PYG@tc##1{\textcolor[rgb]{0.53,0.00,0.00}{##1}}} 28 | \@namedef{PYG@tok@nl}{\def\PYG@tc##1{\textcolor[rgb]{0.63,0.63,0.00}{##1}}} 29 | \@namedef{PYG@tok@ni}{\let\PYG@bf=\textbf\def\PYG@tc##1{\textcolor[rgb]{0.60,0.60,0.60}{##1}}} 30 | \@namedef{PYG@tok@na}{\def\PYG@tc##1{\textcolor[rgb]{0.49,0.56,0.16}{##1}}} 31 | \@namedef{PYG@tok@nt}{\let\PYG@bf=\textbf\def\PYG@tc##1{\textcolor[rgb]{0.00,0.50,0.00}{##1}}} 32 | \@namedef{PYG@tok@nd}{\def\PYG@tc##1{\textcolor[rgb]{0.67,0.13,1.00}{##1}}} 33 | \@namedef{PYG@tok@s}{\def\PYG@tc##1{\textcolor[rgb]{0.73,0.13,0.13}{##1}}} 34 | \@namedef{PYG@tok@sd}{\let\PYG@it=\textit\def\PYG@tc##1{\textcolor[rgb]{0.73,0.13,0.13}{##1}}} 35 | \@namedef{PYG@tok@si}{\let\PYG@bf=\textbf\def\PYG@tc##1{\textcolor[rgb]{0.73,0.40,0.53}{##1}}} 36 | \@namedef{PYG@tok@se}{\let\PYG@bf=\textbf\def\PYG@tc##1{\textcolor[rgb]{0.73,0.40,0.13}{##1}}} 37 | \@namedef{PYG@tok@sr}{\def\PYG@tc##1{\textcolor[rgb]{0.73,0.40,0.53}{##1}}} 38 | \@namedef{PYG@tok@ss}{\def\PYG@tc##1{\textcolor[rgb]{0.10,0.09,0.49}{##1}}} 39 | \@namedef{PYG@tok@sx}{\def\PYG@tc##1{\textcolor[rgb]{0.00,0.50,0.00}{##1}}} 40 | \@namedef{PYG@tok@m}{\def\PYG@tc##1{\textcolor[rgb]{0.40,0.40,0.40}{##1}}} 41 | \@namedef{PYG@tok@gh}{\let\PYG@bf=\textbf\def\PYG@tc##1{\textcolor[rgb]{0.00,0.00,0.50}{##1}}} 42 | \@namedef{PYG@tok@gu}{\let\PYG@bf=\textbf\def\PYG@tc##1{\textcolor[rgb]{0.50,0.00,0.50}{##1}}} 43 | \@namedef{PYG@tok@gd}{\def\PYG@tc##1{\textcolor[rgb]{0.63,0.00,0.00}{##1}}} 44 | \@namedef{PYG@tok@gi}{\def\PYG@tc##1{\textcolor[rgb]{0.00,0.63,0.00}{##1}}} 45 | \@namedef{PYG@tok@gr}{\def\PYG@tc##1{\textcolor[rgb]{1.00,0.00,0.00}{##1}}} 46 | \@namedef{PYG@tok@ge}{\let\PYG@it=\textit} 47 | \@namedef{PYG@tok@gs}{\let\PYG@bf=\textbf} 48 | \@namedef{PYG@tok@gp}{\let\PYG@bf=\textbf\def\PYG@tc##1{\textcolor[rgb]{0.00,0.00,0.50}{##1}}} 49 | \@namedef{PYG@tok@go}{\def\PYG@tc##1{\textcolor[rgb]{0.53,0.53,0.53}{##1}}} 50 | \@namedef{PYG@tok@gt}{\def\PYG@tc##1{\textcolor[rgb]{0.00,0.27,0.87}{##1}}} 51 | \@namedef{PYG@tok@err}{\def\PYG@bc##1{{\setlength{\fboxsep}{\string -\fboxrule}\fcolorbox[rgb]{1.00,0.00,0.00}{1,1,1}{\strut ##1}}}} 52 | \@namedef{PYG@tok@kc}{\let\PYG@bf=\textbf\def\PYG@tc##1{\textcolor[rgb]{0.00,0.50,0.00}{##1}}} 53 | \@namedef{PYG@tok@kd}{\let\PYG@bf=\textbf\def\PYG@tc##1{\textcolor[rgb]{0.00,0.50,0.00}{##1}}} 54 | \@namedef{PYG@tok@kn}{\let\PYG@bf=\textbf\def\PYG@tc##1{\textcolor[rgb]{0.00,0.50,0.00}{##1}}} 55 | \@namedef{PYG@tok@kr}{\let\PYG@bf=\textbf\def\PYG@tc##1{\textcolor[rgb]{0.00,0.50,0.00}{##1}}} 56 | \@namedef{PYG@tok@bp}{\def\PYG@tc##1{\textcolor[rgb]{0.00,0.50,0.00}{##1}}} 57 | \@namedef{PYG@tok@fm}{\def\PYG@tc##1{\textcolor[rgb]{0.00,0.00,1.00}{##1}}} 58 | \@namedef{PYG@tok@vc}{\def\PYG@tc##1{\textcolor[rgb]{0.10,0.09,0.49}{##1}}} 59 | \@namedef{PYG@tok@vg}{\def\PYG@tc##1{\textcolor[rgb]{0.10,0.09,0.49}{##1}}} 60 | \@namedef{PYG@tok@vi}{\def\PYG@tc##1{\textcolor[rgb]{0.10,0.09,0.49}{##1}}} 61 | \@namedef{PYG@tok@vm}{\def\PYG@tc##1{\textcolor[rgb]{0.10,0.09,0.49}{##1}}} 62 | \@namedef{PYG@tok@sa}{\def\PYG@tc##1{\textcolor[rgb]{0.73,0.13,0.13}{##1}}} 63 | \@namedef{PYG@tok@sb}{\def\PYG@tc##1{\textcolor[rgb]{0.73,0.13,0.13}{##1}}} 64 | \@namedef{PYG@tok@sc}{\def\PYG@tc##1{\textcolor[rgb]{0.73,0.13,0.13}{##1}}} 65 | \@namedef{PYG@tok@dl}{\def\PYG@tc##1{\textcolor[rgb]{0.73,0.13,0.13}{##1}}} 66 | \@namedef{PYG@tok@s2}{\def\PYG@tc##1{\textcolor[rgb]{0.73,0.13,0.13}{##1}}} 67 | \@namedef{PYG@tok@sh}{\def\PYG@tc##1{\textcolor[rgb]{0.73,0.13,0.13}{##1}}} 68 | \@namedef{PYG@tok@s1}{\def\PYG@tc##1{\textcolor[rgb]{0.73,0.13,0.13}{##1}}} 69 | \@namedef{PYG@tok@mb}{\def\PYG@tc##1{\textcolor[rgb]{0.40,0.40,0.40}{##1}}} 70 | \@namedef{PYG@tok@mf}{\def\PYG@tc##1{\textcolor[rgb]{0.40,0.40,0.40}{##1}}} 71 | \@namedef{PYG@tok@mh}{\def\PYG@tc##1{\textcolor[rgb]{0.40,0.40,0.40}{##1}}} 72 | \@namedef{PYG@tok@mi}{\def\PYG@tc##1{\textcolor[rgb]{0.40,0.40,0.40}{##1}}} 73 | \@namedef{PYG@tok@il}{\def\PYG@tc##1{\textcolor[rgb]{0.40,0.40,0.40}{##1}}} 74 | \@namedef{PYG@tok@mo}{\def\PYG@tc##1{\textcolor[rgb]{0.40,0.40,0.40}{##1}}} 75 | \@namedef{PYG@tok@ch}{\let\PYG@it=\textit\def\PYG@tc##1{\textcolor[rgb]{0.25,0.50,0.50}{##1}}} 76 | \@namedef{PYG@tok@cm}{\let\PYG@it=\textit\def\PYG@tc##1{\textcolor[rgb]{0.25,0.50,0.50}{##1}}} 77 | \@namedef{PYG@tok@cpf}{\let\PYG@it=\textit\def\PYG@tc##1{\textcolor[rgb]{0.25,0.50,0.50}{##1}}} 78 | \@namedef{PYG@tok@c1}{\let\PYG@it=\textit\def\PYG@tc##1{\textcolor[rgb]{0.25,0.50,0.50}{##1}}} 79 | \@namedef{PYG@tok@cs}{\let\PYG@it=\textit\def\PYG@tc##1{\textcolor[rgb]{0.25,0.50,0.50}{##1}}} 80 | 81 | \def\PYGZbs{\char`\\} 82 | \def\PYGZus{\char`\_} 83 | \def\PYGZob{\char`\{} 84 | \def\PYGZcb{\char`\}} 85 | \def\PYGZca{\char`\^} 86 | \def\PYGZam{\char`\&} 87 | \def\PYGZlt{\char`\<} 88 | \def\PYGZgt{\char`\>} 89 | \def\PYGZsh{\char`\#} 90 | \def\PYGZpc{\char`\%} 91 | \def\PYGZdl{\char`\$} 92 | \def\PYGZhy{\char`\-} 93 | \def\PYGZsq{\char`\'} 94 | \def\PYGZdq{\char`\"} 95 | \def\PYGZti{\char`\~} 96 | % for compatibility with earlier versions 97 | \def\PYGZat{@} 98 | \def\PYGZlb{[} 99 | \def\PYGZrb{]} 100 | \makeatother 101 | 102 | -------------------------------------------------------------------------------- /latex/minted-output/default.pygstyle: -------------------------------------------------------------------------------- 1 | 2 | \makeatletter 3 | \def\PYGdefault@reset{\let\PYGdefault@it=\relax \let\PYGdefault@bf=\relax% 4 | \let\PYGdefault@ul=\relax \let\PYGdefault@tc=\relax% 5 | \let\PYGdefault@bc=\relax \let\PYGdefault@ff=\relax} 6 | \def\PYGdefault@tok#1{\csname PYGdefault@tok@#1\endcsname} 7 | \def\PYGdefault@toks#1+{\ifx\relax#1\empty\else% 8 | \PYGdefault@tok{#1}\expandafter\PYGdefault@toks\fi} 9 | \def\PYGdefault@do#1{\PYGdefault@bc{\PYGdefault@tc{\PYGdefault@ul{% 10 | \PYGdefault@it{\PYGdefault@bf{\PYGdefault@ff{#1}}}}}}} 11 | \def\PYGdefault#1#2{\PYGdefault@reset\PYGdefault@toks#1+\relax+\PYGdefault@do{#2}} 12 | 13 | \@namedef{PYGdefault@tok@w}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.73,0.73,0.73}{##1}}} 14 | \@namedef{PYGdefault@tok@c}{\let\PYGdefault@it=\textit\def\PYGdefault@tc##1{\textcolor[rgb]{0.25,0.50,0.50}{##1}}} 15 | \@namedef{PYGdefault@tok@cp}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.74,0.48,0.00}{##1}}} 16 | \@namedef{PYGdefault@tok@k}{\let\PYGdefault@bf=\textbf\def\PYGdefault@tc##1{\textcolor[rgb]{0.00,0.50,0.00}{##1}}} 17 | \@namedef{PYGdefault@tok@kp}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.00,0.50,0.00}{##1}}} 18 | \@namedef{PYGdefault@tok@kt}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.69,0.00,0.25}{##1}}} 19 | \@namedef{PYGdefault@tok@o}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.40,0.40,0.40}{##1}}} 20 | \@namedef{PYGdefault@tok@ow}{\let\PYGdefault@bf=\textbf\def\PYGdefault@tc##1{\textcolor[rgb]{0.67,0.13,1.00}{##1}}} 21 | \@namedef{PYGdefault@tok@nb}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.00,0.50,0.00}{##1}}} 22 | \@namedef{PYGdefault@tok@nf}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.00,0.00,1.00}{##1}}} 23 | \@namedef{PYGdefault@tok@nc}{\let\PYGdefault@bf=\textbf\def\PYGdefault@tc##1{\textcolor[rgb]{0.00,0.00,1.00}{##1}}} 24 | \@namedef{PYGdefault@tok@nn}{\let\PYGdefault@bf=\textbf\def\PYGdefault@tc##1{\textcolor[rgb]{0.00,0.00,1.00}{##1}}} 25 | \@namedef{PYGdefault@tok@ne}{\let\PYGdefault@bf=\textbf\def\PYGdefault@tc##1{\textcolor[rgb]{0.82,0.25,0.23}{##1}}} 26 | \@namedef{PYGdefault@tok@nv}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.10,0.09,0.49}{##1}}} 27 | \@namedef{PYGdefault@tok@no}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.53,0.00,0.00}{##1}}} 28 | \@namedef{PYGdefault@tok@nl}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.63,0.63,0.00}{##1}}} 29 | \@namedef{PYGdefault@tok@ni}{\let\PYGdefault@bf=\textbf\def\PYGdefault@tc##1{\textcolor[rgb]{0.60,0.60,0.60}{##1}}} 30 | \@namedef{PYGdefault@tok@na}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.49,0.56,0.16}{##1}}} 31 | \@namedef{PYGdefault@tok@nt}{\let\PYGdefault@bf=\textbf\def\PYGdefault@tc##1{\textcolor[rgb]{0.00,0.50,0.00}{##1}}} 32 | \@namedef{PYGdefault@tok@nd}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.67,0.13,1.00}{##1}}} 33 | \@namedef{PYGdefault@tok@s}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.73,0.13,0.13}{##1}}} 34 | \@namedef{PYGdefault@tok@sd}{\let\PYGdefault@it=\textit\def\PYGdefault@tc##1{\textcolor[rgb]{0.73,0.13,0.13}{##1}}} 35 | \@namedef{PYGdefault@tok@si}{\let\PYGdefault@bf=\textbf\def\PYGdefault@tc##1{\textcolor[rgb]{0.73,0.40,0.53}{##1}}} 36 | \@namedef{PYGdefault@tok@se}{\let\PYGdefault@bf=\textbf\def\PYGdefault@tc##1{\textcolor[rgb]{0.73,0.40,0.13}{##1}}} 37 | \@namedef{PYGdefault@tok@sr}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.73,0.40,0.53}{##1}}} 38 | \@namedef{PYGdefault@tok@ss}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.10,0.09,0.49}{##1}}} 39 | \@namedef{PYGdefault@tok@sx}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.00,0.50,0.00}{##1}}} 40 | \@namedef{PYGdefault@tok@m}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.40,0.40,0.40}{##1}}} 41 | \@namedef{PYGdefault@tok@gh}{\let\PYGdefault@bf=\textbf\def\PYGdefault@tc##1{\textcolor[rgb]{0.00,0.00,0.50}{##1}}} 42 | \@namedef{PYGdefault@tok@gu}{\let\PYGdefault@bf=\textbf\def\PYGdefault@tc##1{\textcolor[rgb]{0.50,0.00,0.50}{##1}}} 43 | \@namedef{PYGdefault@tok@gd}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.63,0.00,0.00}{##1}}} 44 | \@namedef{PYGdefault@tok@gi}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.00,0.63,0.00}{##1}}} 45 | \@namedef{PYGdefault@tok@gr}{\def\PYGdefault@tc##1{\textcolor[rgb]{1.00,0.00,0.00}{##1}}} 46 | \@namedef{PYGdefault@tok@ge}{\let\PYGdefault@it=\textit} 47 | \@namedef{PYGdefault@tok@gs}{\let\PYGdefault@bf=\textbf} 48 | \@namedef{PYGdefault@tok@gp}{\let\PYGdefault@bf=\textbf\def\PYGdefault@tc##1{\textcolor[rgb]{0.00,0.00,0.50}{##1}}} 49 | \@namedef{PYGdefault@tok@go}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.53,0.53,0.53}{##1}}} 50 | \@namedef{PYGdefault@tok@gt}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.00,0.27,0.87}{##1}}} 51 | \@namedef{PYGdefault@tok@err}{\def\PYGdefault@bc##1{{\setlength{\fboxsep}{\string -\fboxrule}\fcolorbox[rgb]{1.00,0.00,0.00}{1,1,1}{\strut ##1}}}} 52 | \@namedef{PYGdefault@tok@kc}{\let\PYGdefault@bf=\textbf\def\PYGdefault@tc##1{\textcolor[rgb]{0.00,0.50,0.00}{##1}}} 53 | \@namedef{PYGdefault@tok@kd}{\let\PYGdefault@bf=\textbf\def\PYGdefault@tc##1{\textcolor[rgb]{0.00,0.50,0.00}{##1}}} 54 | \@namedef{PYGdefault@tok@kn}{\let\PYGdefault@bf=\textbf\def\PYGdefault@tc##1{\textcolor[rgb]{0.00,0.50,0.00}{##1}}} 55 | \@namedef{PYGdefault@tok@kr}{\let\PYGdefault@bf=\textbf\def\PYGdefault@tc##1{\textcolor[rgb]{0.00,0.50,0.00}{##1}}} 56 | \@namedef{PYGdefault@tok@bp}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.00,0.50,0.00}{##1}}} 57 | \@namedef{PYGdefault@tok@fm}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.00,0.00,1.00}{##1}}} 58 | \@namedef{PYGdefault@tok@vc}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.10,0.09,0.49}{##1}}} 59 | \@namedef{PYGdefault@tok@vg}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.10,0.09,0.49}{##1}}} 60 | \@namedef{PYGdefault@tok@vi}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.10,0.09,0.49}{##1}}} 61 | \@namedef{PYGdefault@tok@vm}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.10,0.09,0.49}{##1}}} 62 | \@namedef{PYGdefault@tok@sa}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.73,0.13,0.13}{##1}}} 63 | \@namedef{PYGdefault@tok@sb}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.73,0.13,0.13}{##1}}} 64 | \@namedef{PYGdefault@tok@sc}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.73,0.13,0.13}{##1}}} 65 | \@namedef{PYGdefault@tok@dl}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.73,0.13,0.13}{##1}}} 66 | \@namedef{PYGdefault@tok@s2}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.73,0.13,0.13}{##1}}} 67 | \@namedef{PYGdefault@tok@sh}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.73,0.13,0.13}{##1}}} 68 | \@namedef{PYGdefault@tok@s1}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.73,0.13,0.13}{##1}}} 69 | \@namedef{PYGdefault@tok@mb}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.40,0.40,0.40}{##1}}} 70 | \@namedef{PYGdefault@tok@mf}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.40,0.40,0.40}{##1}}} 71 | \@namedef{PYGdefault@tok@mh}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.40,0.40,0.40}{##1}}} 72 | \@namedef{PYGdefault@tok@mi}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.40,0.40,0.40}{##1}}} 73 | \@namedef{PYGdefault@tok@il}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.40,0.40,0.40}{##1}}} 74 | \@namedef{PYGdefault@tok@mo}{\def\PYGdefault@tc##1{\textcolor[rgb]{0.40,0.40,0.40}{##1}}} 75 | \@namedef{PYGdefault@tok@ch}{\let\PYGdefault@it=\textit\def\PYGdefault@tc##1{\textcolor[rgb]{0.25,0.50,0.50}{##1}}} 76 | \@namedef{PYGdefault@tok@cm}{\let\PYGdefault@it=\textit\def\PYGdefault@tc##1{\textcolor[rgb]{0.25,0.50,0.50}{##1}}} 77 | \@namedef{PYGdefault@tok@cpf}{\let\PYGdefault@it=\textit\def\PYGdefault@tc##1{\textcolor[rgb]{0.25,0.50,0.50}{##1}}} 78 | \@namedef{PYGdefault@tok@c1}{\let\PYGdefault@it=\textit\def\PYGdefault@tc##1{\textcolor[rgb]{0.25,0.50,0.50}{##1}}} 79 | \@namedef{PYGdefault@tok@cs}{\let\PYGdefault@it=\textit\def\PYGdefault@tc##1{\textcolor[rgb]{0.25,0.50,0.50}{##1}}} 80 | 81 | \def\PYGdefaultZbs{\char`\\} 82 | \def\PYGdefaultZus{\char`\_} 83 | \def\PYGdefaultZob{\char`\{} 84 | \def\PYGdefaultZcb{\char`\}} 85 | \def\PYGdefaultZca{\char`\^} 86 | \def\PYGdefaultZam{\char`\&} 87 | \def\PYGdefaultZlt{\char`\<} 88 | \def\PYGdefaultZgt{\char`\>} 89 | \def\PYGdefaultZsh{\char`\#} 90 | \def\PYGdefaultZpc{\char`\%} 91 | \def\PYGdefaultZdl{\char`\$} 92 | \def\PYGdefaultZhy{\char`\-} 93 | \def\PYGdefaultZsq{\char`\'} 94 | \def\PYGdefaultZdq{\char`\"} 95 | \def\PYGdefaultZti{\char`\~} 96 | % for compatibility with earlier versions 97 | \def\PYGdefaultZat{@} 98 | \def\PYGdefaultZlb{[} 99 | \def\PYGdefaultZrb{]} 100 | \makeatother 101 | 102 | -------------------------------------------------------------------------------- /latex/minted-output/listing1.pygtex: -------------------------------------------------------------------------------- 1 | \begin{Verbatim}[commandchars=\\\{\}] 2 | \PYG{k+kn}{import} \PYG{n+nn}{math} 3 | \PYG{k+kn}{import} \PYG{n+nn}{torch} 4 | 5 | \PYG{k+kn}{from} \PYG{n+nn}{torch.nn.init} \PYG{k+kn}{import} \PYG{n}{orthogonal\PYGZus{}} 6 | 7 | \PYG{k}{def} \PYG{n+nf}{singular\PYGZus{}value}\PYG{p}{(}\PYG{n}{p}\PYG{p}{):} 8 | \PYG{n}{sv} \PYG{o}{=} \PYG{n}{math}\PYG{o}{.}\PYG{n}{sqrt}\PYG{p}{(}\PYG{n}{p}\PYG{o}{.}\PYG{n}{shape}\PYG{p}{[}\PYG{l+m+mi}{0}\PYG{p}{]} \PYG{o}{/} \PYG{n}{p}\PYG{o}{.}\PYG{n}{shape}\PYG{p}{[}\PYG{l+m+mi}{1}\PYG{p}{])} 9 | \PYG{k}{if} \PYG{n}{p}\PYG{o}{.}\PYG{n}{dim}\PYG{p}{()} \PYG{o}{==} \PYG{l+m+mi}{4}\PYG{p}{:} 10 | \PYG{n}{sv} \PYG{o}{/=} \PYG{n}{math}\PYG{o}{.}\PYG{n}{sqrt}\PYG{p}{(}\PYG{n}{p}\PYG{o}{.}\PYG{n}{shape}\PYG{p}{[}\PYG{l+m+mi}{2}\PYG{p}{]} \PYG{o}{*} \PYG{n}{p}\PYG{o}{.}\PYG{n}{shape}\PYG{p}{[}\PYG{l+m+mi}{3}\PYG{p}{])} 11 | \PYG{k}{return} \PYG{n}{sv} 12 | 13 | \PYG{k}{class} \PYG{n+nc}{AGD}\PYG{p}{:} 14 | \PYG{n+nd}{@torch}\PYG{o}{.}\PYG{n}{no\PYGZus{}grad}\PYG{p}{()} 15 | \PYG{k}{def} \PYG{n+nf+fm}{\PYGZus{}\PYGZus{}init\PYGZus{}\PYGZus{}}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{p}{,} \PYG{n}{net}\PYG{p}{,} \PYG{n}{gain}\PYG{o}{=}\PYG{l+m+mf}{1.0}\PYG{p}{):} 16 | 17 | \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{net} \PYG{o}{=} \PYG{n}{net} 18 | \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{depth} \PYG{o}{=} \PYG{n+nb}{len}\PYG{p}{(}\PYG{n+nb}{list}\PYG{p}{(}\PYG{n}{net}\PYG{o}{.}\PYG{n}{parameters}\PYG{p}{()))} 19 | \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{gain} \PYG{o}{=} \PYG{n}{gain} 20 | 21 | \PYG{k}{for} \PYG{n}{p} \PYG{o+ow}{in} \PYG{n}{net}\PYG{o}{.}\PYG{n}{parameters}\PYG{p}{():} 22 | \PYG{k}{if} \PYG{n}{p}\PYG{o}{.}\PYG{n}{dim}\PYG{p}{()} \PYG{o}{==} \PYG{l+m+mi}{1}\PYG{p}{:} \PYG{k}{raise} \PYG{n+ne}{Exception}\PYG{p}{(}\PYG{l+s+s2}{\PYGZdq{}Biases are not supported.\PYGZdq{}}\PYG{p}{)} 23 | \PYG{k}{if} \PYG{n}{p}\PYG{o}{.}\PYG{n}{dim}\PYG{p}{()} \PYG{o}{==} \PYG{l+m+mi}{2}\PYG{p}{:} \PYG{n}{orthogonal\PYGZus{}}\PYG{p}{(}\PYG{n}{p}\PYG{p}{)} 24 | \PYG{k}{if} \PYG{n}{p}\PYG{o}{.}\PYG{n}{dim}\PYG{p}{()} \PYG{o}{==} \PYG{l+m+mi}{4}\PYG{p}{:} 25 | \PYG{k}{for} \PYG{n}{kx} \PYG{o+ow}{in} \PYG{n+nb}{range}\PYG{p}{(}\PYG{n}{p}\PYG{o}{.}\PYG{n}{shape}\PYG{p}{[}\PYG{l+m+mi}{2}\PYG{p}{]):} 26 | \PYG{k}{for} \PYG{n}{ky} \PYG{o+ow}{in} \PYG{n+nb}{range}\PYG{p}{(}\PYG{n}{p}\PYG{o}{.}\PYG{n}{shape}\PYG{p}{[}\PYG{l+m+mi}{3}\PYG{p}{]):} 27 | \PYG{n}{orthogonal\PYGZus{}}\PYG{p}{(}\PYG{n}{p}\PYG{p}{[:,:,}\PYG{n}{kx}\PYG{p}{,}\PYG{n}{ky}\PYG{p}{])} 28 | \PYG{n}{p} \PYG{o}{*=} \PYG{n}{singular\PYGZus{}value}\PYG{p}{(}\PYG{n}{p}\PYG{p}{)} 29 | 30 | \PYG{n+nd}{@torch}\PYG{o}{.}\PYG{n}{no\PYGZus{}grad}\PYG{p}{()} 31 | \PYG{k}{def} \PYG{n+nf}{step}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{p}{):} 32 | 33 | \PYG{n}{G} \PYG{o}{=} \PYG{l+m+mi}{0} 34 | \PYG{k}{for} \PYG{n}{p} \PYG{o+ow}{in} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{net}\PYG{o}{.}\PYG{n}{parameters}\PYG{p}{():} 35 | \PYG{n}{G} \PYG{o}{+=} \PYG{n}{singular\PYGZus{}value}\PYG{p}{(}\PYG{n}{p}\PYG{p}{)} \PYG{o}{*} \PYG{n}{p}\PYG{o}{.}\PYG{n}{grad}\PYG{o}{.}\PYG{n}{norm}\PYG{p}{(}\PYG{n}{dim}\PYG{o}{=}\PYG{p}{(}\PYG{l+m+mi}{0}\PYG{p}{,}\PYG{l+m+mi}{1}\PYG{p}{))}\PYG{o}{.}\PYG{n}{sum}\PYG{p}{()} 36 | \PYG{n}{G} \PYG{o}{/=} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{depth} 37 | 38 | \PYG{n}{log} \PYG{o}{=} \PYG{n}{math}\PYG{o}{.}\PYG{n}{log}\PYG{p}{(}\PYG{l+m+mf}{0.5} \PYG{o}{*} \PYG{p}{(}\PYG{l+m+mi}{1} \PYG{o}{+} \PYG{n}{math}\PYG{o}{.}\PYG{n}{sqrt}\PYG{p}{(}\PYG{l+m+mi}{1} \PYG{o}{+} \PYG{l+m+mi}{4}\PYG{o}{*}\PYG{n}{G}\PYG{p}{)))} 39 | 40 | \PYG{k}{for} \PYG{n}{p} \PYG{o+ow}{in} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{net}\PYG{o}{.}\PYG{n}{parameters}\PYG{p}{():} 41 | \PYG{n}{factor} \PYG{o}{=} \PYG{n}{singular\PYGZus{}value}\PYG{p}{(}\PYG{n}{p}\PYG{p}{)} \PYG{o}{/} \PYG{n}{p}\PYG{o}{.}\PYG{n}{grad}\PYG{o}{.}\PYG{n}{norm}\PYG{p}{(}\PYG{n}{dim}\PYG{o}{=}\PYG{p}{(}\PYG{l+m+mi}{0}\PYG{p}{,}\PYG{l+m+mi}{1}\PYG{p}{),} \PYG{n}{keepdim}\PYG{o}{=}\PYG{k+kc}{True}\PYG{p}{)} 42 | \PYG{n}{p} \PYG{o}{\PYGZhy{}=} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{gain} \PYG{o}{*} \PYG{n}{log} \PYG{o}{/} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{depth} \PYG{o}{*} \PYG{n}{factor} \PYG{o}{*} \PYG{n}{p}\PYG{o}{.}\PYG{n}{grad} 43 | \end{Verbatim} 44 | -------------------------------------------------------------------------------- /latex/packages.tex: -------------------------------------------------------------------------------- 1 | % fonts 2 | \usepackage[utf8]{inputenc} 3 | \usepackage[T1]{fontenc} 4 | \usepackage[final]{microtype} 5 | \usepackage{fontawesome} 6 | \usepackage{dsfont} 7 | \usepackage{stmaryrd} 8 | \usepackage{verbatim} 9 | \usepackage{pifont} 10 | 11 | % links 12 | \usepackage[pagebackref=true, hidelinks]{hyperref} 13 | \usepackage{natbib} 14 | \usepackage{url} 15 | 16 | % figures 17 | \usepackage[final,pdftex]{graphicx} 18 | \usepackage{float} 19 | \usepackage{caption} 20 | \usepackage{tcolorbox} 21 | \usepackage{booktabs,array} 22 | \usepackage{tabularx} 23 | \usepackage{makecell} 24 | \usepackage{multirow} 25 | \usepackage{tikz} 26 | \usetikzlibrary{shapes.geometric,positioning} 27 | 28 | % algorithms 29 | \usepackage[frozencache,cachedir=minted-output]{minted} 30 | 31 | \usepackage[boxed]{algorithm2e} 32 | \usepackage[noend]{algcompatible} 33 | 34 | % math 35 | \usepackage[amsthm, thmmarks]{ntheorem} 36 | \usepackage{amsfonts,amssymb,amsmath} 37 | \usepackage[noabbrev,nameinlink,capitalise,compress]{cleveref} 38 | \usepackage{mathtools} 39 | \usepackage{bm} 40 | \usepackage{mathabx} 41 | \usepackage{thmtools,thm-restate} 42 | \usepackage{resizegather} 43 | \usepackage{nicefrac} 44 | 45 | % miscellaneous 46 | \usepackage{enumitem} 47 | \usepackage{setspace} 48 | \usepackage{regexpatch} -------------------------------------------------------------------------------- /latex/section/00-abstract.tex: -------------------------------------------------------------------------------- 1 | \begin{abstract} 2 | The architecture of a deep neural network is defined explicitly in terms of the number of layers, the width of each layer and the general network topology. Existing optimisation frameworks neglect this information in favour of implicit architectural information (e.g.~second-order methods) or architecture-agnostic distance functions (e.g.~mirror descent). Meanwhile, the most popular optimiser in practice---Adam---is based on heuristics. This paper builds a new framework for deriving optimisation algorithms that explicitly leverage neural architecture. The theory extends mirror descent to non-convex composite objective functions: the idea is to transform a Bregman divergence to account for the non-linear structure of neural architecture. Working through the details for deep fully-connected networks yields \textit{automatic gradient descent}: a first-order optimiser without any hyperparameters. Automatic gradient descent trains both fully-connected and convolutional networks out-of-the-box and at ImageNet scale. A PyTorch implementation is available at \url{https://github.com/jxbz/agd} and also in \cref{app:pytorch}. Overall, the paper supplies a rigorous theoretical foundation for a next-generation of architecture-dependent optimisers that work automatically and without hyperparameters. 3 | \end{abstract} -------------------------------------------------------------------------------- /latex/section/01-intro.tex: -------------------------------------------------------------------------------- 1 | \section{Introduction} 2 | 3 | Automatic differentiation has contributed to the rapid pace of innovation in the field of deep learning. Software packages such as PyTorch \citep{pytorch} and Theano \citep{theano} have advanced a programming paradigm where the user (1) defines a neural network architecture by composing differentiable operators and (2) supplies training data. The package then automatically computes the gradient of the error on the training data via recursive application of the chain rule. At this point, the user must become involved again by (3) selecting one of numerous optimisation algorithms and (4) manually tuning its hyperparameters: in particular, the initial learning rate and the learning rate decay schedule \citep{Goodfellow-et-al-2016}. 4 | 5 | But manually tuning hyperparameters is irksome. An abundance of hyperparameters makes it difficult to rank the performance of different deep learning algorithms \citep{Lucic2017AreGC,crowded_valley} and difficult to reproduce results in the literature \citep{deeprlmatters}. Hyperparameters confound our efforts to build a scientific understanding of generalisation in deep learning \citep{jiang2019fantastic, my-margin}. And, when training neural networks at the largest scale, in pursuit of stronger forms of artificial intelligence, hyperparameter grid search can rack up millions of dollars in compute costs \citep{Sharir2020TheCO}. 6 | 7 | Are hyperparameters just a fact of life? The thesis of this paper is that \textit{no: they are not}. Deep learning involves fitting a known function to known data via minimising a known objective. If we could characterise these components both individually and in how they interact, then---in principle---there should be no leftover degrees of freedom to be tuned \citep{tutorial}. Taking this idea and running with it leads to \textit{automatic gradient descent} (AGD): a neural network optimiser without any hyperparameters. AGD is complementary to automatic differentiation and could help to automate general machine learning workflows. 8 | 9 | Two existing tools are central to our derivation, and it is their novel combination that presents the main theoretical contribution of this paper. First, a classic tool from convex analysis known as the \textit{Bregman divergence} \citep{bregman1967relaxation,bregman} is used to characterise how the neural network interacts with the loss function. And second, a tool called \textit{deep relative trust} \citep{my-fromage} is used to characterise the highly non-linear interaction between the weights and the network output. With these tools in hand, we can apply the \textit{majorise-minimise meta-algorithm} \citep{mm} to derive an optimiser explicitly tailored to deep network objective functions. To summarise, the derivation of AGD follows three main steps: 10 | 11 | \begin{enumerate}[label=Step \arabic*:, leftmargin=*, font=\sffamily] 12 | \item \textsf{Functional expansion}. We use a \textit{Bregman divergence} to express the linearisation error of the objective function $\el(\vw)$ in terms of the functional perturbation $\Delta \vf$ to the network $\vf$. 13 | \item \textsf{Architectural perturbation bounds.} We use \textit{deep relative trust} to relate the size and structure of the weight perturbation $\Delta \vw$ to the size of the induced functional perturbation $\Delta \vf$. 14 | \item \textsf{Majorise-minimise.} We substitute deep relative trust into the Bregman divergence to obtain an explicitly architecture-dependent majorisation. Minimising with respect to $\Delta \vw$ yields an optimiser. 15 | \end{enumerate} 16 | 17 | \paragraph{Summary of contributions} This paper derives automatic gradient descent (AGD) by applying the majorise-minimise meta-algorithm to deep network objective functions. AGD trains all tested network architectures without hyperparameters, and scales to deep networks such as ResNet-50 and large datasets such as ImageNet. AGD trains out-of-the-box even when Adam and SGD fail to train with their default hyperparameters. 18 | 19 | \input{figures/comparison-table} 20 | 21 | \subsection{Related work} 22 | 23 | \paragraph{Optimisation theory} First-order optimisers leverage the first-order Taylor expansion of the objective function $\el(\vw)$---in particular, the gradient $\nabla_\vw\el(\vw)$. Theoretical treatments include mirror descent \citep{nemirovsky_yudin_1983}, 24 | natural gradient descent \citep{amari} and the Gauss-Newton method \citep{gauss-newton}. These methods have been explored in the context of deep learning \citep{revisiting-ngd,azizan2018stochastic,sun2022mirror}. First-order methods are amenable to deep learning since the gradient of the objective is available via recursive application of the chain rule---a.k.a.\ error back-propagation \citep{Rumelhart1986LearningRB}. 25 | 26 | Second-order optimisers leverage the second-order Taylor expansion of the objective function $\el(\vw)$---in particular, the gradient $\nabla_\vw\el(\vw)$ and Hessian $\nabla^2_\vw\el(\vw)$. Examples include Newton's method \citep{Nocedal1999NumericalO} and cubic-regularised Newton's method \citep{Nesterov2006CubicRO}. Naïvely, second-order methods are less amenable to deep learning since the cost of the relevant Hessian computations is prohibitive at high dimension. That being said, efforts have been made to circumvent this issue \citep{hessian-linear}. 27 | 28 | The majorise-minimise meta-algorithm \citep{mm} is an algorithmic pattern that can be used to derive optimisers. To apply the meta-algorithm, one must first derive an upper bound on the objective which matches the objective up to $k$th-order in its Taylor series for some integer $k$. This \textit{majorisation} can then be minimised as a proxy for reducing the original objective. \cref{fig:maj-min} illustrates the meta-algorithm for $k=1$. 29 | 30 | \paragraph{Deep learning theory} The \textit{Lipschitz smoothness assumption}---a global constraint on the eigenvalues of the Hessian---is often used to derive and analyse neural network optimisers \citep{Agarwal2016FindingAL}. But this assumption has been questioned \citep{Zhang2020Why} and evidence has even been found for the reverse relationship, where the Hessian spectrum is highly sensitive to the choice of optimiser \citep{cohen2021gradient}. 31 | 32 | These considerations motivate the development of theory that is more explicitly tailored to neural architecture. For instance, \citet{my-fromage} used an architectural perturbation bound termed \textit{deep relative trust} to characterise the neural network optimisation landscape as a function of network depth. Similarly, \citet{Yang2021TensorPI} sought to understand the role of width, leading to their \textit{maximal update parameterisation}. \cref{tab:practice,tab:theory} provide some points of comparison between automatic gradient descent and these and other frameworks. 33 | 34 | \input{figures/schematic} 35 | 36 | \subsection{Preliminaries} 37 | 38 | Given a vector $\vv$ in $\R^n$, we will need to measure its size in three different ways: 39 | 40 | \begin{definition}[Manhattan norm] The \textit{Manhattan norm} $\norm{\,\cdot\,}_1$ of a vector $\vv$ is defined by $\norm{\vv}_1 \defeq \sum_{i} \abs{\vv_i}$. 41 | \end{definition} 42 | 43 | \begin{definition}[Euclidean norm] The \textit{Euclidean norm} $\norm{\,\cdot\,}_2$ of a vector $\vv$ is defined by $\smash{\norm{\vv}_2 \defeq \sqrt{\sum_{i} \vv_i^2}}$. 44 | \end{definition} 45 | 46 | \begin{definition}[Infinity norm] The \textit{infinity norm} $\norm{\,\cdot\,}_\infty$ of a vector $\vv$ is defined by $\norm{\vv}_\infty \defeq \max_{i} \abs{\vv_i}$. 47 | \end{definition} 48 | 49 | For a matrix $\mM$ in $\R^{m \times n}$, the reader should be aware that it has a singular value decomposition: 50 | \begin{fact}[SVD] Every matrix $\mM$ in $\R^{m\times n}$ admits a \textit{singular value decomposition} (SVD) of the form $\mM = \sum_{i=1}^{\min(m,n)} \sigma_i(\mM) \cdot \vu_i \vv_i^\top$ where the \textit{left singular vectors} $\{\vu_i\}$ are orthonormal vectors in $\R^{m}$, the \textit{right singular vectors} $\{\vv_i\}$ are orthonormal vectors in $\R^{m}$ and the \textit{singular values} $\{\sigma_i(\mM)\}$ are non-negative scalars. 51 | \end{fact} 52 | 53 | The singular value decomposition allows us to measure the size of a matrix in two different ways: 54 | 55 | \begin{definition}[Frobenius norm] The \textit{Frobenius norm} $\norm{\,\cdot\,}_F$ of a matrix $\mM$ is given by $\norm{\mM}_F \defeq \sqrt{\sum_{i} \sigma_i(\mM)^2}$. 56 | \end{definition} 57 | \begin{definition}[Operator norm] The \textit{operator norm} $\norm{\,\cdot\,}_*$ of a matrix $\mM$ is given by $\norm{\mM}_* \defeq \max_i \sigma_i(\mM)$. 58 | \end{definition} 59 | While the operator norm $\norm{\mM}_*$ reports the largest singular value, the quantity $\norm{\mM}_F / \sqrt{\min(m,n)}$ reports the root mean square singular value. Finally, we will need to understand two aspects of matrix conditioning: 60 | \begin{definition}[Rank] The \textit{rank} of a matrix counts the number of non-zero singular values. 61 | \end{definition} 62 | \begin{definition}[Stable rank] 63 | The \textit{stable rank} of a matrix $\mM$ is defined by $\srank \mM \defeq \norm{\mM}_F^2 / \norm{\mM}_*^2$. 64 | \end{definition} 65 | The stable rank provides an approximation to the rank that ignores the presence of very small singular values. Let us consider the extremes. An orthogonal matrix $\mO\in\R^{m\times n}$ has both full rank and full stable rank: $\rank \mO = \srank \mO = \min(m,n)$. A rank-one matrix $\mP$ has unit stable rank and satisfies $\norm{\mP}_* = \norm{\mP}_F$. -------------------------------------------------------------------------------- /latex/section/03-sketch.tex: -------------------------------------------------------------------------------- 1 | \section{Majorise-Minimise for Generic Learning Problems} 2 | \label{sec:mm-ml} 3 | 4 | \input{figures/theory-table} 5 | 6 | This section develops a framework for applying the majorise-minimise meta-algorithm to generic optimisation problems in machine learning. In particular, the novel technique of \textit{functional expansion} is introduced. \cref{sec:mm-dnn} will apply this technique to deep neural networks. All proofs are supplied in \cref{app:proofs}. 7 | 8 | Given a machine learning model and a set of training data, our objective is to minimise the error of the model, averaged over the training data. Formally, we would like to minimise the following function: 9 | 10 | \begin{definition}[Composite objective] Consider a machine learning model $\vf$ that maps an input $\vx$ and a weight vector $\vw$ to output $\vf(\vx;\vw)$. Given data $\set{S}$ and a convex loss function $\ell$, the \textit{objective} $\el(\vw)$ is defined by: 11 | \begin{equation*} 12 | \el(\vw) \defeq \frac{1}{|\set{S}|}\sum_{(\vx,\vy) \in \set{S}} \ell(\vf(\vx;\vw), \vy). 13 | \end{equation*} 14 | \end{definition} 15 | We refer to this objective as \textit{composite} since the loss function $\ell$ is \textit{composed} with a machine learning model $\vf$. While the loss function itself is convex, the overall composite is often non-convex due to the non-linear machine learning model. Common convex loss functions include the square loss and the cross-entropy loss: 16 | 17 | \begin{example}[Square loss]\label{ex:sq-loss} The \textit{square loss} is defined by: $\ell(\vf(\vx; \vw), \vy) \defeq \frac{1}{2d_L} \norm{\vf(\vx; \vw) - \vy}_2^2$. 18 | \end{example} 19 | \begin{example}[Xent loss]\label{ex:xent-loss} The \textit{cross-entropy (xent) loss} is defined by: $\ell(\vf(\vx), \vy) \defeq - \log [\softmax(\vf(\vx))]^\top \vy$, where the softmax function is defined by $\softmax(\vf(\vx))\defeq \exp \vf(\vx) / \norm{\exp \vf(\vx)}_1$. 20 | \end{example} 21 | 22 | \subsection{Decomposition of linearisation error} 23 | 24 | First-order optimisers leverage the linearisation of the objective at the current iterate. To design such methods, we must understand the realm of validity of this linearisation. To that end, we derive a very general decomposition of the linearisation error of a machine learning system. The result is stated in terms of a \textit{perturbation hierarchy}. In particular, perturbing the weight vector of a machine learning model $\vw \to \vw + \Delta \vw$ induces perturbations to the model output $\vf \to \vf + \Delta \vf$, to the loss on individual data samples $\ell \to \ell + \Delta \ell$ and, at last, to the overall objective function $\el \to \el + \Delta \el$. Formally, a weight perturbation $\Delta \vw$ induces: 25 | \begin{flalign*} 26 | &\Delta \vf(\vx) &&\coloneqq \vf(\vx;\vw+\Delta \vw) - \vf(\vx; \vw); \hspace{16em} \tag{functional perturbation}\\ 27 | &\Delta \ell(\vf(\vx), \vy) &&\coloneqq \ell(\vf(\vx)+\Delta \vf(\vx),\vy) - \ell(\vf(\vx),\vy); \tag{loss perturbation}\\ 28 | &\Delta \el(\vw) &&\coloneqq \textstyle\frac{1}{|\set{S}|}\sum_{(\vx,\vy) \in \set{S}} \Delta \ell(\vf(\vx), \vy) \tag{objective perturbation}. 29 | \end{flalign*} 30 | We have adopted a compact notation where the dependence of $\vf(\vx;\vw)$ on $\vw$ is at times suppressed. The perturbation hierarchies of a generic machine learning model and a deep neural network are visualised in \cref{fig:maj-min,fig:apbs}, respectively. The linearisation error of the objective perturbation $\Delta \el$ decomposes as: 31 | 32 | \begin{restatable}[Decomposition of linearisation error]{proposition}{decomposition}\label{thm:decomposition}For any differentiable loss $\ell$ and any differentiable machine learning model $\vf$ the linearisation error of the objective function $\el$ admits the following decomposition: 33 | \begin{align*} 34 | \quad\quad\quad\underbrace{\Delta \el(\vw) - \nabla_\vw\el(\vw)^\top \Delta \vw}_{\mathclap{\text{linearisation error of objective}}} \quad\quad&= &&\frac{1}{|\set{S}|}\sum_{(\vx,\vy)\in \set{S}} \nabla_{\vf(\vx)} \ell(\vf(\vx),\vy)^\top \underbrace{\left[\Delta \vf(\vx) - \nabla_\vw \vf(\vx) \Delta \vw \right]}_{\mathclap{\text{linearisation error of model}}} \\ &&+\,&\frac{1}{|\set{S}|}\sum_{(\vx,\vy)\in \set{S}}\underbrace{\Delta \ell(\vf(\vx), \vy) -\nabla_{\vf(\vx)}\ell(\vf(\vx),\vy)^\top \Delta \vf(\vx)}_{\text{linearisation error of loss}}.\quad\quad\quad \nonumber 35 | \end{align*} 36 | \end{restatable} 37 | In words: the linearisation error of the objective decomposes into two terms. The first depends on the linearisation error of the machine learning model and the second the loss. This decomposition relies on nothing but differentiability. For a convex loss, the second term may be interpreted as a Bregman divergence: 38 | 39 | \begin{definition}[Bregman divergence of loss]\label{def:bregman} For any convex loss $\ell$: 40 | \begin{flalign*} 41 | \qquad\qquad\qquad\qquad\bregman_{\ell(\cdot,\vy)}(\vf(\vx), \Delta \vf(\vx)) \defeq \Delta \ell(\vf(\vx), \vy) -\nabla_{\vf(\vx)}\ell(\vf(\vx),\vy)^\top \Delta \vf(\vx). && 42 | \end{flalign*} 43 | \end{definition} 44 | 45 | A Bregman divergence is just the linearisation error of a convex function. Two important examples are: 46 | 47 | \input{figures/apbs} 48 | 49 | \begin{restatable}[Bregman divergence of square loss]{lemma}{squarebreg}\label{lem:sq-bregman} 50 | When $\ell$ is set to square loss, then: 51 | \begin{flalign*} 52 | \qquad\qquad\qquad\qquad\bregman_{\ell(\cdot,\vy)}(\vf(\vx), \Delta \vf(\vx)) = \tfrac{1}{2d_L} \norm{\Delta \vf(\vx)}_2^2.&& 53 | \end{flalign*} 54 | \end{restatable} 55 | 56 | \begin{restatable}[Bregman divergence of xent loss]{lemma}{xentbreg} \label{lem:xent-bregman} 57 | When $\ell$ is set to cross-entropy loss, and if $\vy^\top \bm{1} =1$, then: 58 | \begin{flalign*} 59 | \qquad\qquad\qquad\qquad\bregman_{\ell(\cdot,\vy)}(\vf(\vx), \Delta \vf(\vx)) &= \kl \Big(\softmax(\vf(\vx))\,\Big|\Big|\, \softmax(\vf(\vx)+\Delta \vf(\vx))\Big)&& \\ 60 | &\leq \half\norm{\Delta \vf(\vx)}_\infty^2 + \mathcal{O}(\Delta \vf^3).&& 61 | \end{flalign*} 62 | \end{restatable} 63 | 64 | Our methods may be applied to other convex losses by calculating or bounding their Bregman divergence. 65 | 66 | \subsection{Functional expansion and functional majorisation} 67 | 68 | Before continuing, we make one simplifying assumption. Observe that the first term on the right-hand side of \cref{thm:decomposition} is a high-dimensional inner product between two vectors. Since there is no clear reason why these two vectors should be aligned, let us assume that their inner product is zero: 69 | \begin{assumption}[Orthogonality of model linearisation error]\label{ass:orthog} 70 | In the same setting as \cref{thm:decomposition}: 71 | \begin{equation*} 72 | \frac{1}{|\set{S}|}\sum_{(\vx,\vy)\in \set{S}} \nabla_{\vf(\vx)} \ell(\vf(\vx),\vy)^\top \underbrace{\left[\Delta \vf(\vx) - \nabla_\vw \vf(\vx) \Delta \vw \right]}_{\mathclap{\text{linearisation error of model}}} = 0. 73 | \end{equation*} 74 | \end{assumption} 75 | 76 | While it is possible to work without this assumption \citep{bernstein-thesis}, we found that its inclusion simplifies the analysis and in practice did not lead to a discernible weakening of the resulting algorithm. In any case, this assumption is considerably milder than the common assumption in the literature \citep{revisiting-ngd,NEURIPS2019_0d1a9651} that the model linearisation error is itself zero: $\left[\Delta \vf(\vx) - \nabla_\vw \vf(\vx) \Delta \vw \right] = 0$. 77 | 78 | Armed with \cref{thm:decomposition} and \cref{ass:orthog}, we are ready to introduce functional expansion and majorisation: 79 | 80 | \begin{restatable}[Functional expansion]{theorem}{functmajor}\label{thm:functmajor}Consider a convex differentiable loss $\ell$ and a differentiable machine learning model $\vf$. Under \cref{ass:orthog}, the corresponding composite objective $\el$ admits the expansion: 81 | \begin{align*} 82 | \el(\vw + \Delta \vw) = \underbrace{\el(\vw) + \nabla_\vw\el(\vw)^\top \Delta \vw}_{\text{first-order Taylor series}} +\frac{1}{|\set{S}|}\sum_{(\vx,\vy)\in \set{S}}\bregman_{\ell(\cdot,\vy)}(\vf(\vx), \Delta \vf(\vx)). 83 | \end{align*} 84 | \end{restatable} 85 | So the perturbed objective $\el(\vw+\Delta \vw)$ may be written as the sum of its first-order Taylor expansion with a Bregman divergence in the model outputs averaged over the training set. 86 | It is straightforward to specialise this result to different losses by substituting in their Bregman divergence: 87 | 88 | \begin{restatable}[Functional expansion of mean squared error]{corollary}{sqmajor}\label{lem:sq-major} Under \cref{ass:orthog}, for square loss: 89 | \begin{flalign*} 90 | \qquad\qquad\qquad\qquad\el(\vw + \Delta \vw) = \el(\vw) + \nabla_\vw\el(\vw)^\top \Delta \vw +\frac{1}{|\set{S}|}\sum_{(\vx,\vy)\in \set{S}}\tfrac{1}{2d_L} \norm{\Delta \vf(\vx)}_2^2.&& 91 | \end{flalign*} 92 | \end{restatable} 93 | 94 | \begin{restatable}[Functional majorisation for xent loss]{corollary}{xentmajor}\label{lem:xent-major} 95 | Under \cref{ass:orthog}, for cross-entropy loss, if $\vy^\top \bm{1} =1$: 96 | \begin{flalign*} 97 | \qquad\qquad\qquad\qquad\el(\vw + \Delta \vw) \leq \el(\vw) + \nabla_\vw\el(\vw)^\top \Delta \vw +\frac{1}{|\set{S}|}\sum_{(\vx,\vy)\in \set{S}}\half\norm{\Delta \vf(\vx)}_\infty^2 + \mathcal{O}(\Delta \vf^3).&& 98 | \end{flalign*} 99 | \end{restatable} 100 | 101 | When the functional perturbation is reasonably ``spread out'', we would expect $\norm{\Delta \vf(\vx)}_\infty^2 \approx \norm{\Delta \vf(\vx)}_2^2/d_L$. In this setting, the functional majorisation of cross-entropy loss agrees with the functional expansion of mean squared error to second order. While the paper derives automatic gradient descent for the square loss, this observation justifies its application to cross-entropy loss, as in the case of the ImageNet experiments. 102 | 103 | \subsection{Recovering existing frameworks} 104 | \label{sec:recover} 105 | 106 | We briefly observe that three existing optimisation frameworks may be recovered efficiently from \cref{thm:functmajor}: 107 | 108 | \paragraph{Mirror descent} For linear models $\vf(\vx;\mW) \defeq \mW \vx$, the Bregman divergence $\bregman_{\ell(\cdot,\vy)}(\vf(\vx), \Delta \vf(\vx))$ may be written $\bregman_{\ell(\cdot,\vy)}(\mW\vx, \Delta\mW\vx)$. This is a convex function of the weight perturbation $\Delta \mW$. Substituting into \cref{thm:functmajor} and minimising with respect to $\Delta \mW$ is the starting point for mirror descent. 109 | 110 | \paragraph{Gauss-Newton method} Substituting the linearised functional perturbation $\Delta \vf(\vx) \approx \nabla_\vw \vf(\vx) \Delta \vw$ into \cref{lem:sq-major} and minimising with respect to $\Delta \vw$ is the starting point for the Gauss-Newton method. 111 | 112 | \paragraph{Natural gradient descent} Substituting the linearised functional perturbation $\Delta \vf(\vx) \approx \nabla_\vw \vf(\vx) \Delta \vw$ into \cref{lem:xent-major} and minimising with respect to $\Delta \vw$ is the starting point for natural gradient descent. -------------------------------------------------------------------------------- /latex/section/04-bounds.tex: -------------------------------------------------------------------------------- 1 | \section{Majorise-Minimise for Deep Learning Problems} 2 | \label{sec:mm-dnn} 3 | 4 | In this section, we will focus our efforts on deriving an optimiser for deep fully-connected networks trained with square loss. The derivation for cross-entropy loss is analogous. Proofs are relegated to \cref{app:proofs}. 5 | 6 | \begin{definition}[Fully-connected network]\label{def:dln} 7 | A \textit{fully-connected network (FCN)} $\vf$ of depth $L$ maps an input $\vx\in\R^{d_0}$ to an output $\vf(\vx;\vw) \in \R^{d_L}$ via $L$ matrix multiplications interspersed by non-linearity $\relu(z) \defeq \max(0,z)$: 8 | \begin{equation*} 9 | \vf(\vx; \vw) \coloneqq \mW_L\circ(\relu{}\circ \mW_{L - 1}) \circ(\relu{}\circ \mW_{L - 2}) \circ \dots \circ (\relu{} \circ \mW_1 \vx). 10 | \end{equation*} 11 | \end{definition} 12 | 13 | In this expression, $\vw$ denotes the tuple of matrices $\vw = (\mW_1,...,\mW_L)$ with $k$th matrix $\mW_k$ in $\R^{d_k\times d_{k-1}}$. In what follows, we will find the following dimensional scaling to be particularly convenient: 14 | \begin{prescription}[Dimensional scaling]\label{prescription:norm} For $\eta>0$, the data $(\vx,\vy)$, weights $\mW_k$ and updates $\Delta\mW_k$ should obey: 15 | \begin{align*} 16 | \norm{\vx}_2 &= \sqrt{d_0}; \tag{input scaling} \\ 17 | \norm{\mW_k}_* &= \sqrt{d_k/d_{k-1}} \hspace{1.519em}\qquad\text{for all }k=1,...,L; \tag{weight scaling} \\ 18 | \norm{\Delta \mW_k}_* &= \sqrt{d_k/d_{k-1}} \cdot \tfrac{\eta}{L} \qquad\text{for all }k=1,...,L; \tag{update scaling}\\ 19 | \norm{\vy}_2 &= \sqrt{d_L}. \tag{target scaling} 20 | \end{align*} 21 | \end{prescription} 22 | While results can be derived without adopting \cref{prescription:norm}, the scalings substantially simplify our formulae. One reason for this is that, under \cref{prescription:norm}, we have the telescoping property that $\prod_{k=1}^L \norm{\mW_k}_* = \sqrt{d_L/d_0}$. For a concrete example of how this helps, consider the following bound on the norm of the network outputs: 23 | 24 | \begin{restatable}[Output bound]{lemma}{outbound} 25 | \label{lem:outbound} The output norm of a fully-connected network $\vf$ obeys the following bound: 26 | \begin{align*} 27 | \norm{\vf(\vx;\vw)}_2 &\leq \left[\prod_{k=1}^L \norm{\mW_k}_* \right] \times \norm{\vx}_2 = \sqrt{d_L} \text{ under \cref{prescription:norm}}. 28 | \end{align*} 29 | \end{restatable} 30 | 31 | So, under \cref{prescription:norm}, the bound is simple. Furthermore, the scaling of the update with a single parameter $\eta$ reduces the problem of solving for an optimiser to a single parameter problem. To see how this might make life easier, consider the following lemma that relates weight perturbations to functional perturbations: 32 | 33 | \begin{restatable}[Deep relative trust]{lemma}{archbounds} 34 | \label{lem:deep_perturbation_bounds} 35 | When adjusting the weights $\vw = (\mW_1,...,\mW_L)$ of a fully-connected network $\vf$ by $\Delta\vw = (\Delta\mW_1,...,\Delta\mW_L)$, the induced functional perturbation $\Delta \vf(\vx)\defeq\vf(\vx;\vw+\Delta\vw)-\vf(\vx;\vw)$ obeys: 36 | \begin{align*} 37 | \norm{\Delta\vf(\vx)}_2 &\leq \left[\prod_{k=1}^L \norm{\mW_k}_* \right] \times \norm{\vx}_2 \times \left[ \prod_{k = 1}^L \left( 1 + \frac{\Vert \Delta \mW_k \Vert_{*}}{\Vert \mW_k \Vert_{*}}\right) - 1 \right] \leq \sqrt{d_L}\times(\exp \eta - 1) \text{ under \cref{prescription:norm}}. 38 | \end{align*} 39 | \end{restatable} 40 | So, under \cref{prescription:norm}, the single parameter $\eta$ directly controls the size of functional perturbations. 41 | 42 | In terms of enforcing \cref{prescription:norm} in practice, the norms of the data $(\vx,\vy)$ may be set via pre-processing, the norm of the update $\Delta \mW_k$ may be set via the optimisation algorithm and the norm of the weight matrix $\mW_k$ may be set by the choice of initialisation. While, yes, $\norm{\mW_k}_*$ may drift during training, the amount that this can happen is limited by \citet{Weyl1912}'s inequality for singular values. In particular, after one step the perturbed operator norm $\norm{\mW_k + \Delta \mW_K}_*$ is sandwiched like $(1-\eta/L) \cdot \norm{\mW_k}_* \leq \norm{\mW_k + \Delta \mW_K}_* \leq (1+\eta/L) \cdot\norm{\mW_k}_*$. 43 | 44 | \input{algorithm/agd} 45 | 46 | \subsection{Deriving automatic gradient descent} 47 | 48 | With both functional majorisation and deep relative trust in hand, we can majorise the deep network objective: 49 | 50 | 51 | 52 | \begin{restatable}[Exponential majorisation]{lemma}{majordnn}\label{lem:sq-major-nn} 53 | For an FCN with square loss, under \cref{ass:orthog} and \cref{prescription:norm}: 54 | \begin{equation*} 55 | \el(\vw+\Delta \vw) \leq \el(\vw) + \frac{\eta}{L}\sum_{k=1}^L\left[\sqrt{d_k/d_{k-1}} \times\trace\frac{\Delta \mW_k^\top\nabla_{\mW_k}\el}{\norm{\Delta \mW_k}_*}\right] + \tfrac{1}{2} \,(\exp \eta -1)^2. 56 | \end{equation*} 57 | \end{restatable} 58 | 59 | Observe that the majorisation only depends on the magnitude of the scalar $\eta$ and on some notion of angle $\trace\Delta \mW_k^\top\nabla_{\mW_k}\el/\norm{\Delta \mW_k}_*$ between the perturbation matrix $\Delta \mW_k$ and the gradient matrix $\nabla_{\mW_k}\el$. To derive an optimiser, we would now like to minimise this majorisation with respect to $\eta$ and this angle. First, let us introduce one additional assumption and one additional definition: 60 | \begin{assumption}[Gradient conditioning]\label{approx:g-cond} The gradient satisfies $\srank\nabla_{\mW_k}\el=1$ at all layers $k=1,...,L$. 61 | \end{assumption} 62 | This assumption implies that the Frobenius norm $\norm{\nabla_{\mW_k}\el}_F$ and operator norm $\norm{\nabla_{\mW_k}\el}_*$ of the gradient at layer $k$ are equal. It is not immediately obvious why this should be a good assumption. After all, the gradient is a sum of $\abs{\set{S}}$ rank-one matrices: $\nabla_{\mW_k}\el = \tfrac{1}{\abs{\set{S}}} \sum_{(\vx,\vy)\in\set{S}} \nabla_{\vh_k}\ell(\vf(\vx),\vy) \otimes \vh_{k-1}$, where $\vh_{k-1}(\vx)$ and $\vh_k(\vx)$ denote the inputs and outputs of the weight matrix $\mW_k$ at layer $k$, and $\otimes$ denotes the outer product. So, naïvely, one might expect the gradient $\nabla_{\mW_k}\el$ to have a stable rank of $\min(d_k,d_{k-1},\abs{\set{S}})$. But it turns out to be a good assumption in practice \citep{Yang2021TensorPI,yang2021tuning}. And for the definition: 63 | 64 | \begin{definition}[Gradient summary]\label{def:gsummary} 65 | At a weight setting $\vw$, the \textit{gradient summary} $G$ is given by: 66 | \begin{align*} 67 | G & \defeq \frac{1}{L}\sum_{k=1}^L \sqrt{d_k/d_{k-1}} \cdot \norm{ \nabla_{\mW_k} \el(\vw)}_F. 68 | \end{align*} 69 | \end{definition} 70 | The gradient summary is a weighted average of gradient norms over layers. It can be thought of as a way to measure the size of the gradient while accounting for the fact that the weight matrices at different layers may be on different scales. This is related to the concept of the \textit{gradient scale coefficient} of \citet{Philipp2017TheEG}. 71 | 72 | We now have everything we need to derive automatic gradient descent via the majorise-minimise principle: 73 | 74 | \begin{restatable}[Automatic gradient descent]{theorem}{loglr}\label{thm:log-lr} 75 | For a deep fully-connected network, under \cref{ass:orthog,approx:g-cond} and \cref{prescription:norm}, the majorisation of square loss given in \cref{lem:sq-major-nn} is minimised by setting: 76 | \begin{align*} 77 | \eta = \log\frac{1 + \sqrt{1+4G}}{2},\qquad 78 | \Delta \mW_k = - \frac{\eta}{L}\cdot \sqrt{d_k/d_{k-1}} \cdot\frac{\nabla_{\mW_k} \el}{\norm{\nabla_{\mW_k} \el}_F}, \qquad \text{for all layers } k=1,...,L. 79 | \end{align*} 80 | \end{restatable} 81 | 82 | We present pseudocode for this theorem in \cref{alg:agd}, and a PyTorch implementation in \cref{app:pytorch}. Via a simple derivation based on clear algorithmic principles, automatic gradient descent unifies various heuristic and theoretical ideas that have appeared in the literature: 83 | \begin{itemize}[leftmargin=*] 84 | \item \textit{Relative updates.} The update is scaled relative to the norm of the weight matrix to which it is applied---assuming the weight matrices are scaled according to \cref{prescription:norm}. Such a scaling was proposed by \citet{You:EECS-2017-156} and further explored by \citet{carbonnelle2019layer} and \citet{my-fromage}. There is evidence that such relative synaptic updates may occur in neuroscience \citep{Loewenstein9481}. 85 | \item \textit{Depth scaling.} Scaling the perturbation strength like $1/L$ for networks of depth $L$ was proposed on theoretical grounds by \citet{my-fromage} based on analysis via deep relative trust. 86 | \item \textit{Width scaling.} The dimensional factors of $d_k$ and $d_{k-1}$ that appear closely relate to the maximal update parameterisation of \citet{Yang2021TensorPI} designed to ensure hyperparameter transfer across network width. 87 | \item \textit{Gradient clipping.} The logarithmic dependence of the update on the gradient summary may be seen as an automatic form of \textit{adaptive gradient clipping} \citep{pmlr-v139-brock21a}---a technique which clips the gradient once its magnitude surpasses a certain threshold set by a hyperparameter. 88 | \end{itemize} 89 | 90 | \subsection{Convergence analysis} 91 | 92 | This section presents theoretical convergence rates for automatic gradient descent. While the spirit of the analysis is standard in optimisation theory, the details may still prove interesting for their detailed characterisation of the optimisation properties of deep networks. For instance, we propose a novel Polyak-Łojasiewicz inequality tailored to the operator structure of deep networks. We begin with two observations: 93 | 94 | \begin{restatable}[Bounded objective]{lemma}{objectivebound}\label{lem:objectivebound} 95 | For square loss, the objective is bounded as follows: 96 | \begin{align*} 97 | \el(\vw) &\leq \frac{1}{\abs{\set{S}}} \sum_{(\vx,\vy)\in \set{S}}\frac{\norm{\vf(\vx;\vw)}_2^2 +\norm{\vy}_2^2}{2d_L} \leq 1 \text{ under \cref{prescription:norm}.} 98 | \end{align*} 99 | \end{restatable} 100 | 101 | \begin{restatable}[Bounded gradient]{lemma}{gradientbound}\label{lem:gradientbound} 102 | For square loss, the norm of the gradient at layer $k$ is bounded as follows: 103 | \begin{align*} 104 | \norm{\nabla_{\mW_k}\el}_F &\leq \frac{\prod_{l=1}^L\norm{\mW_l}_*}{\norm{\mW_k}_*} \cdot \sqrt{\frac{2\el(\vw)}{d_L}} \cdot \sqrt{\frac{1}{\abs{\set{S}}} \sum_{(\vx,\vy)\in \set{S}}\norm{\vx}_2^2} \leq \sqrt{2\cdot\frac{d_{k-1}}{d_k}} \text{ under \cref{prescription:norm}.} 105 | \end{align*} 106 | \end{restatable} 107 | 108 | These results help us prove that automatic gradient descent converges to a point where the gradient vanishes: 109 | 110 | \begin{restatable}[Convergence rate to critical point]{lemma}{criticalrate}\label{lem:criticalrate} 111 | Consider a fully-connected network trained by automatic gradient descent (\cref{thm:log-lr}) and square loss for $T$ iterations. Let $G_t$ denote the gradient summary (\cref{def:gsummary}) at step $t\leq T$. Under \cref{ass:orthog,approx:g-cond} and \cref{prescription:norm}, AGD converges at the following rate:\vspace{-0.5em} 112 | \begin{equation*} 113 | \min_{t\in\{1,...,T\}} G_t^2 \leq \frac{11}{T}. 114 | \end{equation*} 115 | \end{restatable} 116 | 117 | This lemma can be converted into a convergence rate to a global minimum with one additional assumption: 118 | 119 | \begin{assumption}[Deep Polyak-Łojasiewicz inequality] \label{ass:pl} 120 | For some $\alpha>0$, the gradient norm is lower bounded by: 121 | \begin{align*} 122 | \norm{\nabla_{\mW_k}\el}_F &\geq \alpha \times \frac{\prod_{l=1}^L\norm{\mW_l}_*}{\norm{\mW_k}_*} \cdot \sqrt{\frac{2\el(\vw)}{d_L}} \cdot \sqrt{\frac{1}{\abs{\set{S}}} \sum_{(\vx,\vy)\in \set{S}}\norm{\vx}_2^2} = \alpha \times \sqrt{2\cdot\el(\vw)\cdot\frac{d_{k-1}}{d_k}} \text{ under \cref{prescription:norm}.} 123 | \end{align*} 124 | \end{assumption} 125 | This lower bound mirrors the structure of the upper bound in \cref{lem:gradientbound}. The parameter $\alpha$ captures how much of the gradient is attenuated by small singular values in the weights and by deactivated $\relu$ units. While Polyak-Łojasiewicz inequalities are common in the literature \citep{LIU202285}, our assumption is novel in that it pays attention to the operator structure of the network. \cref{ass:pl} leads to the following theorem: 126 | 127 | \begin{restatable}[Convergence rate to global minima]{theorem}{globalrate}\label{thm:globalrate} 128 | For automatic gradient descent (\cref{thm:log-lr}) in the same setting as \cref{lem:criticalrate} but with the addition of \cref{ass:pl}, the mean squared error objective at step $T$ obeys: 129 | \begin{align*} 130 | \el(\vw_T) \leq \frac{1}{\alpha^2}\times\frac{6}{T}. 131 | \end{align*} 132 | \end{restatable} -------------------------------------------------------------------------------- /latex/section/07-experiments.tex: -------------------------------------------------------------------------------- 1 | \subsection{Experiments} 2 | 3 | \input{figures/exp1} 4 | 5 | The goal of our experiments was twofold. First, we wanted to test automatic gradient descent (AGD, \cref{alg:agd}) on a broad variety of networks architectures and datasets to check that it actually works. In particular, we tested AGD on fully-connected networks (FCNs, \cref{def:dln}), and both VGG-style \citep{simonyan2015a} and ResNet-style \citep{He2015DeepRL} convolutional neural networks on the CIFAR-10, CIFAR-100 \citep{Krizhevsky09learningmultiple} and ImageNet \citep[ILSVRC2012]{deng2009imagenet} datasets with standard data augmentation. And second, to see what AGD may have to offer beyond the status quo, we wanted to compare AGD to tuned Adam and SGD baselines, as well as Adam and SGD run with their default hyperparameters. 6 | 7 | To get AGD working with convolutional layers, we adopted a per-submatrix normalisation scheme. Specifically, for a convolutional tensor with filters of size $\mathtt{k_x} \times \mathtt{k_y}$, we implemented the normalisation separately for each of the $\mathtt{k_x} \times \mathtt{k_y}$ submatrices of dimension $\mathtt{channels_{in}} \times \mathtt{channels_{out}}$. Since AGD does not yet support biases or affine parameters in batchnorm, we disabled these parameters in all architectures. To at least adhere to \cref{prescription:norm} at initialisation, AGD draws initial weight matrices uniform semi-orthogonal and re-scaled by a factor of $\sqrt{\mathtt{fan\_in}/\mathtt{fan\_out}}$. Adam and SGD baselines used the PyTorch default initialisation. A PyTorch implementation of AGD reflecting these details is given in \cref{app:pytorch}. All experiments use square loss except ImageNet which used cross-entropy loss. Cross-entropy loss has been found to be superior to square loss for datasets with a large number of classes \citep{Demirkaya2020ExploringTR,HuiSquareCrossEntropy}. 8 | 9 | % We will quote the test accuracy of a model to be that of the epoch with the lowest training loss. 10 | 11 | Our experimental results are spread across five figures: 12 | \begin{itemize}[leftmargin=*] 13 | \item \cref{fig:showcase} presents some highlights of our results: First, AGD can train networks that Adam and SGD with default hyperparameters cannot. Second, for ResNet-18 on CIFAR-10, AGD attained performance comparable to the best-tuned performance of Adam and SGD. And third, AGD scales up to ImageNet. 14 | \item \cref{fig:1} displays the breadth of our experiments: from training a 16-layer fully-connected network on CIFAR-10 to training ResNet-50 on ImageNet. Adam's learning rate was tuned over the logarithmic grid $\{10^{-5},10^{-4},...,10^{-1}\}$ while for ImageNet we used a default learning rate of 0.1 for SGD without any manual decay. AGD and Adam performed almost equally well on the depth-16 width-512 fully-connected network: 52.7\% test accuracy for AGD compared to 53.5\% for Adam. 15 | For ResNet-18 on CIFAR-10, Adam attained 92.9\% test accuracy compared to AGD's 91.2\%. On this benchmark, a fully-tuned SGD with learning rate schedule, weight decay, cross-entropy loss and bias and affine parameters can attain 93.0\% test accuracy \citep{kuangliu}. For VGG-16 on CIFAR-100, AGD achieved 67.4\% test accuracy compared to Adam's 69.7\%. 16 | Finally, on ImageNet AGD achieved a top-1 test accuracy of 65.5\% after 350 epochs. 17 | \item \cref{fig:2} compares AGD to Adam and SGD for training an eight-layer fully-connected network of width 256. Adam and SGD's learning rates were tuned over the logarithmic grid $\{10^{-5},10^{-4},...,10^{-1}\}$. Adam's optimal learning rate of $10^{-4}$ was three orders of magnitude smaller than SGD's optimal learning rate of $10^{-1}$. SGD did not attain as low of an objective value as Adam or AGD. 18 | \item \cref{fig:3} shows that AGD can train FCNs with width ranging from 64 to 2048 and depth from 2 to 32 and \cref{fig:4} shows that AGD successfully trains a four-layer FCN at varying mini-batch size: from 32 to 4096. 19 | \end{itemize} 20 | 21 | \input{figures/exp2} 22 | \input{figures/exp3} 23 | \input{figures/exp4} 24 | -------------------------------------------------------------------------------- /latex/section/08-discuss.tex: -------------------------------------------------------------------------------- 1 | \section{Discussion} 2 | 3 | This paper has proposed a new framework for deriving optimisation algorithms for non-convex composite objective functions, which are particularly prevalent in the field of machine learning and the subfield of deep learning. What we have proposed is truly a \textit{framework}: it can be applied to a new loss function by writing down its Bregman divergence, or a new machine learning model by writing down its architectural perturbation bound. The framework is properly placed in the context of existing frameworks such as the majorise-minimise meta-algorithm, mirror descent and natural gradient descent. 4 | 5 | Recent papers have proposed a paradigm of \textit{hyperparameter transfer} where a small network is tuned and the resulting hyperparameters are transferred to a larger network \citep{yang2021tuning, bernstein-thesis}. The methods and results in this paper suggest a stronger paradigm of \textit{hyperparameter elimination}: by detailed analysis of the structure and interactions between different components of a machine learning system, we may hope---if not to outright outlaw hyperparameters---at least to reduce their abundance and opacity. 6 | 7 | The main product of this research is automatic gradient descent (AGD), with pseudocode given in \cref{alg:agd} and PyTorch code given in \cref{app:pytorch}. We have found AGD to be genuinely useful, and believe that it may complement automatic differentiation in helping to automate general machine learning workflows. 8 | 9 | The analysis leading to automatic gradient descent is elementary: we leverage basic concepts in linear algebra such as matrix and vector norms, and use simple bounds such as the triangle inequality for vector--vector sums, and the operator norm bound for matrix--vector products. The analysis is non-asymptotic: it does not rely on taking dimensions to infinity, and deterministic: it does not involve random matrix theory. We believe that the accessibility of the analysis could make this paper a good starting point for future developments. 10 | 11 | \paragraph{Directions for future work} Here we list some promising avenues for theoretical and practical research. We are exploring some of these ideas in our development codebase: \url{https://github.com/C1510/agd_exp}. 12 | 13 | \begin{itemize}[leftmargin=*] 14 | \item \textit{Stochastic optimisation.} Automatic gradient descent is derived in the full-batch optimisation setting, but the algorithm is evaluated experimentally in the mini-batch setting. It would be interesting to try to extend our theoretical and practical methods to more faithfully address stochastic optimisation. 15 | \item \textit{More architectures.} Automatic gradient descent is derived for fully-connected networks and extended heuristically to convolutional networks. We are curious to extend the methods to more varied architectures such as transformers \citep{NIPS2017_3f5ee243} and architectural components such as biases. Since most neural networks resemble fully-connected networks in the sense that they are all just deep compound operators, we expect much of the structure of automatic gradient descent as presented to carry through. 16 | \item \textit{Regularisation.} The present paper deals purely with the optimisation structure of deep neural networks, and little thought is given to either generalisation or regularisation. Future work could look at both theoretical and practical regularisation schemes for automatic gradient descent. It would be interesting to try to do this without introducing hyperparameters, although we suspect that when it comes to regularisation at least one hyperparameter may become necessary. 17 | \item \textit{Acceleration.} We have found in some preliminary experiments that slightly increasing the update size of automatic gradient descent with a gain hyperparameter, or introducing a momentum hyperparameter, can lead to faster convergence. We emphasise that no experiment in this paper used such hyperparameters. Still, these observations may provide a valuable starting point for improving AGD in future work. 18 | \item \textit{Operator perturbation theory.} Part of the inspiration for this paper was the idea of applying operator perturbation theory to deep learning. While perturbation theory is well-studied in the context of linear operators \citep{Weyl1912,Kato:1966:PTL,STEWART200653}, in deep learning we are concerned with non-linear compound operators. It may be interesting to try to further extend results in perturbation theory to deep neural networks. One could imagine cataloging the perturbation structure of different neural network building blocks, and using a result similar to deep relative trust (\cref{lem:deep_perturbation_bounds}) to describe how they compound. 19 | \end{itemize} -------------------------------------------------------------------------------- /latex/section/98-ack.tex: -------------------------------------------------------------------------------- 1 | \subsubsection*{Acknowledgments} 2 | 3 | The authors are grateful to MIT SuperCloud, Oxford Hydra, NVIDIA and Virgile Richard for providing GPUs. Thanks are due to Greg Yang and Jamie Simon for helpful discussions. A paper with Greg and Jamie is in preparation to explain the relationship between muP \citep{Yang2021TensorPI} and the operator norm. -------------------------------------------------------------------------------- /latex/section/99-appendix.tex: -------------------------------------------------------------------------------- 1 | \section{Proofs} 2 | \label{app:proofs} 3 | 4 | Here are the proofs for the theoretical results in the main text. 5 | 6 | \decomposition* 7 | \begin{proof}[\mbox{\hyperref[thm:decomposition]{Proof}}]\label{proof:decomposition} 8 | By the chain rule, $\nabla_\vw\el(\vw)^\top \Delta \vw = \frac{1}{|\set{S}|}\sum_{(\vx,\vy)\in \set{S}} \nabla_{\vf(\vx)} \ell(\vf(\vx),\vy)^\top \nabla_\vw \vf(\vx) \Delta \vw$. Therefore: 9 | \begin{equation*} 10 | \Delta \el(\vw) - \nabla_\vw\el(\vw)^\top \Delta \vw = \frac{1}{|\set{S}|}\sum_{(\vx,\vy)\in \set{S}}\Delta \ell(\vf(\vx), \vy) - \nabla_{\vf(\vx)} \ell(\vf(\vx),\vy)^\top \nabla_\vw \vf(\vx) \Delta \vw. 11 | \end{equation*} 12 | Adding and subtracting $\frac{1}{|\set{S}|}\sum_{(\vx,\vy)\in \set{S}}\nabla_{\vf(\vx)}\ell(\vf(\vx),\vy)^\top \Delta \vf(\vx)$ on the right-hand side yields the result. 13 | \end{proof} 14 | 15 | \squarebreg* 16 | \begin{proof}[\mbox{\hyperref[lem:sq-bregman]{Proof}}]\label{proof:squarebreg} 17 | Expanding the Euclidean norms in the loss perturbation $\Delta \ell$ yields: 18 | \begin{align*} 19 | \Delta \ell(\vf(\vx), \vy) & = \tfrac{1}{2d_L} \norm{\vf(\vx) + \Delta \vf(\vx) - \vy}_2^2 - \tfrac{1}{2d_L} \norm{\vf(\vx) - \vy}_2^2 \\ 20 | &= \tfrac{1}{2d_L} \norm{\Delta \vf(\vx)}_2^2 + (\vf(\vx) - \vy)^\top \Delta \vf(\vx). 21 | \end{align*} 22 | The result follows by identifying that $\nabla_{\vf(\vx)}\ell(\vf(\vx),\vy)^\top \Delta \vf(\vx) = (\vf(\vx) - \vy)^\top \Delta \vf(\vx)$. 23 | \end{proof} 24 | 25 | \xentbreg* 26 | \begin{proof}[\mbox{\hyperref[lem:xent-bregman]{Proof}}]\label{proof:xentbreg} 27 | First, since $\sum_i \vy_i =1$, cross-entropy loss may be re-written: 28 | \begin{align*} 29 | \ell(\vf(\vx), \vy) \defeq - \log [\softmax(\vf(\vx))]^\top \vy = - \vf(\vx)^\top \vy + \log \norm{\exp \vf(\vx)}_1. 30 | \end{align*} 31 | The linear term $- \vf(\vx)^\top \vy$ does not contribute to the linearisation error and may be neglected. Therefore: 32 | \begin{align*} 33 | &\Delta \ell(\vf(\vx), \vy) -\nabla_{\vf(\vx)}\ell(\vf(\vx),\vy)^\top \Delta \vf(\vx) \\ 34 | &\quad\quad= \log \norm{\exp (\vf(\vx)+\Delta \vf(\vx))}_1 - \log \norm{\exp \vf(\vx)}_1 - \nabla_{\vf(\vx)}\log \norm{\exp \vf(\vx)}_1^\top \Delta \vf(\vx) \\ 35 | &\quad\quad= \log \frac{1/\norm{\exp \vf(\vx)}_1}{1/\norm{\exp (\vf(\vx)+\Delta \vf(\vx))}_1} - \frac{\exp\vf(\vx)^\top}{\norm{\exp \vf(\vx)}_1} \Delta \vf(\vx)\\ 36 | &\quad\quad=\frac{\exp\vf(\vx)^\top}{\norm{\exp \vf(\vx)}_1} \log \frac{\exp \vf(\vx)/\norm{\exp \vf(\vx)}_1}{\exp (\vf(\vx)+\Delta \vf(\vx))/\norm{\exp (\vf(\vx)+\Delta \vf(\vx))}_1}. 37 | \end{align*} 38 | The final line is equivalent to $\kl \Big(\softmax(\vf(\vx))\,\Big|\Big|\, \softmax(\vf(\vx)+\Delta \vf(\vx))\Big)$ establishing the first equality. 39 | 40 | To establish the inequality, let $\otimes$ denote the outer product and define $p \defeq\softmax(f(\vx))$. Then we have: 41 | \begin{align*} 42 | \Delta \ell(\vf(\vx), \vy) -\nabla_{\vf(\vx)}\ell(\vf(\vx),\vy)^\top \Delta \vf(\vx) &= \frac{1}{2}\Delta \vf(\vx)^\top \nabla^2_{\vf(\vx)}\ell(\vf(\vx), \vy) \Delta \vf(\vx) + \mathcal{O}(\Delta \vf^3) \\ 43 | &= \frac{1}{2}\Delta \vf(\vx)^\top \nabla^2_{\vf(\vx)}\log \norm{\exp \vf(\vx)}_1 \Delta \vf(\vx) + \mathcal{O}(\Delta \vf^3)\\ 44 | &= \frac{1}{2}\Delta \vf(\vx)^\top [\diag (p) - p \otimes p] \Delta \vf(\vx) + \mathcal{O}(\Delta \vf^3)\\ 45 | &\leq \frac{1}{2}\Delta \vf(\vx)^\top \diag (p) \Delta \vf(\vx) + \mathcal{O}(\Delta \vf^3)\\ 46 | &\leq \frac{1}{2}\norm{\Delta \vf(\vx)}_\infty^2 + \mathcal{O}(\Delta \vf^3), 47 | \end{align*} 48 | where we have used that $p\otimes p$ is positive definite and then applied H\"older's inequality with $\norm{p}_1 = 1$. 49 | \end{proof} 50 | 51 | \functmajor* 52 | \begin{proof}[\mbox{\hyperref[thm:functmajor]{Proof}}]\label{proof:functmajor} 53 | The result follows by substituting \cref{ass:orthog} into \cref{thm:decomposition} and applying \cref{def:bregman}. 54 | \end{proof} 55 | 56 | \sqmajor* 57 | \begin{proof}[\mbox{\hyperref[lem:sq-major]{Proof}}]\label{proof:sqmajor} Combine \cref{lem:sq-bregman} with \cref{thm:functmajor} to obtain the result. 58 | \end{proof} 59 | 60 | \xentmajor* 61 | \begin{proof}[\mbox{\hyperref[lem:xent-major]{Proof}}]\label{proof:xentmajor} Combine \cref{lem:xent-bregman} with \cref{thm:functmajor} to obtain the result. 62 | \end{proof} 63 | 64 | \outbound* 65 | \begin{proof}[\mbox{\hyperref[lem:outbound]{Proof}}]\label{proof:outbound} 66 | For any vector $\vv$ and matrix $\mM$ with compatible dimensions, we have that $\norm{\mM \vv}_2 \leq \norm{\mM}_* \cdot \norm{\vv}_2$ and $\norm{\relu \vv}_2 \leq \norm{\vv}_2$. The lemma follows by applying these results recursively over the depth of the network. 67 | \end{proof} 68 | 69 | \archbounds* 70 | \begin{proof}[\mbox{\hyperref[lem:deep_perturbation_bounds]{Proof}}]\label{proof:archbounds} We proceed by induction. First, consider a network with $L=1$ layers: $\vf(\vx) = \mW_1 \vx$. Observe that $\norm{\Delta \vf(\vx)}_2 = \norm{\Delta \mW_1 \vx}_2 \leq \norm{\Delta \mW_1}_*\cdot \norm{\vx}_2$ as required. Next, assume that the result holds for a network $\vg(\vx)$ with $L-1$ layers and consider adding a layer to obtain $\vf(\vx) = \mW_L\circ \relu{}\circ \vg(\vx)$. Then: 71 | \begin{align*} 72 | \norm{\Delta \vf(\vx)}_2 &= \norm{(\mW_L+\Delta \mW_L)\circ \relu{} \circ (\vg(\vx)+\Delta \vg(\vx)) - \mW_L \circ \relu{} \circ \vg(\vx)}_2 \\ 73 | &= \norm{\mW_L \left(\relu{} \circ (\vg(\vx)+\Delta 74 | \vg(\vx)) - \relu{} \circ \vg(\vx)\right) + \Delta \mW_L \left( \relu{} \circ (\vg(\vx)+\Delta \vg(\vx)) - \relu(0)\right)}_2 \\ 75 | &\leq \norm{\mW_L}_*\cdot\norm{\Delta \vg(\vx)}_2 + \norm{\Delta \mW_L}_*\cdot(\norm{\vg(\vx)}_2 + \norm{\Delta \vg(\vx)}_2)\\ 76 | &= (\norm{\mW_L}_*+\norm{\Delta \mW_L}_*)\cdot \norm{\Delta \vg(\vx)}_2 + \norm{\Delta \mW_L}_*\cdot \norm{\vg(\vx)}_2, 77 | \end{align*} 78 | where the inequality follows by applying the triangle inequality, the operator norm bound, the fact that $\relu{}$ is one-Lipschitz, and a further application of the triangle inequality. But by the inductive hypothesis and \cref{lem:outbound}, the right-hand side is bounded by: 79 | \begin{align*} 80 | (\norm{\mW_L}_*&+\norm{\Delta \mW_L}_*) \left[ \prod_{k = 1}^{L-1} \left( 1 + \frac{\Vert \Delta \mW_k \Vert_{*}}{\Vert \mW_k \Vert_{*}}\right) - 1 \right] \times \left[\prod_{k=1}^{L-1} \norm{\mW_k}_* \right] \times \norm{\vx}_2 + \norm{\Delta \mW_L}_* \times \left[\prod_{k=1}^{L-1} \norm{\mW_k}_* \right] \times \norm{\vx}_2\\ 81 | &= \left[ \prod_{k = 1}^L \left( 1 + \frac{\Vert \Delta \mW_k \Vert_{*}}{\Vert \mW_k \Vert_{*}}\right) - 1 \right] \times \left[\prod_{k=1}^L \norm{\mW_k}_* \right] \times \norm{\vx}_2. 82 | \end{align*} 83 | The induction is complete. To further bound this result under \cref{prescription:norm}, observe that the product $\left[\prod_{k=1}^L \norm{\mW_k}_* \right] \times \norm{\vx}_2$ telescopes to just $\sqrt{d_L}$, while the other product satisfies: 84 | \begin{equation*} 85 | \left[ \prod_{k = 1}^L \left( 1 + \frac{\Vert \Delta \mW_k \Vert_{*}}{\Vert \mW_k \Vert_{*}}\right) - 1 \right] = \left(1+\frac{\eta}{L}\right)^L -1 \leq \lim_{L\to\infty}\left(1+\frac{\eta}{L}\right)^L-1 = \exp\eta - 1. 86 | \end{equation*} 87 | Combining these observations yields the result. 88 | \end{proof} 89 | 90 | \majordnn* 91 | \begin{proof}[\mbox{\hyperref[lem:sq-major-nn]{Proof}}]\label{proof:majordnn} 92 | Substitute \cref{lem:deep_perturbation_bounds} into \cref{lem:sq-major} and decompose $\nabla_\vw\el(\vw)^\top \Delta \vw = \sum_{k=1}^L \trace (\Delta \mW_k^\top \nabla_{\mW_k}\el)$. The result follows by realising that under \cref{prescription:norm}, the perturbations satisfy $\norm{\Delta \mW_k}_* = \sqrt{d_k/d_{k-1}} \cdot \frac{\eta}{L}$. 93 | \end{proof} 94 | 95 | \loglr* 96 | \begin{proof}[\mbox{\hyperref[thm:log-lr]{Proof}}]\label{proof:loglr} The inner product $\trace\frac{\Delta \mW_k^\top\nabla_{\mW_k}\el}{\norm{\Delta \mW_k}_*}$ that appears in \cref{lem:sq-major-nn} is most negative when the perturbation $\Delta \mW_k$ satisfies $\Delta \mW_k/\norm{\Delta \mW_k}_* = - \nabla_{\mW_k}\el / \norm{\nabla_{\mW_k}\el}_*$. Substituting this result back into \cref{lem:sq-major-nn} yields: 97 | \begin{equation*} 98 | \el(\vw+\Delta \vw) \leq \el(\vw) - \frac{\eta}{L}\sum_{k=1}^L\left[\sqrt{d_k/d_{k-1}} \times\frac{\norm{\nabla_{\mW_k}\el}_F^2}{\norm{\nabla_{\mW_k}\el}_*}\right] + \tfrac{1}{2} \,(\exp \eta -1)^2. 99 | \end{equation*} 100 | Under \cref{approx:g-cond}, we have that $\norm{\nabla_{\mW_k}\el}_F^2/\norm{\nabla_{\mW_k}\el}_* = \norm{\nabla_{\mW_k}\el}_F$ and so this inequality simplifies to: 101 | \begin{equation*} 102 | \el(\vw+\Delta \vw) \leq \el(\vw) - \eta\cdot G + \tfrac{1}{2} \,(\exp \eta -1)^2. 103 | \end{equation*} 104 | Taking the derivative of the right-hand side with respect to $\eta$ and setting it to zero yields $(\exp\eta-1)\exp\eta = G$. Applying the quadratic formula and retaining the positive solution yields $\exp \eta = \half(1+\sqrt{1+4G})$. Combining this with the relation that $\Delta \mW_k/\norm{\Delta \mW_k}_* = - \nabla_{\mW_k}\el / \norm{\nabla_{\mW_k}\el}_*$ and applying that $\norm{\Delta \mW_k}_* = \sqrt{d_k/d_{k-1}} \cdot \frac{\eta}{L}$ by \cref{prescription:norm} yields the result. 105 | \end{proof} 106 | 107 | \objectivebound* 108 | \begin{proof}[\mbox{\hyperref[lem:objectivebound]{Proof}}]\label{proof:objectivebound} 109 | The result follows by the following chain of inequalities: 110 | \begin{align*} 111 | \el(\vw) \defeq \frac{1}{\abs{\set{S}}} \sum_{(\vx,\vy)\in \set{S}}\frac{1}{2d_L}\norm{\vf(\vx;\vw) - \vy}_2^2 \leq \frac{1}{\abs{\set{S}}} \sum_{(\vx,\vy)\in \set{S}}\frac{1}{2d_L}(\norm{\vf(\vx;\vw)}_2^2 +\norm{\vy}_2^2) \leq \frac{1}{\abs{\set{S}}} \sum_{(\vx,\vy)\in \set{S}}\frac{d_L+d_L}{2d_L} = 1, 112 | \end{align*} 113 | where the second inequality holds under \cref{prescription:norm}. 114 | \end{proof} 115 | 116 | \gradientbound* 117 | \begin{proof}[\mbox{\hyperref[lem:gradientbound]{Proof}}]\label{proof:gradientbound} 118 | By the chain rule, the gradient of mean square error objective may be written: 119 | \begin{align*} 120 | \nabla_{\mW_k} \el(\vw) = \frac{1}{\abs{\set{S}}} \sum_{(\vx,\vy)\in \set{S}}\frac{1}{d_L}(\vf(\vx;\vw) - \vy)^\top \mW_L \cdot \mD_{L-1}\mW_{L-1} \dots \mD_{k+1}\mW_{k+1} \cdot \mD_{k} \otimes \mD_{k-1} \mW_{k-1}\dots \mD_1 \mW_1 \vx, 121 | \end{align*} 122 | where $\otimes$ denotes the outer product and $\mD_k$ denotes a diagonal matrix whose entries are one when $\relu$ is active and zero when $\relu$ is inactive. Since the operator norm $\norm{\mD_k}_* = 1$, we have that the Frobenius norm $\norm{\nabla_{\mW_k} \el(\vw)}_F$ is bounded from above by: 123 | \begin{align*} 124 | &\frac{1}{\abs{\set{S}}} \sum_{(\vx,\vy)\in \set{S}}\frac{1}{d_L}\norm{(\vf(\vx;\vw) - \vy)^\top \mW_L \cdot \mD_{L-1}\mW_{L-1} \dots \mD_{k+1}\mW_{k+1} \cdot \mD_{k} \otimes \mD_{k-1} \mW_{k-1}\dots \mD_1 \mW_1 \vx}_F\\ 125 | &\hspace{3em}= \frac{1}{\abs{\set{S}}} \sum_{(\vx,\vy)\in \set{S}}\frac{1}{d_L}\norm{(\vf(\vx;\vw) - \vy)^\top \mW_L \cdot \mD_{L-1}\mW_{L-1} \dots \mD_{k+1}\mW_{k+1} \cdot \mD_{k}}_2 \cdot \norm{\mD_{k-1} \mW_{k-1}\dots \mD_1 \mW_1 \vx}_2\\ 126 | &\hspace{3em}\leq \frac{1}{\abs{\set{S}}} \sum_{(\vx,\vy)\in \set{S}}\frac{1}{d_L}\norm{\vf(\vx;\vw) - \vy}_2\cdot \norm{\mW_L}_*\cdot \norm{\mW_{L-1}}_* \dots \norm{\mW_{k+1}}_*\cdot \norm{\mW_{k-1}}_*\dots \norm{\mW_1}_*\cdot \norm{\vx}_2 \\ 127 | &\hspace{3em}= \frac{\prod_{l=1}^L\norm{\mW_l}_*}{\norm{\mW_k}} \times \frac{1}{\abs{\set{S}}} \sum_{(\vx,\vy)\in \set{S}}\frac{1}{d_L}\norm{\vf(\vx;\vw) - \vy}_2 \cdot \norm{\vx}_2 \\ 128 | &\hspace{3em}\leq \frac{\prod_{l=1}^L\norm{\mW_l}_*}{\norm{\mW_k}_*} \cdot\frac{1}{\sqrt{d_L}} \sqrt{\frac{2}{\abs{\set{S}}} \sum_{(\vx,\vy)\in \set{S}}\frac{1}{2d_L}\norm{\vf(\vx;\vw) - \vy}_2^2} \cdot \sqrt{\frac{1}{\abs{\set{S}}} \sum_{(\vx,\vy)\in \set{S}}\norm{\vx}_2^2}\\ 129 | &\hspace{3em}= \frac{\prod_{l=1}^L\norm{\mW_l}_*}{\norm{\mW_k}_*} \cdot \sqrt{\frac{2\el(\vw)}{d_L}} \cdot \sqrt{\frac{1}{\abs{\set{S}}} \sum_{(\vx,\vy)\in \set{S}}\norm{\vx}_2^2}. 130 | \end{align*} 131 | In the above argument, the first inequality follows by recursive application of the operator norm upper bound, and the second inequality follows from the Cauchy-Schwarz inequality. The right-hand side simplifies under \cref{prescription:norm}, and we may apply \cref{lem:objectivebound} to obtain: 132 | \begin{align*} 133 | \norm{\nabla_{\mW_k} \el(\vw)}_F \leq \frac{\prod_{l=1}^L\norm{\mW_l}_*}{\norm{\mW_k}_*} \cdot \sqrt{\frac{2\el(\vw)}{d_L}} \cdot \sqrt{\frac{1}{\abs{\set{S}}} \sum_{(\vx,\vy)\in \set{S}}\norm{\vx}_2^2} \leq \frac{\sqrt{d_L/d_0}}{\sqrt{d_k / d_{k-1}}} \cdot \sqrt{\frac{2}{d_L}}\cdot \sqrt{d_0} = \sqrt{2}\cdot \sqrt{\frac{d_{k-1}}{d_k}}. 134 | \end{align*} 135 | \end{proof} 136 | 137 | \criticalrate* 138 | \begin{proof}[\mbox{\hyperref[lem:criticalrate]{Proof}}]\label{proof:criticalrate} 139 | \cref{thm:log-lr} prescribes that $\exp\eta = \half(1+\sqrt{1+4G})$, and so $\eta = \log\big(1+\frac{\sqrt{1+4G}-1}{2}\big)$. We begin by proving some useful auxiliary bounds. By \cref{lem:gradientbound} and \cref{prescription:norm}, the gradient summary is bounded by: 140 | \begin{align*} 141 | G \defeq \frac{1}{L}\sum_{k=1}^L \sqrt{d_k/d_{k-1}} \cdot \norm{ \nabla_{\mW_k} \el(\vw)}_F \leq \frac{1}{L}\sum_{k=1}^L \sqrt{2} < 2. 142 | \end{align*} 143 | The fact that the gradient summary $G$ is less than two is important because, for $x\leq 1$, we have that $\log(1+x) \geq x \log 2$. In turn, this implies that since $G<2$, we have that $\eta = \log \frac{1+\sqrt{1+4G}}{2} \geq \frac{\sqrt{1+4G} - 1}{2} \log 2$. It will also be important to know that for $G<2$, we have that $\half\cdot G \leq \tfrac{\sqrt{1+4G} - 1}{2} \leq G$. 144 | 145 | With these bounds in hand, the analysis becomes fairly standard. By an intermediate step in the proof of \cref{thm:log-lr}, the change in objective across a single step is bounded by: 146 | \begin{align*} 147 | \el(\vw+\Delta \vw)- \el(\vw)&\leq - \eta\cdot G + \tfrac{1}{2} \,(\exp \eta -1)^2 \\ 148 | &\leq - \tfrac{\sqrt{1+4G} - 1}{2} (G \log 2 - \half \tfrac{\sqrt{1+4G} - 1}{2})\\ 149 | &\leq -\half \cdot (\log 2 - \half)\cdot G^2 150 | \leq -G^2 / 11, 151 | \end{align*} 152 | where the second and third inequalities follow by our auxiliary bounds. Letting $G_t$ denote the gradient summary at step $t$, averaging this bound over time steps and applying the telescoping property yields: 153 | \begin{equation*} 154 | \min_{t\in[1,...,T]} G_t^2 \leq \frac{1}{T}\sum_{t=1}^{T} G_t^2 \leq \frac{11}{T}\sum_{t=1}^{T} \el(\vw_t) - \el(\vw_{t+1}) = \frac{11}{T}\cdot (\el(\vw_1) - \el(\vw_T)) \leq \frac{11}{T}, 155 | \end{equation*} 156 | where the final inequality follows by \cref{lem:objectivebound} and the fact that $\el(\vw_T)\geq0$. 157 | 158 | 159 | \end{proof} 160 | 161 | \globalrate* 162 | \begin{proof}[\mbox{\hyperref[thm:globalrate]{Proof}}]\label{proof:globalrate} 163 | 164 | By \cref{ass:pl}, the gradient summary at time step $t$ must satisfy $G_t \geq \alpha \times \sqrt{2\cdot\el(\vw_t)}$. Therefore the objective at time step $t$ is bounded by $\el(\vw_t) \leq G_t^2/(2\alpha^2)$. Combining with \cref{lem:criticalrate} then yields that: 165 | \begin{equation*} 166 | \el(\vw_T) = \min_{t\in[1,...,T]} \el(\vw_t) \leq \frac{1}{2\alpha^2}\min_{t\in[1,...,T]}G_t^2 \leq \frac{6}{\alpha^2T}. 167 | \end{equation*} 168 | The proof is complete. 169 | \end{proof} 170 | 171 | \newpage 172 | \section{PyTorch Implementation} 173 | \label{app:pytorch} 174 | 175 | The following code implements automatic gradient descent in PyTorch \citep{pytorch}. We include a single gain hyperparameter which controls the update size and may be increased from its default value of 1.0 to slightly accelerate training. We emphasise that all the results reported in the paper used a gain of unity. 176 | 177 | \inputminted[ 178 | frame=single, 179 | framesep=2mm, 180 | ]{python}{algorithm/agd.py} 181 | 182 | -------------------------------------------------------------------------------- /latex/tmlr/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /latex/tmlr/fancyhdr.sty: -------------------------------------------------------------------------------- 1 | %% 2 | %% This is file `fancyhdr.sty', 3 | %% generated with the docstrip utility. 4 | %% 5 | %% The original source files were: 6 | %% 7 | %% fancyhdr.dtx (with options: `fancyhdr') 8 | %% 9 | %% This is a generated file. 10 | %% 11 | %% This file may be distributed and/or modified under the conditions of 12 | %% the LaTeX Project Public License, either version 1.3 of this license 13 | %% or (at your option) any later version. The latest version of this 14 | %% license is in: 15 | %% 16 | %% http://www.latex-project.org/lppl.txt 17 | %% 18 | %% and version 1.3 or later is part of all distributions of LaTeX version 19 | %% 2005/12/01 or later. 20 | %% 21 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 22 | \NeedsTeXFormat{LaTeX2e} 23 | \ProvidesPackage{fancyhdr}% 24 | [2021/01/28 v4.0.1 25 | Extensive control of page headers and footers]% 26 | % Copyright (C) 1994-2021 by Pieter van Oostrum