├── .gitignore ├── LICENSE ├── README.md ├── airbench ├── __init__.py ├── lib_airbench93.py ├── lib_airbench94.py ├── lib_airbench95.py ├── lib_airbench96.py └── utils.py ├── airbench94_muon.py ├── airbench96_faster.py ├── img ├── airbench94_intro.png └── alternating_flip.png ├── legacy ├── airbench94.py ├── airbench94_compiled.py ├── airbench95.py └── airbench96.py ├── research └── airbench94_muon_simple.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | cifar10/ 2 | dist/ 3 | *.egg-info/ 4 | *__pycache__/ 5 | upload.sh 6 | trash/ 7 | logs/ 8 | 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2024 Keller Jordan 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | "Software"), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CIFAR-10 Airbench 💨 2 | 3 | This repo contains the two fastest known algorithms for training a neural network to 94% or 96% accuracy on CIFAR-10 using a single NVIDIA A100 GPU. 4 | 5 | | Script | Mean accuracy | Time | PFLOPs | 6 | | - | - | - | - | 7 | | [airbench94_muon.py](./airbench94_muon.py) | 94.01% | 2.59s | 0.29 | 8 | | [airbench96_faster.py](airbench96_faster.py) | 96.00% | 27.3s | 3.1 | 9 | 10 | (Timings done using `torch==2.4.1` on a 400W NVIDIA A100) 11 | 12 | For comparison, the standard training used in most studies on CIFAR-10 is much slower: 13 | 14 | | Baseline | Mean accuracy | Time | PFLOPs | 15 | | - | - | - | - | 16 | | Standard ResNet-18 training | 96.0% | 7min | 32.3 | 17 | 18 | ## Quickstart 19 | 20 | The current speedrun record for 94% on CIFAR-10 can be run via: 21 | ``` 22 | git clone https://github.com/KellerJordan/cifar10-airbench.git 23 | cd airbench 24 | python airbench94_muon.py 25 | ``` 26 | 27 | (requires `torch` and `torchvision` to be installed) 28 | 29 | --- 30 | --- 31 | 32 | ## Methods 33 | 34 | The set of methods used to obtain these training speeds are described in [the paper](https://arxiv.org/abs/2404.00498). 35 | 36 | In addition, [airbench94_muon.py](airbench94_muon.py) uses the [Muon optimizer](https://kellerjordan.github.io/posts/muon/) and [airbench96_faster.py](airbench96_faster.py) uses a form of data filtering. These are both new records since the paper. 37 | We have preserved the paper's old records as well in the `legacy` folder: 38 | 39 | | Script | Mean accuracy | Time | PFLOPs | 40 | | - | - | - | - | 41 | | [airbench94_compiled.py](legacy/airbench94_compiled.py) | 94.01% | 3.09s | 0.36 | 42 | | [airbench94.py](legacy/airbench94.py) | 94.01% | 3.83s | 0.36 | 43 | | [airbench95.py](legacy/airbench95.py) | 95.01% | 10.4s | 1.4 | 44 | | [airbench96.py](legacy/airbench96.py) | 96.03% | 34.7s | 4.9 | 45 | 46 | ![alt](img/alternating_flip.png) 47 | ![curve](img/airbench94_intro.png) 48 | 49 | ## Motivation 50 | 51 | CIFAR-10 is one of the most widely used datasets in machine learning, facilitating [thousands of research projects per year](https://paperswithcode.com/dataset/cifar-10). 52 | This repo provides fast and stable training baselines for CIFAR-10 in order to help accelerate this research. 53 | The trainings are provided as easily runnable dependency-free PyTorch scripts, and can replace classic baselines like training ResNet-20 or ResNet-18. 54 | 55 | ## Using the GPU-accelerated dataloader independently 56 | 57 | For writing custom CIFAR-10 experiments or trainings, you may find it useful to use the GPU-accelerated dataloader independently. 58 | ``` 59 | import airbench 60 | train_loader = airbench.CifarLoader('/tmp/cifar10', train=True, aug=dict(flip=True, translate=4, cutout=16), batch_size=500) 61 | test_loader = airbench.CifarLoader('/tmp/cifar10', train=False, batch_size=1000) 62 | 63 | for epoch in range(200): 64 | for inputs, labels in train_loader: 65 | # outputs = model(inputs) 66 | # loss = F.cross_entropy(outputs, labels) 67 | ... 68 | ``` 69 | 70 | If you wish to modify the data in the loader, it can be done like so: 71 | ``` 72 | import airbench 73 | train_loader = airbench.CifarLoader('/tmp/cifar10', train=True, aug=dict(flip=True, translate=4, cutout=16), batch_size=500) 74 | mask = (train_loader.labels < 6) # (this is just an example, the mask can be anything) 75 | train_loader.images = train_loader.images[mask] 76 | train_loader.labels = train_loader.labels[mask] 77 | print(len(train_loader)) # The loader now contains 30,000 images and has batch size 500, so this prints 60. 78 | ``` 79 | 80 | ## Example data-selection experiment 81 | 82 | Airbench can be used as a platform for experiments in data selection and active learning. 83 | The following is an example experiment which demonstrates the classic result that low-confidence examples provide more training signal than random examples. 84 | It runs in <20 seconds on an A100. 85 | 86 | ``` 87 | import torch 88 | from airbench import train94, infer, evaluate, CifarLoader 89 | 90 | net = train94(label_smoothing=0) # train this network without label smoothing to get a better confidence signal 91 | 92 | loader = CifarLoader('cifar10', train=True, batch_size=1000) 93 | logits = infer(net, loader) 94 | conf = logits.log_softmax(1).amax(1) # confidence 95 | 96 | train_loader = CifarLoader('cifar10', train=True, batch_size=1024, aug=dict(flip=True, translate=2)) 97 | mask = (torch.rand(len(train_loader.labels)) < 0.6) 98 | print('Training on %d images selected randomly' % mask.sum()) 99 | train_loader.images = train_loader.images[mask] 100 | train_loader.labels = train_loader.labels[mask] 101 | train94(train_loader, epochs=16) # yields around 93% accuracy 102 | 103 | train_loader = CifarLoader('cifar10', train=True, batch_size=1024, aug=dict(flip=True, translate=2)) 104 | mask = (conf < conf.float().quantile(0.6)) 105 | print('Training on %d images selected based on minimum confidence' % mask.sum()) 106 | train_loader.images = train_loader.images[mask] 107 | train_loader.labels = train_loader.labels[mask] 108 | train94(train_loader, epochs=16) # yields around 94% accuracy => low-confidence sampling is better than random. 109 | ``` 110 | 111 | ## Prior work 112 | 113 | This project builds on the excellent previous record https://github.com/tysam-code/hlb-CIFAR10 (6.3 A100-seconds to 94%). 114 | 115 | Which itself builds on the amazing series https://myrtle.ai/learn/how-to-train-your-resnet/ (26 V100-seconds to 94%, which is >=8 A100-seconds) 116 | 117 | -------------------------------------------------------------------------------- /airbench/__init__.py: -------------------------------------------------------------------------------- 1 | from .lib_airbench93 import train93, make_net93 2 | from .lib_airbench94 import train94, make_net94 3 | from .lib_airbench95 import train95, make_net95 4 | from .lib_airbench96 import train96, make_net96 5 | from .utils import infer, evaluate, CifarLoader 6 | 7 | def warmup93(*args, **kwargs): 8 | return train93(*args, run=-1, **kwargs) 9 | def warmup94(*args, **kwargs): 10 | return train94(*args, run=-1, **kwargs) 11 | def warmup95(*args, **kwargs): 12 | return train95(*args, run=-1, **kwargs) 13 | def warmup96(*args, **kwargs): 14 | return train96(*args, run=-1, **kwargs) 15 | 16 | -------------------------------------------------------------------------------- /airbench/lib_airbench93.py: -------------------------------------------------------------------------------- 1 | # 93.00 in n=50 2 | # Achieves 93% in roughly 1/4 of the FLOPs of 94% (but only ~40% less wallclock time on A100) 3 | 4 | from .utils import train, CifarLoader 5 | 6 | ############################################# 7 | # Setup/Hyperparameters # 8 | ############################################# 9 | 10 | import torch 11 | from torch import nn 12 | import torch.nn.functional as F 13 | 14 | torch.backends.cudnn.benchmark = True 15 | 16 | # We express the main training hyperparameters (batch size, learning rate, momentum, and weight decay) 17 | # in decoupled form, so that each one can be tuned independently. This accomplishes the following: 18 | # * Assuming time-constant gradients, the average step size is decoupled from everything but the lr. 19 | # * The size of the weight decay update is decoupled from everything but the wd. 20 | # In constrast, normally when we increase the (Nesterov) momentum, this also scales up the step size 21 | # proportionally to 1 + 1 / (1 - momentum), meaning we cannot change momentum without having to re-tune 22 | # the learning rate. Similarly, normally when we increase the learning rate this also increases the size 23 | # of the weight decay, requiring a proportional decrease in the wd to maintain the same decay strength. 24 | # 25 | # The practical impact is that hyperparameter tuning is faster, since this parametrization allows each 26 | # one to be tuned independently. See https://myrtle.ai/learn/how-to-train-your-resnet-5-hyperparameters/. 27 | 28 | hyp = { 29 | 'opt': { 30 | 'train_epochs': 11.0, 31 | 'batch_size': 1024, 32 | 'lr': 11.5, # learning rate per 1024 examples 33 | 'momentum': 0.85, 34 | 'weight_decay': 0.0153, # weight decay per 1024 examples (decoupled from learning rate) 35 | 'bias_scaler': 64.0, # scales up learning rate (but not weight decay) for BatchNorm biases 36 | 'label_smoothing': 0.2, 37 | 'whiten_bias_epochs': 3, # how many epochs to train the whitening layer bias before freezing 38 | }, 39 | 'aug': { 40 | 'flip': True, 41 | 'translate': 2, 42 | }, 43 | 'net': { 44 | 'widths': { 45 | 'block1': 64, 46 | 'block2': 128, 47 | 'block3': 128, 48 | }, 49 | 'batchnorm_momentum': 0.6, 50 | 'scaling_factor': 1/9, 51 | 'tta_level': 2, # the level of test-time augmentation: 0=none, 1=mirror, 2=mirror+translate 52 | }, 53 | } 54 | 55 | ############################################# 56 | # Network Components # 57 | ############################################# 58 | 59 | class Flatten(nn.Module): 60 | def forward(self, x): 61 | return x.view(x.size(0), -1) 62 | 63 | class Mul(nn.Module): 64 | def __init__(self, scale): 65 | super().__init__() 66 | self.scale = scale 67 | def forward(self, x): 68 | return x * self.scale 69 | 70 | class BatchNorm(nn.BatchNorm2d): 71 | def __init__(self, num_features, momentum, eps=1e-12, 72 | weight=False, bias=True): 73 | super().__init__(num_features, eps=eps, momentum=1-momentum) 74 | self.weight.requires_grad = weight 75 | self.bias.requires_grad = bias 76 | # Note that PyTorch already initializes the weights to one and bias to zero 77 | 78 | class Conv(nn.Conv2d): 79 | def __init__(self, in_channels, out_channels, kernel_size=3, padding='same', bias=False): 80 | super().__init__(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias) 81 | 82 | def reset_parameters(self): 83 | super().reset_parameters() 84 | if self.bias is not None: 85 | self.bias.data.zero_() 86 | w = self.weight.data 87 | torch.nn.init.dirac_(w[:w.size(1)]) 88 | 89 | class ConvGroup(nn.Module): 90 | def __init__(self, channels_in, channels_out, batchnorm_momentum): 91 | super().__init__() 92 | self.conv1 = Conv(channels_in, channels_out) 93 | self.pool = nn.MaxPool2d(2) 94 | self.norm1 = BatchNorm(channels_out, batchnorm_momentum) 95 | self.conv2 = Conv(channels_out, channels_out) 96 | self.norm2 = BatchNorm(channels_out, batchnorm_momentum) 97 | self.activ = nn.GELU() 98 | 99 | def forward(self, x): 100 | x = self.conv1(x) 101 | x = self.pool(x) 102 | x = self.norm1(x) 103 | x = self.activ(x) 104 | x = self.conv2(x) 105 | x = self.norm2(x) 106 | x = self.activ(x) 107 | return x 108 | 109 | ############################################# 110 | # Network Definition # 111 | ############################################# 112 | 113 | def make_net93(widths=hyp['net']['widths'], batchnorm_momentum=hyp['net']['batchnorm_momentum']): 114 | whiten_kernel_size = 2 115 | whiten_width = 2 * 3 * whiten_kernel_size**2 116 | net = nn.Sequential( 117 | Conv(3, whiten_width, whiten_kernel_size, padding=0, bias=True), 118 | nn.GELU(), 119 | ConvGroup(whiten_width, widths['block1'], batchnorm_momentum), 120 | ConvGroup(widths['block1'], widths['block2'], batchnorm_momentum), 121 | ConvGroup(widths['block2'], widths['block3'], batchnorm_momentum), 122 | nn.MaxPool2d(3), 123 | Flatten(), 124 | nn.Linear(widths['block3'], 10, bias=False), 125 | Mul(hyp['net']['scaling_factor']), 126 | ) 127 | net[0].weight.requires_grad = False 128 | net = net.half().cuda() 129 | net = net.to(memory_format=torch.channels_last) 130 | for mod in net.modules(): 131 | if isinstance(mod, BatchNorm): 132 | mod.float() 133 | return net 134 | 135 | ############################################ 136 | # Train and Eval # 137 | ############################################ 138 | 139 | def train93(train_loader=None, epochs=hyp['opt']['train_epochs'], label_smoothing=hyp['opt']['label_smoothing'], 140 | learning_rate=hyp['opt']['lr'], bias_scaler=hyp['opt']['bias_scaler'], 141 | momentum=hyp['opt']['momentum'], weight_decay=hyp['opt']['weight_decay'], 142 | whiten_bias_epochs=hyp['opt']['whiten_bias_epochs'], tta_level=hyp['net']['tta_level'], 143 | make_net=make_net93, run=0, verbose=True): 144 | 145 | if train_loader is None: 146 | train_loader = CifarLoader('cifar10', train=True, batch_size=hyp['opt']['batch_size'], aug=hyp['aug'], altflip=True) 147 | 148 | return train(train_loader, epochs, label_smoothing, learning_rate, bias_scaler, momentum, weight_decay, 149 | whiten_bias_epochs, tta_level, make_net, run, verbose) 150 | 151 | -------------------------------------------------------------------------------- /airbench/lib_airbench94.py: -------------------------------------------------------------------------------- 1 | # 94.01 in n=1000 runs 2 | 3 | from .utils import train, CifarLoader 4 | 5 | ############################################# 6 | # Setup/Hyperparameters # 7 | ############################################# 8 | 9 | import torch 10 | from torch import nn 11 | import torch.nn.functional as F 12 | 13 | torch.backends.cudnn.benchmark = True 14 | 15 | # We express the main training hyperparameters (batch size, learning rate, momentum, and weight decay) 16 | # in decoupled form, so that each one can be tuned independently. This accomplishes the following: 17 | # * Assuming time-constant gradients, the average step size is decoupled from everything but the lr. 18 | # * The size of the weight decay update is decoupled from everything but the wd. 19 | # In constrast, normally when we increase the (Nesterov) momentum, this also scales up the step size 20 | # proportionally to 1 + 1 / (1 - momentum), meaning we cannot change momentum without having to re-tune 21 | # the learning rate. Similarly, normally when we increase the learning rate this also increases the size 22 | # of the weight decay, requiring a proportional decrease in the wd to maintain the same decay strength. 23 | # 24 | # The practical impact is that hyperparameter tuning is faster, since this parametrization allows each 25 | # one to be tuned independently. See https://myrtle.ai/learn/how-to-train-your-resnet-5-hyperparameters/. 26 | 27 | hyp = { 28 | 'opt': { 29 | 'train_epochs': 9.9, 30 | 'batch_size': 1024, 31 | 'lr': 11.5, # learning rate per 1024 examples 32 | 'momentum': 0.85, 33 | 'weight_decay': 0.0153, # weight decay per 1024 examples (decoupled from learning rate) 34 | 'bias_scaler': 64.0, # scales up learning rate (but not weight decay) for BatchNorm biases 35 | 'label_smoothing': 0.2, 36 | 'whiten_bias_epochs': 3, # how many epochs to train the whitening layer bias before freezing 37 | }, 38 | 'aug': { 39 | 'flip': True, 40 | 'translate': 2, 41 | }, 42 | 'net': { 43 | 'widths': { 44 | 'block1': 64, 45 | 'block2': 256, 46 | 'block3': 256, 47 | }, 48 | 'batchnorm_momentum': 0.6, 49 | 'scaling_factor': 1/9, 50 | 'tta_level': 2, # the level of test-time augmentation: 0=none, 1=mirror, 2=mirror+translate 51 | }, 52 | } 53 | 54 | ############################################# 55 | # Network Components # 56 | ############################################# 57 | 58 | class Flatten(nn.Module): 59 | def forward(self, x): 60 | return x.view(x.size(0), -1) 61 | 62 | class Mul(nn.Module): 63 | def __init__(self, scale): 64 | super().__init__() 65 | self.scale = scale 66 | def forward(self, x): 67 | return x * self.scale 68 | 69 | class BatchNorm(nn.BatchNorm2d): 70 | def __init__(self, num_features, momentum, eps=1e-12, 71 | weight=False, bias=True): 72 | super().__init__(num_features, eps=eps, momentum=1-momentum) 73 | self.weight.requires_grad = weight 74 | self.bias.requires_grad = bias 75 | # Note that PyTorch already initializes the weights to one and bias to zero 76 | 77 | class Conv(nn.Conv2d): 78 | def __init__(self, in_channels, out_channels, kernel_size=3, padding='same', bias=False): 79 | super().__init__(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias) 80 | 81 | def reset_parameters(self): 82 | super().reset_parameters() 83 | if self.bias is not None: 84 | self.bias.data.zero_() 85 | w = self.weight.data 86 | torch.nn.init.dirac_(w[:w.size(1)]) 87 | 88 | class ConvGroup(nn.Module): 89 | def __init__(self, channels_in, channels_out, batchnorm_momentum): 90 | super().__init__() 91 | self.conv1 = Conv(channels_in, channels_out) 92 | self.pool = nn.MaxPool2d(2) 93 | self.norm1 = BatchNorm(channels_out, batchnorm_momentum) 94 | self.conv2 = Conv(channels_out, channels_out) 95 | self.norm2 = BatchNorm(channels_out, batchnorm_momentum) 96 | self.activ = nn.GELU() 97 | 98 | def forward(self, x): 99 | x = self.conv1(x) 100 | x = self.pool(x) 101 | x = self.norm1(x) 102 | x = self.activ(x) 103 | x = self.conv2(x) 104 | x = self.norm2(x) 105 | x = self.activ(x) 106 | return x 107 | 108 | ############################################# 109 | # Network Definition # 110 | ############################################# 111 | 112 | def make_net94(widths=hyp['net']['widths'], batchnorm_momentum=hyp['net']['batchnorm_momentum']): 113 | whiten_kernel_size = 2 114 | whiten_width = 2 * 3 * whiten_kernel_size**2 115 | net = nn.Sequential( 116 | Conv(3, whiten_width, whiten_kernel_size, padding=0, bias=True), 117 | nn.GELU(), 118 | ConvGroup(whiten_width, widths['block1'], batchnorm_momentum), 119 | ConvGroup(widths['block1'], widths['block2'], batchnorm_momentum), 120 | ConvGroup(widths['block2'], widths['block3'], batchnorm_momentum), 121 | nn.MaxPool2d(3), 122 | Flatten(), 123 | nn.Linear(widths['block3'], 10, bias=False), 124 | Mul(hyp['net']['scaling_factor']), 125 | ) 126 | net[0].weight.requires_grad = False 127 | net = net.half().cuda() 128 | net = net.to(memory_format=torch.channels_last) 129 | for mod in net.modules(): 130 | if isinstance(mod, BatchNorm): 131 | mod.float() 132 | return net 133 | 134 | ############################################ 135 | # Train and Eval # 136 | ############################################ 137 | 138 | def train94(train_loader=None, epochs=hyp['opt']['train_epochs'], label_smoothing=hyp['opt']['label_smoothing'], 139 | learning_rate=hyp['opt']['lr'], bias_scaler=hyp['opt']['bias_scaler'], 140 | momentum=hyp['opt']['momentum'], weight_decay=hyp['opt']['weight_decay'], 141 | whiten_bias_epochs=hyp['opt']['whiten_bias_epochs'], tta_level=hyp['net']['tta_level'], 142 | make_net=make_net94, run=0, verbose=True): 143 | 144 | if train_loader is None: 145 | train_loader = CifarLoader('cifar10', train=True, batch_size=hyp['opt']['batch_size'], aug=hyp['aug'], altflip=True) 146 | 147 | return train(train_loader, epochs, label_smoothing, learning_rate, bias_scaler, momentum, weight_decay, 148 | whiten_bias_epochs, tta_level, make_net, run, verbose) 149 | 150 | -------------------------------------------------------------------------------- /airbench/lib_airbench95.py: -------------------------------------------------------------------------------- 1 | # 95.01 in n=200 runs 2 | 3 | from .utils import train, CifarLoader 4 | 5 | ############################################# 6 | # Setup/Hyperparameters # 7 | ############################################# 8 | 9 | import torch 10 | from torch import nn 11 | import torch.nn.functional as F 12 | 13 | torch.backends.cudnn.benchmark = True 14 | 15 | # We express the main training hyperparameters (batch size, learning rate, momentum, and weight decay) 16 | # in decoupled form, so that each one can be tuned independently. This accomplishes the following: 17 | # * Assuming time-constant gradients, the average step size is decoupled from everything but the lr. 18 | # * The size of the weight decay update is decoupled from everything but the wd. 19 | # In constrast, normally when we increase the (Nesterov) momentum, this also scales up the step size 20 | # proportionally to 1 + 1 / (1 - momentum), meaning we cannot change momentum without having to re-tune 21 | # the learning rate. Similarly, normally when we increase the learning rate this also increases the size 22 | # of the weight decay, requiring a proportional decrease in the wd to maintain the same decay strength. 23 | # 24 | # The practical impact is that hyperparameter tuning is faster, since this parametrization allows each 25 | # one to be tuned independently. See https://myrtle.ai/learn/how-to-train-your-resnet-5-hyperparameters/. 26 | 27 | hyp = { 28 | 'opt': { 29 | 'train_epochs': 15.0, 30 | 'batch_size': 1024, 31 | 'lr': 10.0, # learning rate per 1024 examples 32 | 'momentum': 0.85, 33 | 'weight_decay': 0.0153, # weight decay per 1024 examples (decoupled from learning rate) 34 | 'bias_scaler': 64.0, # scales up learning rate (but not weight decay) for BatchNorm biases 35 | 'label_smoothing': 0.2, 36 | 'whiten_bias_epochs': 3, # how many epochs to train the whitening layer bias before freezing 37 | }, 38 | 'aug': { 39 | 'flip': True, 40 | 'translate': 2, 41 | }, 42 | 'net': { 43 | 'widths': { 44 | 'block1': 128, 45 | 'block2': 384, 46 | 'block3': 384, 47 | }, 48 | 'batchnorm_momentum': 0.6, 49 | 'scaling_factor': 1/9, 50 | 'tta_level': 2, # the level of test-time augmentation: 0=none, 1=mirror, 2=mirror+translate 51 | }, 52 | } 53 | 54 | ############################################# 55 | # Network Components # 56 | ############################################# 57 | 58 | class Flatten(nn.Module): 59 | def forward(self, x): 60 | return x.view(x.size(0), -1) 61 | 62 | class Mul(nn.Module): 63 | def __init__(self, scale): 64 | super().__init__() 65 | self.scale = scale 66 | def forward(self, x): 67 | return x * self.scale 68 | 69 | class BatchNorm(nn.BatchNorm2d): 70 | def __init__(self, num_features, momentum, eps=1e-12, 71 | weight=False, bias=True): 72 | super().__init__(num_features, eps=eps, momentum=1-momentum) 73 | self.weight.requires_grad = weight 74 | self.bias.requires_grad = bias 75 | # Note that PyTorch already initializes the weights to one and bias to zero 76 | 77 | class Conv(nn.Conv2d): 78 | def __init__(self, in_channels, out_channels, kernel_size=3, padding='same', bias=False): 79 | super().__init__(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias) 80 | 81 | def reset_parameters(self): 82 | super().reset_parameters() 83 | if self.bias is not None: 84 | self.bias.data.zero_() 85 | w = self.weight.data 86 | torch.nn.init.dirac_(w[:w.size(1)]) 87 | 88 | class ConvGroup(nn.Module): 89 | def __init__(self, channels_in, channels_out, batchnorm_momentum): 90 | super().__init__() 91 | self.conv1 = Conv(channels_in, channels_out) 92 | self.pool = nn.MaxPool2d(2) 93 | self.norm1 = BatchNorm(channels_out, batchnorm_momentum) 94 | self.conv2 = Conv(channels_out, channels_out) 95 | self.norm2 = BatchNorm(channels_out, batchnorm_momentum) 96 | self.activ = nn.GELU() 97 | 98 | def forward(self, x): 99 | x = self.conv1(x) 100 | x = self.pool(x) 101 | x = self.norm1(x) 102 | x = self.activ(x) 103 | x = self.conv2(x) 104 | x = self.norm2(x) 105 | x = self.activ(x) 106 | return x 107 | 108 | ############################################# 109 | # Network Definition # 110 | ############################################# 111 | 112 | def make_net95(widths=hyp['net']['widths'], batchnorm_momentum=hyp['net']['batchnorm_momentum']): 113 | whiten_kernel_size = 2 114 | whiten_width = 2 * 3 * whiten_kernel_size**2 115 | net = nn.Sequential( 116 | Conv(3, whiten_width, whiten_kernel_size, padding=0, bias=True), 117 | nn.GELU(), 118 | ConvGroup(whiten_width, widths['block1'], batchnorm_momentum), 119 | ConvGroup(widths['block1'], widths['block2'], batchnorm_momentum), 120 | ConvGroup(widths['block2'], widths['block3'], batchnorm_momentum), 121 | nn.MaxPool2d(3), 122 | Flatten(), 123 | nn.Linear(widths['block3'], 10, bias=False), 124 | Mul(hyp['net']['scaling_factor']), 125 | ) 126 | net[0].weight.requires_grad = False 127 | net = net.half().cuda() 128 | net = net.to(memory_format=torch.channels_last) 129 | for mod in net.modules(): 130 | if isinstance(mod, BatchNorm): 131 | mod.float() 132 | return net 133 | 134 | ############################################ 135 | # Train and Eval # 136 | ############################################ 137 | 138 | def train95(train_loader=None, epochs=hyp['opt']['train_epochs'], label_smoothing=hyp['opt']['label_smoothing'], 139 | learning_rate=hyp['opt']['lr'], bias_scaler=hyp['opt']['bias_scaler'], 140 | momentum=hyp['opt']['momentum'], weight_decay=hyp['opt']['weight_decay'], 141 | whiten_bias_epochs=hyp['opt']['whiten_bias_epochs'], tta_level=hyp['net']['tta_level'], 142 | make_net=make_net95, run=0, verbose=True): 143 | 144 | if train_loader is None: 145 | train_loader = CifarLoader('cifar10', train=True, batch_size=hyp['opt']['batch_size'], aug=hyp['aug'], altflip=True) 146 | 147 | return train(train_loader, epochs, label_smoothing, learning_rate, bias_scaler, momentum, weight_decay, 148 | whiten_bias_epochs, tta_level, make_net, run, verbose) 149 | 150 | -------------------------------------------------------------------------------- /airbench/lib_airbench96.py: -------------------------------------------------------------------------------- 1 | # 96.03 in n=200 runs. 2 | 3 | from .utils import train, CifarLoader 4 | 5 | ############################################# 6 | # Setup/Hyperparameters # 7 | ############################################# 8 | 9 | import torch 10 | from torch import nn 11 | import torch.nn.functional as F 12 | 13 | torch.backends.cudnn.benchmark = True 14 | 15 | # We express the main training hyperparameters (batch size, learning rate, momentum, and weight decay) 16 | # in decoupled form, so that each one can be tuned independently. This accomplishes the following: 17 | # * Assuming time-constant gradients, the average step size is decoupled from everything but the lr. 18 | # * The size of the weight decay update is decoupled from everything but the wd. 19 | # In constrast, normally when we increase the (Nesterov) momentum, this also scales up the step size 20 | # proportionally to 1 + 1 / (1 - momentum), meaning we cannot change momentum without having to re-tune 21 | # the learning rate. Similarly, normally when we increase the learning rate this also increases the size 22 | # of the weight decay, requiring a proportional decrease in the wd to maintain the same decay strength. 23 | # 24 | # The practical impact is that hyperparameter tuning is faster, since this parametrization allows each 25 | # one to be tuned independently. See https://myrtle.ai/learn/how-to-train-your-resnet-5-hyperparameters/. 26 | 27 | hyp = { 28 | 'opt': { 29 | 'train_epochs': 37.0, 30 | 'batch_size': 1024, 31 | 'lr': 9.0, # learning rate per 1024 examples 32 | 'momentum': 0.85, 33 | 'weight_decay': 0.012, # weight decay per 1024 examples (decoupled from learning rate) 34 | 'bias_scaler': 64.0, # scales up learning rate (but not weight decay) for BatchNorm biases 35 | 'label_smoothing': 0.2, 36 | 'whiten_bias_epochs': 3, # how many epochs to train the whitening layer bias before freezing 37 | }, 38 | 'aug': { 39 | 'flip': True, 40 | 'translate': 4, 41 | 'cutout': 12, 42 | }, 43 | 'net': { 44 | 'widths': { 45 | 'block1': 128, 46 | 'block2': 384, 47 | 'block3': 512, 48 | }, 49 | 'scaling_factor': 1/9, 50 | 'tta_level': 2, # the level of test-time augmentation: 0=none, 1=mirror, 2=mirror+translate 51 | }, 52 | } 53 | 54 | ############################################# 55 | # Network Components # 56 | ############################################# 57 | 58 | class Flatten(nn.Module): 59 | def forward(self, x): 60 | return x.view(x.size(0), -1) 61 | 62 | class Mul(nn.Module): 63 | def __init__(self, scale): 64 | super().__init__() 65 | self.scale = scale 66 | def forward(self, x): 67 | return x * self.scale 68 | 69 | class BatchNorm(nn.BatchNorm2d): 70 | def __init__(self, num_features, eps=1e-12, 71 | weight=False, bias=True): 72 | super().__init__(num_features, eps=eps) 73 | self.weight.requires_grad = weight 74 | self.bias.requires_grad = bias 75 | # Note that PyTorch already initializes the weights to one and bias to zero 76 | 77 | class Conv(nn.Conv2d): 78 | def __init__(self, in_channels, out_channels, kernel_size=3, padding='same', bias=False): 79 | super().__init__(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias) 80 | 81 | def reset_parameters(self): 82 | super().reset_parameters() 83 | if self.bias is not None: 84 | self.bias.data.zero_() 85 | w = self.weight.data 86 | torch.nn.init.dirac_(w[:w.size(1)]) 87 | 88 | class ConvGroup(nn.Module): 89 | def __init__(self, channels_in, channels_out): 90 | super().__init__() 91 | self.conv1 = Conv(channels_in, channels_out) 92 | self.pool = nn.MaxPool2d(2) 93 | self.norm1 = BatchNorm(channels_out) 94 | self.conv2 = Conv(channels_out, channels_out) 95 | self.norm2 = BatchNorm(channels_out) 96 | self.conv3 = Conv(channels_out, channels_out) 97 | self.norm3 = BatchNorm(channels_out) 98 | self.activ = nn.GELU() 99 | 100 | def forward(self, x): 101 | x = self.conv1(x) 102 | x = self.pool(x) 103 | x = self.norm1(x) 104 | x = self.activ(x) 105 | x0 = x 106 | x = self.conv2(x) 107 | x = self.norm2(x) 108 | x = self.activ(x) 109 | x = self.conv3(x) 110 | x = self.norm3(x) 111 | x += x0 112 | x = self.activ(x) 113 | return x 114 | 115 | ############################################# 116 | # Network Definition # 117 | ############################################# 118 | 119 | def make_net96(): 120 | widths = hyp['net']['widths'] 121 | whiten_kernel_size = 2 122 | whiten_width = 2 * 3 * whiten_kernel_size**2 123 | net = nn.Sequential( 124 | Conv(3, whiten_width, whiten_kernel_size, padding=0, bias=True), 125 | nn.GELU(), 126 | ConvGroup(whiten_width, widths['block1']), 127 | ConvGroup(widths['block1'], widths['block2']), 128 | ConvGroup(widths['block2'], widths['block3']), 129 | nn.MaxPool2d(3), 130 | Flatten(), 131 | nn.Linear(widths['block3'], 10, bias=False), 132 | Mul(hyp['net']['scaling_factor']), 133 | ) 134 | net[0].weight.requires_grad = False 135 | net = net.half().cuda() 136 | net = net.to(memory_format=torch.channels_last) 137 | for mod in net.modules(): 138 | if isinstance(mod, BatchNorm): 139 | mod.float() 140 | return net 141 | 142 | ############################################ 143 | # Train and Eval # 144 | ############################################ 145 | 146 | def train96(train_loader=None, epochs=hyp['opt']['train_epochs'], label_smoothing=hyp['opt']['label_smoothing'], 147 | learning_rate=hyp['opt']['lr'], bias_scaler=hyp['opt']['bias_scaler'], 148 | momentum=hyp['opt']['momentum'], weight_decay=hyp['opt']['weight_decay'], 149 | whiten_bias_epochs=hyp['opt']['whiten_bias_epochs'], tta_level=hyp['net']['tta_level'], 150 | make_net=make_net96, run=0, verbose=True, lr_peak=0.1, lr_end=0): 151 | 152 | if train_loader is None: 153 | train_loader = CifarLoader('cifar10', train=True, batch_size=hyp['opt']['batch_size'], aug=hyp['aug'], altflip=True) 154 | 155 | return train(train_loader, epochs, label_smoothing, learning_rate, bias_scaler, momentum, weight_decay, 156 | whiten_bias_epochs, tta_level, make_net, run, verbose) 157 | 158 | -------------------------------------------------------------------------------- /airbench/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from math import ceil 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | import torchvision 8 | import torchvision.transforms as T 9 | 10 | 11 | def infer(model, loader, tta_level=0): 12 | 13 | def infer_basic(inputs, net): 14 | return net(inputs).clone() 15 | 16 | def infer_mirror(inputs, net): 17 | return 0.5 * net(inputs) + 0.5 * net(inputs.flip(-1)) 18 | 19 | def infer_mirror_translate(inputs, net): 20 | logits = infer_mirror(inputs, net) 21 | pad = 1 22 | padded_inputs = F.pad(inputs, (pad,)*4, 'reflect') 23 | inputs_translate_list = [ 24 | padded_inputs[:, :, 0:32, 0:32], 25 | padded_inputs[:, :, 2:34, 2:34], 26 | ] 27 | logits_translate_list = [infer_mirror(inputs_translate, net) 28 | for inputs_translate in inputs_translate_list] 29 | logits_translate = torch.stack(logits_translate_list).mean(0) 30 | return 0.5 * logits + 0.5 * logits_translate 31 | 32 | model.eval() 33 | test_images = loader.normalize(loader.images) 34 | infer_fn = [infer_basic, infer_mirror, infer_mirror_translate][tta_level] 35 | with torch.no_grad(): 36 | return torch.cat([infer_fn(inputs, model) for inputs in test_images.split(2000)]) 37 | 38 | def evaluate(model, loader, tta_level=0): 39 | logits = infer(model, loader, tta_level) 40 | return (logits.argmax(1) == loader.labels).float().mean().item() 41 | 42 | ############################################# 43 | # DataLoader # 44 | ############################################# 45 | 46 | CIFAR_MEAN = torch.tensor((0.4914, 0.4822, 0.4465)) 47 | CIFAR_STD = torch.tensor((0.2470, 0.2435, 0.2616)) 48 | 49 | def batch_flip_lr(inputs): 50 | flip_mask = (torch.rand(len(inputs), device=inputs.device) < 0.5).view(-1, 1, 1, 1) 51 | return torch.where(flip_mask, inputs.flip(-1), inputs) 52 | 53 | def batch_crop(images, crop_size): 54 | r = (images.size(-1) - crop_size)//2 55 | shifts = torch.randint(-r, r+1, size=(len(images), 2), device=images.device) 56 | images_out = torch.empty((len(images), 3, crop_size, crop_size), device=images.device, dtype=images.dtype) 57 | # The two cropping methods in this if-else produce equivalent results, but the second is faster for r > 2. 58 | if r <= 2: 59 | for sy in range(-r, r+1): 60 | for sx in range(-r, r+1): 61 | mask = (shifts[:, 0] == sy) & (shifts[:, 1] == sx) 62 | images_out[mask] = images[mask, :, r+sy:r+sy+crop_size, r+sx:r+sx+crop_size] 63 | else: 64 | images_tmp = torch.empty((len(images), 3, crop_size, crop_size+2*r), device=images.device, dtype=images.dtype) 65 | for s in range(-r, r+1): 66 | mask = (shifts[:, 0] == s) 67 | images_tmp[mask] = images[mask, :, r+s:r+s+crop_size, :] 68 | for s in range(-r, r+1): 69 | mask = (shifts[:, 1] == s) 70 | images_out[mask] = images_tmp[mask, :, :, r+s:r+s+crop_size] 71 | return images_out 72 | 73 | def make_random_square_masks(inputs, size): 74 | is_even = int(size % 2 == 0) 75 | n,c,h,w = inputs.shape 76 | 77 | # seed top-left corners of squares to cutout boxes from, in one dimension each 78 | corner_y = torch.randint(0, h-size+1, size=(n,), device=inputs.device) 79 | corner_x = torch.randint(0, w-size+1, size=(n,), device=inputs.device) 80 | 81 | # measure distance, using the center as a reference point 82 | corner_y_dists = torch.arange(h, device=inputs.device).view(1, 1, h, 1) - corner_y.view(-1, 1, 1, 1) 83 | corner_x_dists = torch.arange(w, device=inputs.device).view(1, 1, 1, w) - corner_x.view(-1, 1, 1, 1) 84 | 85 | mask_y = (corner_y_dists >= 0) * (corner_y_dists < size) 86 | mask_x = (corner_x_dists >= 0) * (corner_x_dists < size) 87 | 88 | final_mask = mask_y * mask_x 89 | 90 | return final_mask 91 | 92 | def batch_cutout(inputs, size): 93 | cutout_masks = make_random_square_masks(inputs, size) 94 | return inputs.masked_fill(cutout_masks, 0) 95 | 96 | class CifarLoader: 97 | 98 | def __init__(self, path, train=True, batch_size=500, aug=None, drop_last=None, shuffle=None, altflip=False): 99 | 100 | data_path = os.path.join(path, 'train.pt' if train else 'test.pt') 101 | if not os.path.exists(data_path): 102 | dset = torchvision.datasets.CIFAR10(path, download=True, train=train) 103 | images = torch.tensor(dset.data) 104 | labels = torch.tensor(dset.targets) 105 | torch.save({'images': images, 'labels': labels, 'classes': dset.classes}, data_path) 106 | data = torch.load(data_path, map_location='cuda') 107 | 108 | self.epoch = 0 109 | self.images, self.labels, self.classes = data['images'], data['labels'], data['classes'] 110 | # It's faster to load+process uint8 data than to load preprocessed fp16 data 111 | self.images = (self.images.half() / 255).permute(0, 3, 1, 2).to(memory_format=torch.channels_last) 112 | 113 | self.normalize = T.Normalize(CIFAR_MEAN, CIFAR_STD) 114 | self.proc_images = {} # Saved results of image processing to be done on the first epoch 115 | 116 | self.aug = aug or {} 117 | for k in self.aug.keys(): 118 | assert k in ['flip', 'translate', 'cutout'], 'Unrecognized key: %s' % k 119 | 120 | self.batch_size = batch_size 121 | self.drop_last = train if drop_last is None else drop_last 122 | self.shuffle = train if shuffle is None else shuffle 123 | self.altflip = altflip 124 | 125 | def __len__(self): 126 | return len(self.images)//self.batch_size if self.drop_last else ceil(len(self.images)/self.batch_size) 127 | 128 | def __setattr__(self, k, v): 129 | if k in ('images', 'labels'): 130 | assert self.epoch == 0, 'Changing images or labels is only unsupported before iteration.' 131 | super().__setattr__(k, v) 132 | 133 | def __iter__(self): 134 | 135 | if self.epoch == 0: 136 | images = self.proc_images['norm'] = self.normalize(self.images) 137 | # Pre-flip images in order to do every-other epoch flipping scheme 138 | if self.aug.get('flip', False): 139 | images = self.proc_images['flip'] = batch_flip_lr(images) 140 | # Pre-pad images to save time when doing random translation 141 | pad = self.aug.get('translate', 0) 142 | if pad > 0: 143 | self.proc_images['pad'] = F.pad(images, (pad,)*4, 'reflect') 144 | 145 | if self.aug.get('translate', 0) > 0: 146 | images = batch_crop(self.proc_images['pad'], self.images.shape[-2]) 147 | elif self.aug.get('flip', False): 148 | images = self.proc_images['flip'] 149 | else: 150 | images = self.proc_images['norm'] 151 | # Flip all images together every other epoch. This increases diversity relative to random flipping 152 | if self.aug.get('flip', False): 153 | if self.altflip: 154 | if self.epoch % 2 == 1: 155 | images = images.flip(-1) 156 | else: 157 | images = batch_flip_lr(images) 158 | if self.aug.get('cutout', 0) > 0: 159 | images = batch_cutout(images, self.aug['cutout']) 160 | 161 | self.epoch += 1 162 | 163 | indices = (torch.randperm if self.shuffle else torch.arange)(len(images), device=images.device) 164 | for i in range(len(self)): 165 | idxs = indices[i*self.batch_size:(i+1)*self.batch_size] 166 | yield (images[idxs], self.labels[idxs]) 167 | 168 | ############################################# 169 | # Whitening Conv Initialization # 170 | ############################################# 171 | 172 | def get_patches(x, patch_shape): 173 | c, (h, w) = x.shape[1], patch_shape 174 | return x.unfold(2,h,1).unfold(3,w,1).transpose(1,3).reshape(-1,c,h,w).float() 175 | 176 | def get_whitening_parameters(patches): 177 | n,c,h,w = patches.shape 178 | patches_flat = patches.view(n, -1) 179 | est_patch_covariance = (patches_flat.T @ patches_flat) / n 180 | eigenvalues, eigenvectors = torch.linalg.eigh(est_patch_covariance, UPLO='U') 181 | return eigenvalues.flip(0).view(-1, 1, 1, 1), eigenvectors.T.reshape(c*h*w,c,h,w).flip(0) 182 | 183 | def init_whitening_conv(layer, train_set, eps=5e-4): 184 | patches = get_patches(train_set, patch_shape=layer.weight.data.shape[2:]) 185 | eigenvalues, eigenvectors = get_whitening_parameters(patches) 186 | eigenvectors_scaled = eigenvectors / torch.sqrt(eigenvalues + eps) 187 | layer.weight.data[:] = torch.cat((eigenvectors_scaled, -eigenvectors_scaled)) 188 | 189 | ############################################ 190 | # Lookahead # 191 | ############################################ 192 | 193 | class LookaheadState: 194 | def __init__(self, net): 195 | self.net_ema = {k: v.clone() for k, v in net.state_dict().items()} 196 | 197 | def update(self, net, decay): 198 | for ema_param, net_param in zip(self.net_ema.values(), net.state_dict().values()): 199 | if net_param.dtype in (torch.half, torch.float): 200 | ema_param.lerp_(net_param, 1-decay) 201 | net_param.copy_(ema_param) 202 | 203 | ############################################ 204 | # Logging # 205 | ############################################ 206 | 207 | def print_columns(columns_list, is_head=False, is_final_entry=False, print_cols=True): 208 | print_string = '' 209 | for col in columns_list: 210 | print_string += '| %s ' % col 211 | print_string += '|' 212 | if is_head: 213 | print('-'*len(print_string)) 214 | if print_cols: 215 | print(print_string) 216 | if is_head or is_final_entry: 217 | print('-'*len(print_string)) 218 | 219 | logging_columns_list = ['run ', 'epoch', 'train_loss', 'train_acc', 'val_acc', 'tta_val_acc', 'total_time_seconds'] 220 | def print_training_details(variables, is_final_entry): 221 | formatted = [] 222 | for col in logging_columns_list: 223 | var = variables.get(col.strip(), None) 224 | if type(var) in (int, str): 225 | res = str(var) 226 | elif type(var) is float: 227 | res = '{:0.4f}'.format(var) 228 | else: 229 | assert var is None 230 | res = '' 231 | formatted.append(res.rjust(len(col))) 232 | print_columns(formatted, is_final_entry=is_final_entry) 233 | 234 | ############################################ 235 | # Train and Eval # 236 | ############################################ 237 | 238 | def train(train_loader, epochs, label_smoothing, learning_rate, bias_scaler, momentum, weight_decay, 239 | whiten_bias_epochs, tta_level, make_net, run, verbose, lr_peak=0.23, lr_end=0.07): 240 | 241 | train_loader.epoch = 0 242 | 243 | is_warmup = run in (-1, 'warmup') 244 | if is_warmup and verbose: 245 | run = 'warmup' 246 | print_columns(logging_columns_list, is_head=True, print_cols=False) 247 | if run == 0 and verbose: 248 | print_columns(logging_columns_list, is_head=True) 249 | 250 | batch_size = train_loader.batch_size 251 | # Assuming gradients are constant in time, for Nesterov momentum, the below ratio is how much 252 | # larger the default steps will be than the underlying per-example gradients. We divide the 253 | # learning rate by this ratio in order to ensure steps are the same scale as gradients, regardless 254 | # of the choice of momentum. 255 | kilostep_scale = 1024 * (1 + 1 / (1 - momentum)) 256 | lr = learning_rate / kilostep_scale # un-decoupled learning rate for PyTorch SGD 257 | wd = weight_decay * batch_size / kilostep_scale 258 | lr_biases = lr * bias_scaler 259 | 260 | loss_fn = nn.CrossEntropyLoss(label_smoothing=label_smoothing, reduction='none') 261 | test_loader = CifarLoader('cifar10', train=False, batch_size=2000) 262 | total_train_steps = ceil(len(train_loader) * epochs) 263 | 264 | model = make_net() 265 | current_steps = 0 266 | 267 | norm_biases = [p for k, p in model.named_parameters() if 'norm' in k and p.requires_grad] 268 | other_params = [p for k, p in model.named_parameters() if 'norm' not in k and p.requires_grad] 269 | param_configs = [dict(params=norm_biases, lr=lr_biases, weight_decay=wd/lr_biases), 270 | dict(params=other_params, lr=lr, weight_decay=wd/lr)] 271 | optimizer = torch.optim.SGD(param_configs, momentum=momentum, nesterov=True) 272 | 273 | def get_lr(step): 274 | warmup_steps = int(total_train_steps * lr_peak) 275 | warmdown_steps = total_train_steps - warmup_steps 276 | if step < warmup_steps: 277 | frac = step / warmup_steps 278 | return 0.2 * (1 - frac) + 1.0 * frac 279 | else: 280 | frac = (step - warmup_steps) / warmdown_steps 281 | return 1.0 * (1 - frac) + lr_end * frac 282 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, get_lr) 283 | 284 | alpha_schedule = 0.95**5 * (torch.arange(total_train_steps+1) / total_train_steps)**3 285 | lookahead_state = LookaheadState(model) 286 | 287 | # For accurately timing GPU code 288 | starter = torch.cuda.Event(enable_timing=True) 289 | ender = torch.cuda.Event(enable_timing=True) 290 | total_time_seconds = 0.0 291 | 292 | # Initialize the whitening layer using training images 293 | starter.record() 294 | train_images = train_loader.normalize(train_loader.images[:5000]) 295 | init_whitening_conv(model[0], train_images) 296 | ender.record() 297 | torch.cuda.synchronize() 298 | total_time_seconds += 1e-3 * starter.elapsed_time(ender) 299 | 300 | for epoch in range(ceil(epochs)): 301 | 302 | model[0].bias.requires_grad = (epoch < whiten_bias_epochs) 303 | 304 | #################### 305 | # Training # 306 | #################### 307 | 308 | starter.record() 309 | 310 | model.train() 311 | for inputs, labels in train_loader: 312 | 313 | outputs = model(inputs) 314 | loss = loss_fn(outputs, labels).sum() 315 | optimizer.zero_grad(set_to_none=True) 316 | loss.backward() 317 | optimizer.step() 318 | scheduler.step() 319 | 320 | current_steps += 1 321 | 322 | if current_steps % 5 == 0: 323 | lookahead_state.update(model, decay=alpha_schedule[current_steps].item()) 324 | 325 | if current_steps >= total_train_steps: 326 | if lookahead_state is not None: 327 | lookahead_state.update(model, decay=1.0) 328 | break 329 | 330 | if verbose and is_warmup: 331 | epoch = None 332 | print_training_details(locals(), is_final_entry=False) 333 | return 334 | 335 | ender.record() 336 | torch.cuda.synchronize() 337 | total_time_seconds += 1e-3 * starter.elapsed_time(ender) 338 | 339 | if verbose: 340 | 341 | #################### 342 | # Evaluation # 343 | #################### 344 | 345 | # Save the accuracy and loss from the last training batch of the epoch 346 | train_acc = (outputs.detach().argmax(1) == labels).float().mean().item() 347 | train_loss = loss.item() / batch_size 348 | val_acc = evaluate(model, test_loader, tta_level=0) 349 | print_training_details(locals(), is_final_entry=False) 350 | run = None # Only print the run number once 351 | 352 | if verbose: 353 | 354 | #################### 355 | # TTA Evaluation # 356 | #################### 357 | 358 | starter.record() 359 | tta_val_acc = evaluate(model, test_loader, tta_level) 360 | ender.record() 361 | torch.cuda.synchronize() 362 | total_time_seconds += 1e-3 * starter.elapsed_time(ender) 363 | 364 | epoch = 'eval' 365 | print_training_details(locals(), is_final_entry=True) 366 | 367 | return model 368 | 369 | -------------------------------------------------------------------------------- /airbench94_muon.py: -------------------------------------------------------------------------------- 1 | """ 2 | airbench94_muon.py 3 | Runs in 2.59 seconds on a 400W NVIDIA A100 using torch==2.4.1 4 | Attains 94.01 mean accuracy (n=200 trials) 5 | Descends from https://github.com/tysam-code/hlb-CIFAR10/blob/main/main.py 6 | """ 7 | 8 | ############################################# 9 | # Setup # 10 | ############################################# 11 | 12 | import os 13 | import sys 14 | with open(sys.argv[0]) as f: 15 | code = f.read() 16 | import uuid 17 | from math import ceil 18 | 19 | import torch 20 | from torch import nn 21 | import torch.nn.functional as F 22 | import torchvision 23 | import torchvision.transforms as T 24 | 25 | torch.backends.cudnn.benchmark = True 26 | 27 | ############################################# 28 | # Muon optimizer # 29 | ############################################# 30 | 31 | @torch.compile 32 | def zeropower_via_newtonschulz5(G, steps=3, eps=1e-7): 33 | """ 34 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a 35 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose 36 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at 37 | zero even beyond the point where the iteration no longer converges all the way to one everywhere 38 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T 39 | where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model 40 | performance at all relative to UV^T, where USV^T = G is the SVD. 41 | """ 42 | assert len(G.shape) == 2 43 | a, b, c = (3.4445, -4.7750, 2.0315) 44 | X = G.bfloat16() 45 | X /= (X.norm() + eps) # ensure top singular value <= 1 46 | if G.size(0) > G.size(1): 47 | X = X.T 48 | for _ in range(steps): 49 | A = X @ X.T 50 | B = b * A + c * A @ A 51 | X = a * X + B @ X 52 | if G.size(0) > G.size(1): 53 | X = X.T 54 | return X 55 | 56 | class Muon(torch.optim.Optimizer): 57 | def __init__(self, params, lr=1e-3, momentum=0, nesterov=False): 58 | if lr < 0.0: 59 | raise ValueError(f"Invalid learning rate: {lr}") 60 | if momentum < 0.0: 61 | raise ValueError(f"Invalid momentum value: {momentum}") 62 | if nesterov and momentum <= 0: 63 | raise ValueError("Nesterov momentum requires a momentum") 64 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov) 65 | super().__init__(params, defaults) 66 | 67 | def step(self): 68 | for group in self.param_groups: 69 | lr = group["lr"] 70 | momentum = group["momentum"] 71 | for p in group["params"]: 72 | g = p.grad 73 | if g is None: 74 | continue 75 | state = self.state[p] 76 | 77 | if "momentum_buffer" not in state.keys(): 78 | state["momentum_buffer"] = torch.zeros_like(g) 79 | buf = state["momentum_buffer"] 80 | buf.mul_(momentum).add_(g) 81 | g = g.add(buf, alpha=momentum) if group["nesterov"] else buf 82 | 83 | p.data.mul_(len(p.data)**0.5 / p.data.norm()) # normalize the weight 84 | update = zeropower_via_newtonschulz5(g.reshape(len(g), -1)).view(g.shape) # whiten the update 85 | p.data.add_(update, alpha=-lr) # take a step 86 | 87 | ############################################# 88 | # DataLoader # 89 | ############################################# 90 | 91 | CIFAR_MEAN = torch.tensor((0.4914, 0.4822, 0.4465)) 92 | CIFAR_STD = torch.tensor((0.2470, 0.2435, 0.2616)) 93 | 94 | def batch_flip_lr(inputs): 95 | flip_mask = (torch.rand(len(inputs), device=inputs.device) < 0.5).view(-1, 1, 1, 1) 96 | return torch.where(flip_mask, inputs.flip(-1), inputs) 97 | 98 | def batch_crop(images, crop_size): 99 | r = (images.size(-1) - crop_size)//2 100 | shifts = torch.randint(-r, r+1, size=(len(images), 2), device=images.device) 101 | images_out = torch.empty((len(images), 3, crop_size, crop_size), device=images.device, dtype=images.dtype) 102 | # The two cropping methods in this if-else produce equivalent results, but the second is faster for r > 2. 103 | if r <= 2: 104 | for sy in range(-r, r+1): 105 | for sx in range(-r, r+1): 106 | mask = (shifts[:, 0] == sy) & (shifts[:, 1] == sx) 107 | images_out[mask] = images[mask, :, r+sy:r+sy+crop_size, r+sx:r+sx+crop_size] 108 | else: 109 | images_tmp = torch.empty((len(images), 3, crop_size, crop_size+2*r), device=images.device, dtype=images.dtype) 110 | for s in range(-r, r+1): 111 | mask = (shifts[:, 0] == s) 112 | images_tmp[mask] = images[mask, :, r+s:r+s+crop_size, :] 113 | for s in range(-r, r+1): 114 | mask = (shifts[:, 1] == s) 115 | images_out[mask] = images_tmp[mask, :, :, r+s:r+s+crop_size] 116 | return images_out 117 | 118 | class CifarLoader: 119 | 120 | def __init__(self, path, train=True, batch_size=500, aug=None): 121 | data_path = os.path.join(path, "train.pt" if train else "test.pt") 122 | if not os.path.exists(data_path): 123 | dset = torchvision.datasets.CIFAR10(path, download=True, train=train) 124 | images = torch.tensor(dset.data) 125 | labels = torch.tensor(dset.targets) 126 | torch.save({"images": images, "labels": labels, "classes": dset.classes}, data_path) 127 | 128 | data = torch.load(data_path, map_location=torch.device("cuda")) 129 | self.images, self.labels, self.classes = data["images"], data["labels"], data["classes"] 130 | # It's faster to load+process uint8 data than to load preprocessed fp16 data 131 | self.images = (self.images.half() / 255).permute(0, 3, 1, 2).to(memory_format=torch.channels_last) 132 | 133 | self.normalize = T.Normalize(CIFAR_MEAN, CIFAR_STD) 134 | self.proc_images = {} # Saved results of image processing to be done on the first epoch 135 | self.epoch = 0 136 | 137 | self.aug = aug or {} 138 | for k in self.aug.keys(): 139 | assert k in ["flip", "translate"], "Unrecognized key: %s" % k 140 | 141 | self.batch_size = batch_size 142 | self.drop_last = train 143 | self.shuffle = train 144 | 145 | def __len__(self): 146 | return len(self.images)//self.batch_size if self.drop_last else ceil(len(self.images)/self.batch_size) 147 | 148 | def __iter__(self): 149 | 150 | if self.epoch == 0: 151 | images = self.proc_images["norm"] = self.normalize(self.images) 152 | # Pre-flip images in order to do every-other epoch flipping scheme 153 | if self.aug.get("flip", False): 154 | images = self.proc_images["flip"] = batch_flip_lr(images) 155 | # Pre-pad images to save time when doing random translation 156 | pad = self.aug.get("translate", 0) 157 | if pad > 0: 158 | self.proc_images["pad"] = F.pad(images, (pad,)*4, "reflect") 159 | 160 | if self.aug.get("translate", 0) > 0: 161 | images = batch_crop(self.proc_images["pad"], self.images.shape[-2]) 162 | elif self.aug.get("flip", False): 163 | images = self.proc_images["flip"] 164 | else: 165 | images = self.proc_images["norm"] 166 | # Flip all images together every other epoch. This increases diversity relative to random flipping 167 | if self.aug.get("flip", False): 168 | if self.epoch % 2 == 1: 169 | images = images.flip(-1) 170 | 171 | self.epoch += 1 172 | 173 | indices = (torch.randperm if self.shuffle else torch.arange)(len(images), device=images.device) 174 | for i in range(len(self)): 175 | idxs = indices[i*self.batch_size:(i+1)*self.batch_size] 176 | yield (images[idxs], self.labels[idxs]) 177 | 178 | ############################################# 179 | # Network Definition # 180 | ############################################# 181 | 182 | # note the use of low BatchNorm stats momentum 183 | class BatchNorm(nn.BatchNorm2d): 184 | def __init__(self, num_features, momentum=0.6, eps=1e-12): 185 | super().__init__(num_features, eps=eps, momentum=1-momentum) 186 | self.weight.requires_grad = False 187 | # Note that PyTorch already initializes the weights to one and bias to zero 188 | 189 | class Conv(nn.Conv2d): 190 | def __init__(self, in_channels, out_channels): 191 | super().__init__(in_channels, out_channels, kernel_size=3, padding="same", bias=False) 192 | 193 | def reset_parameters(self): 194 | super().reset_parameters() 195 | w = self.weight.data 196 | torch.nn.init.dirac_(w[:w.size(1)]) 197 | 198 | class ConvGroup(nn.Module): 199 | def __init__(self, channels_in, channels_out): 200 | super().__init__() 201 | self.conv1 = Conv(channels_in, channels_out) 202 | self.pool = nn.MaxPool2d(2) 203 | self.norm1 = BatchNorm(channels_out) 204 | self.conv2 = Conv(channels_out, channels_out) 205 | self.norm2 = BatchNorm(channels_out) 206 | self.activ = nn.GELU() 207 | 208 | def forward(self, x): 209 | x = self.conv1(x) 210 | x = self.pool(x) 211 | x = self.norm1(x) 212 | x = self.activ(x) 213 | x = self.conv2(x) 214 | x = self.norm2(x) 215 | x = self.activ(x) 216 | return x 217 | 218 | class CifarNet(nn.Module): 219 | def __init__(self): 220 | super().__init__() 221 | widths = dict(block1=64, block2=256, block3=256) 222 | whiten_kernel_size = 2 223 | whiten_width = 2 * 3 * whiten_kernel_size**2 224 | self.whiten = nn.Conv2d(3, whiten_width, whiten_kernel_size, padding=0, bias=True) 225 | self.whiten.weight.requires_grad = False 226 | self.layers = nn.Sequential( 227 | nn.GELU(), 228 | ConvGroup(whiten_width, widths["block1"]), 229 | ConvGroup(widths["block1"], widths["block2"]), 230 | ConvGroup(widths["block2"], widths["block3"]), 231 | nn.MaxPool2d(3), 232 | ) 233 | self.head = nn.Linear(widths["block3"], 10, bias=False) 234 | for mod in self.modules(): 235 | if isinstance(mod, BatchNorm): 236 | mod.float() 237 | else: 238 | mod.half() 239 | 240 | def reset(self): 241 | for m in self.modules(): 242 | if type(m) in (nn.Conv2d, Conv, BatchNorm, nn.Linear): 243 | m.reset_parameters() 244 | w = self.head.weight.data 245 | w *= 1 / w.std() 246 | 247 | def init_whiten(self, train_images, eps=5e-4): 248 | c, (h, w) = train_images.shape[1], self.whiten.weight.shape[2:] 249 | patches = train_images.unfold(2,h,1).unfold(3,w,1).transpose(1,3).reshape(-1,c,h,w).float() 250 | patches_flat = patches.view(len(patches), -1) 251 | est_patch_covariance = (patches_flat.T @ patches_flat) / len(patches_flat) 252 | eigenvalues, eigenvectors = torch.linalg.eigh(est_patch_covariance, UPLO="U") 253 | eigenvectors_scaled = eigenvectors.T.reshape(-1,c,h,w) / torch.sqrt(eigenvalues.view(-1,1,1,1) + eps) 254 | self.whiten.weight.data[:] = torch.cat((eigenvectors_scaled, -eigenvectors_scaled)) 255 | 256 | def forward(self, x, whiten_bias_grad=True): 257 | b = self.whiten.bias 258 | x = F.conv2d(x, self.whiten.weight, b if whiten_bias_grad else b.detach()) 259 | x = self.layers(x) 260 | x = x.view(len(x), -1) 261 | return self.head(x) / x.size(-1) 262 | 263 | ############################################ 264 | # Logging # 265 | ############################################ 266 | 267 | def print_columns(columns_list, is_head=False, is_final_entry=False): 268 | print_string = "" 269 | for col in columns_list: 270 | print_string += "| %s " % col 271 | print_string += "|" 272 | if is_head: 273 | print("-"*len(print_string)) 274 | print(print_string) 275 | if is_head or is_final_entry: 276 | print("-"*len(print_string)) 277 | 278 | logging_columns_list = ["run ", "epoch", "train_acc", "val_acc", "tta_val_acc", "time_seconds"] 279 | def print_training_details(variables, is_final_entry): 280 | formatted = [] 281 | for col in logging_columns_list: 282 | var = variables.get(col.strip(), None) 283 | if type(var) in (int, str): 284 | res = str(var) 285 | elif type(var) is float: 286 | res = "{:0.4f}".format(var) 287 | else: 288 | assert var is None 289 | res = "" 290 | formatted.append(res.rjust(len(col))) 291 | print_columns(formatted, is_final_entry=is_final_entry) 292 | 293 | ############################################ 294 | # Evaluation # 295 | ############################################ 296 | 297 | def infer(model, loader, tta_level=0): 298 | 299 | # Test-time augmentation strategy (for tta_level=2): 300 | # 1. Flip/mirror the image left-to-right (50% of the time). 301 | # 2. Translate the image by one pixel either up-and-left or down-and-right (50% of the time, 302 | # i.e. both happen 25% of the time). 303 | # 304 | # This creates 6 views per image (left/right times the two translations and no-translation), 305 | # which we evaluate and then weight according to the given probabilities. 306 | 307 | def infer_basic(inputs, net): 308 | return net(inputs).clone() 309 | 310 | def infer_mirror(inputs, net): 311 | return 0.5 * net(inputs) + 0.5 * net(inputs.flip(-1)) 312 | 313 | def infer_mirror_translate(inputs, net): 314 | logits = infer_mirror(inputs, net) 315 | pad = 1 316 | padded_inputs = F.pad(inputs, (pad,)*4, "reflect") 317 | inputs_translate_list = [ 318 | padded_inputs[:, :, 0:32, 0:32], 319 | padded_inputs[:, :, 2:34, 2:34], 320 | ] 321 | logits_translate_list = [infer_mirror(inputs_translate, net) 322 | for inputs_translate in inputs_translate_list] 323 | logits_translate = torch.stack(logits_translate_list).mean(0) 324 | return 0.5 * logits + 0.5 * logits_translate 325 | 326 | model.eval() 327 | test_images = loader.normalize(loader.images) 328 | infer_fn = [infer_basic, infer_mirror, infer_mirror_translate][tta_level] 329 | with torch.no_grad(): 330 | return torch.cat([infer_fn(inputs, model) for inputs in test_images.split(2000)]) 331 | 332 | def evaluate(model, loader, tta_level=0): 333 | logits = infer(model, loader, tta_level) 334 | return (logits.argmax(1) == loader.labels).float().mean().item() 335 | 336 | ############################################ 337 | # Training # 338 | ############################################ 339 | 340 | def main(run, model): 341 | 342 | batch_size = 2000 343 | bias_lr = 0.053 344 | head_lr = 0.67 345 | wd = 2e-6 * batch_size 346 | 347 | test_loader = CifarLoader("cifar10", train=False, batch_size=2000) 348 | train_loader = CifarLoader("cifar10", train=True, batch_size=batch_size, aug=dict(flip=True, translate=2)) 349 | if run == "warmup": 350 | # The only purpose of the first run is to warmup the compiled model, so we can use dummy data 351 | train_loader.labels = torch.randint(0, 10, size=(len(train_loader.labels),), device=train_loader.labels.device) 352 | total_train_steps = ceil(8 * len(train_loader)) 353 | whiten_bias_train_steps = ceil(3 * len(train_loader)) 354 | 355 | # Create optimizers and learning rate schedulers 356 | filter_params = [p for p in model.parameters() if len(p.shape) == 4 and p.requires_grad] 357 | norm_biases = [p for n, p in model.named_parameters() if "norm" in n and p.requires_grad] 358 | param_configs = [dict(params=[model.whiten.bias], lr=bias_lr, weight_decay=wd/bias_lr), 359 | dict(params=norm_biases, lr=bias_lr, weight_decay=wd/bias_lr), 360 | dict(params=[model.head.weight], lr=head_lr, weight_decay=wd/head_lr)] 361 | optimizer1 = torch.optim.SGD(param_configs, momentum=0.85, nesterov=True, fused=True) 362 | optimizer2 = Muon(filter_params, lr=0.24, momentum=0.6, nesterov=True) 363 | optimizers = [optimizer1, optimizer2] 364 | for opt in optimizers: 365 | for group in opt.param_groups: 366 | group["initial_lr"] = group["lr"] 367 | 368 | # For accurately timing GPU code 369 | starter = torch.cuda.Event(enable_timing=True) 370 | ender = torch.cuda.Event(enable_timing=True) 371 | time_seconds = 0.0 372 | def start_timer(): 373 | starter.record() 374 | def stop_timer(): 375 | ender.record() 376 | torch.cuda.synchronize() 377 | nonlocal time_seconds 378 | time_seconds += 1e-3 * starter.elapsed_time(ender) 379 | 380 | model.reset() 381 | step = 0 382 | 383 | # Initialize the whitening layer using training images 384 | start_timer() 385 | train_images = train_loader.normalize(train_loader.images[:5000]) 386 | model.init_whiten(train_images) 387 | stop_timer() 388 | 389 | for epoch in range(ceil(total_train_steps / len(train_loader))): 390 | 391 | #################### 392 | # Training # 393 | #################### 394 | 395 | start_timer() 396 | model.train() 397 | for inputs, labels in train_loader: 398 | outputs = model(inputs, whiten_bias_grad=(step < whiten_bias_train_steps)) 399 | F.cross_entropy(outputs, labels, label_smoothing=0.2, reduction="sum").backward() 400 | for group in optimizer1.param_groups[:1]: 401 | group["lr"] = group["initial_lr"] * (1 - step / whiten_bias_train_steps) 402 | for group in optimizer1.param_groups[1:]+optimizer2.param_groups: 403 | group["lr"] = group["initial_lr"] * (1 - step / total_train_steps) 404 | for opt in optimizers: 405 | opt.step() 406 | model.zero_grad(set_to_none=True) 407 | step += 1 408 | if step >= total_train_steps: 409 | break 410 | stop_timer() 411 | 412 | #################### 413 | # Evaluation # 414 | #################### 415 | 416 | # Save the accuracy and loss from the last training batch of the epoch 417 | train_acc = (outputs.detach().argmax(1) == labels).float().mean().item() 418 | val_acc = evaluate(model, test_loader, tta_level=0) 419 | print_training_details(locals(), is_final_entry=False) 420 | run = None # Only print the run number once 421 | 422 | #################### 423 | # TTA Evaluation # 424 | #################### 425 | 426 | start_timer() 427 | tta_val_acc = evaluate(model, test_loader, tta_level=2) 428 | stop_timer() 429 | epoch = "eval" 430 | print_training_details(locals(), is_final_entry=True) 431 | 432 | return tta_val_acc 433 | 434 | if __name__ == "__main__": 435 | 436 | # We re-use the compiled model between runs to save the non-data-dependent compilation time 437 | model = CifarNet().cuda().to(memory_format=torch.channels_last) 438 | model.compile(mode="max-autotune") 439 | 440 | print_columns(logging_columns_list, is_head=True) 441 | main("warmup", model) 442 | accs = torch.tensor([main(run, model) for run in range(200)]) 443 | print("Mean: %.4f Std: %.4f" % (accs.mean(), accs.std())) 444 | 445 | log_dir = os.path.join("logs", str(uuid.uuid4())) 446 | os.makedirs(log_dir, exist_ok=True) 447 | log_path = os.path.join(log_dir, "log.pt") 448 | torch.save(dict(code=code, accs=accs), log_path) 449 | print(os.path.abspath(log_path)) 450 | 451 | -------------------------------------------------------------------------------- /airbench96_faster.py: -------------------------------------------------------------------------------- 1 | # A variant of airbench optimized for time-to-96%. 2 | # 27.3s runtime on an A100; 3.1 PFLOPs. 3 | # Evidence: 96.00 average accuracy in n=200 runs. 4 | # 5 | # We recorded the above runtime on an NVIDIA A100-SXM4-40GB with the following nvidia-smi: 6 | # NVIDIA-SMI 515.105.01 Driver Version: 515.105.01 CUDA Version: 11.7 7 | # torch.__version__ == '2.4.0+cu121' 8 | 9 | ############################################# 10 | # Setup/Hyperparameters # 11 | ############################################# 12 | 13 | import os 14 | import sys 15 | import uuid 16 | from math import ceil 17 | 18 | import torch 19 | from torch import nn 20 | import torch.nn.functional as F 21 | import torchvision 22 | import torchvision.transforms as T 23 | 24 | torch.backends.cudnn.benchmark = True 25 | 26 | hyp = { 27 | 'opt': { 28 | 'train_epochs': 45.0, 29 | 'batch_size': 1024, 30 | 'batch_size_masked': 512, 31 | 'lr': 9.0, # learning rate per 1024 examples 32 | 'momentum': 0.85, 33 | 'weight_decay': 0.012, # weight decay per 1024 examples (decoupled from learning rate) 34 | 'bias_scaler': 64.0, # scales up learning rate (but not weight decay) for BatchNorm biases 35 | 'label_smoothing': 0.2, 36 | 'whiten_bias_epochs': 3, # how many epochs to train the whitening layer bias before freezing 37 | }, 38 | 'aug': { 39 | 'flip': True, 40 | 'translate': 4, 41 | 'cutout': 12, 42 | }, 43 | 'proxy': { 44 | 'widths': { 45 | 'block1': 32, 46 | 'block2': 64, 47 | 'block3': 64, 48 | }, 49 | 'depth': 2, 50 | 'scaling_factor': 1/9, 51 | }, 52 | 'net': { 53 | 'widths': { 54 | 'block1': 128, 55 | 'block2': 384, 56 | 'block3': 512, 57 | }, 58 | 'depth': 3, 59 | 'scaling_factor': 1/9, 60 | 'tta_level': 2, # the level of test-time augmentation: 0=none, 1=mirror, 2=mirror+translate 61 | }, 62 | } 63 | 64 | ############################################# 65 | # DataLoader # 66 | ############################################# 67 | 68 | CIFAR_MEAN = torch.tensor((0.4914, 0.4822, 0.4465)) 69 | CIFAR_STD = torch.tensor((0.2470, 0.2435, 0.2616)) 70 | 71 | def batch_flip_lr(inputs): 72 | flip_mask = (torch.rand(len(inputs), device=inputs.device) < 0.5).view(-1, 1, 1, 1) 73 | return torch.where(flip_mask, inputs.flip(-1), inputs) 74 | 75 | def batch_crop(images, crop_size): 76 | r = (images.size(-1) - crop_size)//2 77 | shifts = torch.randint(-r, r+1, size=(len(images), 2), device=images.device) 78 | images_out = torch.empty((len(images), 3, crop_size, crop_size), device=images.device, dtype=images.dtype) 79 | # The two cropping methods in this if-else produce equivalent results, but the second is faster for r > 2. 80 | if r <= 2: 81 | for sy in range(-r, r+1): 82 | for sx in range(-r, r+1): 83 | mask = (shifts[:, 0] == sy) & (shifts[:, 1] == sx) 84 | images_out[mask] = images[mask, :, r+sy:r+sy+crop_size, r+sx:r+sx+crop_size] 85 | else: 86 | images_tmp = torch.empty((len(images), 3, crop_size, crop_size+2*r), device=images.device, dtype=images.dtype) 87 | for s in range(-r, r+1): 88 | mask = (shifts[:, 0] == s) 89 | images_tmp[mask] = images[mask, :, r+s:r+s+crop_size, :] 90 | for s in range(-r, r+1): 91 | mask = (shifts[:, 1] == s) 92 | images_out[mask] = images_tmp[mask, :, :, r+s:r+s+crop_size] 93 | return images_out 94 | 95 | def make_random_square_masks(inputs, size): 96 | is_even = int(size % 2 == 0) 97 | n,c,h,w = inputs.shape 98 | 99 | # seed top-left corners of squares to cutout boxes from, in one dimension each 100 | corner_y = torch.randint(0, h-size+1, size=(n,), device=inputs.device) 101 | corner_x = torch.randint(0, w-size+1, size=(n,), device=inputs.device) 102 | 103 | # measure distance, using the center as a reference point 104 | corner_y_dists = torch.arange(h, device=inputs.device).view(1, 1, h, 1) - corner_y.view(-1, 1, 1, 1) 105 | corner_x_dists = torch.arange(w, device=inputs.device).view(1, 1, 1, w) - corner_x.view(-1, 1, 1, 1) 106 | 107 | mask_y = (corner_y_dists >= 0) * (corner_y_dists < size) 108 | mask_x = (corner_x_dists >= 0) * (corner_x_dists < size) 109 | 110 | final_mask = mask_y * mask_x 111 | 112 | return final_mask 113 | 114 | def batch_cutout(inputs, size): 115 | cutout_masks = make_random_square_masks(inputs, size) 116 | return inputs.masked_fill(cutout_masks, 0) 117 | 118 | def set_random_state(seed, state): 119 | if seed is None: 120 | # If we don't get a data seed, then make sure to randomize the state using independent generator, since 121 | # it might have already been set by the model seed. 122 | import random 123 | torch.manual_seed(random.randint(0, 2**63)) 124 | else: 125 | seed1 = 1000 * seed + state # just don't do more than 1000 epochs or else there will be overlap 126 | torch.manual_seed(seed1) 127 | 128 | class InfiniteCifarLoader: 129 | """ 130 | CIFAR-10 loader which constructs every input to be used during training during the call to __iter__. 131 | The purpose is to support cross-epoch batches (in case the batch size does not divide the number of train examples), 132 | and support stochastic iteration counts in order to preserve perfect linearity/independence. 133 | """ 134 | 135 | def __init__(self, path, train=True, batch_size=500, aug=None, altflip=True, subset_mask=None, aug_seed=None, order_seed=None): 136 | data_path = os.path.join(path, 'train.pt' if train else 'test.pt') 137 | if not os.path.exists(data_path): 138 | dset = torchvision.datasets.CIFAR10(path, download=True, train=train) 139 | images = torch.tensor(dset.data) 140 | labels = torch.tensor(dset.targets) 141 | torch.save({'images': images, 'labels': labels, 'classes': dset.classes}, data_path) 142 | 143 | data = torch.load(data_path, map_location='cuda') 144 | self.images, self.labels, self.classes = data['images'], data['labels'], data['classes'] 145 | # It's faster to load+process uint8 data than to load preprocessed fp16 data 146 | self.images = (self.images.half() / 255).permute(0, 3, 1, 2).to(memory_format=torch.channels_last) 147 | 148 | self.normalize = T.Normalize(CIFAR_MEAN, CIFAR_STD) 149 | 150 | self.aug = aug or {} 151 | for k in self.aug.keys(): 152 | assert k in ['flip', 'translate', 'cutout'], 'Unrecognized key: %s' % k 153 | 154 | self.batch_size = batch_size 155 | self.altflip = altflip 156 | self.subset_mask = subset_mask if subset_mask is not None else torch.tensor([True]*len(self.images)).cuda() 157 | self.train = train 158 | self.aug_seed = aug_seed 159 | self.order_seed = order_seed 160 | 161 | def __iter__(self): 162 | 163 | # Preprocess 164 | images0 = self.normalize(self.images) 165 | # Pre-randomly flip images in order to do alternating flip later. 166 | if self.aug.get('flip', False) and self.altflip: 167 | set_random_state(self.aug_seed, 0) 168 | images0 = batch_flip_lr(images0) 169 | # Pre-pad images to save time when doing random translation 170 | pad = self.aug.get('translate', 0) 171 | if pad > 0: 172 | images0 = F.pad(images0, (pad,)*4, 'reflect') 173 | labels0 = self.labels 174 | 175 | # Iterate forever 176 | epoch = 0 177 | batch_size = self.batch_size 178 | 179 | # In the below while-loop, we will repeatedly build a batch and then yield it. 180 | num_examples = self.subset_mask.sum().item() 181 | current_pointer = num_examples 182 | batch_images = torch.empty(0, 3, 32, 32, dtype=images0.dtype, device=images0.device) 183 | batch_labels = torch.empty(0, dtype=labels0.dtype, device=labels0.device) 184 | batch_indices = torch.empty(0, dtype=labels0.dtype, device=labels0.device) 185 | 186 | while True: 187 | 188 | # Assume we need to generate more data to add to the batch. 189 | assert len(batch_images) < batch_size 190 | 191 | # If we have already exhausted the current epoch, then begin a new one. 192 | if current_pointer >= num_examples: 193 | # If we already reached the end of the last epoch then we need to generate 194 | # a new augmented epoch of data (using random crop and alternating flip). 195 | epoch += 1 196 | 197 | set_random_state(self.aug_seed, epoch) 198 | if pad > 0: 199 | images1 = batch_crop(images0, 32) 200 | if self.aug.get('flip', False): 201 | if self.altflip: 202 | images1 = images1 if epoch % 2 == 0 else images1.flip(-1) 203 | else: 204 | images1 = batch_flip_lr(images1) 205 | if self.aug.get('cutout', 0) > 0: 206 | images1 = batch_cutout(images1, self.aug['cutout']) 207 | 208 | set_random_state(self.order_seed, epoch) 209 | indices = (torch.randperm if self.train else torch.arange)(len(self.images), device=images0.device) 210 | 211 | # The effect of doing subsetting in this manner is as follows. If the permutation wants to show us 212 | # our four examples in order [3, 2, 0, 1], and the subset mask is [True, False, True, False], 213 | # then we will be shown the examples [2, 0]. It is the subset of the ordering. 214 | # The purpose is to minimize the interaction between the subset mask and the randomness. 215 | # So that the subset causes not only a subset of the total examples to be shown, but also a subset of 216 | # the actual sequence of examples which is shown during training. 217 | indices_subset = indices[self.subset_mask[indices]] 218 | current_pointer = 0 219 | 220 | # Now we are sure to have more data in this epoch remaining. 221 | # This epoch's remaining data is given by (images1[current_pointer:], labels0[current_pointer:]) 222 | # We add more data to the batch, up to whatever is needed to make a full batch (but it might not be enough). 223 | remaining_size = batch_size - len(batch_images) 224 | 225 | # Given that we want `remaining_size` more training examples, we construct them here, using 226 | # the remaining available examples in the epoch. 227 | 228 | extra_indices = indices_subset[current_pointer:current_pointer+remaining_size] 229 | extra_images = images1[extra_indices] 230 | extra_labels = labels0[extra_indices] 231 | current_pointer += remaining_size 232 | batch_indices = torch.cat([batch_indices, extra_indices]) 233 | batch_images = torch.cat([batch_images, extra_images]) 234 | batch_labels = torch.cat([batch_labels, extra_labels]) 235 | 236 | # If we have a full batch ready then yield it and reset. 237 | if len(batch_images) == batch_size: 238 | assert len(batch_images) == len(batch_labels) 239 | yield (batch_indices, batch_images, batch_labels) 240 | batch_images = torch.empty(0, 3, 32, 32, dtype=images0.dtype, device=images0.device) 241 | batch_labels = torch.empty(0, dtype=labels0.dtype, device=labels0.device) 242 | batch_indices = torch.empty(0, dtype=labels0.dtype, device=labels0.device) 243 | 244 | ############################################ 245 | # Evaluation # 246 | ############################################ 247 | 248 | def infer(model, loader, tta_level=0): 249 | 250 | # Test-time augmentation strategy (for tta_level=2): 251 | # 1. Flip/mirror the image left-to-right (50% of the time). 252 | # 2. Translate the image by one pixel either up-and-left or down-and-right (50% of the time, 253 | # i.e. both happen 25% of the time). 254 | # 255 | # This creates 6 views per image (left/right times the two translations and no-translation), 256 | # which we evaluate and then weight according to the given probabilities. 257 | 258 | def infer_basic(inputs, net): 259 | return net(inputs).clone() 260 | 261 | def infer_mirror(inputs, net): 262 | return 0.5 * net(inputs) + 0.5 * net(inputs.flip(-1)) 263 | 264 | def infer_mirror_translate(inputs, net): 265 | logits = infer_mirror(inputs, net) 266 | pad = 1 267 | padded_inputs = F.pad(inputs, (pad,)*4, 'reflect') 268 | inputs_translate_list = [ 269 | padded_inputs[:, :, 0:32, 0:32], 270 | padded_inputs[:, :, 2:34, 2:34], 271 | ] 272 | logits_translate_list = [infer_mirror(inputs_translate, net) 273 | for inputs_translate in inputs_translate_list] 274 | logits_translate = torch.stack(logits_translate_list).mean(0) 275 | return 0.5 * logits + 0.5 * logits_translate 276 | 277 | model.eval() 278 | test_images = loader.normalize(loader.images) 279 | infer_fn = [infer_basic, infer_mirror, infer_mirror_translate][tta_level] 280 | with torch.no_grad(): 281 | return torch.cat([infer_fn(inputs, model) for inputs in test_images.split(2000)]) 282 | 283 | def evaluate(model, loader, tta_level=0): 284 | logits = infer(model, loader, tta_level) 285 | return (logits.argmax(1) == loader.labels).float().mean().item() 286 | 287 | ############################################# 288 | # Network Components # 289 | ############################################# 290 | 291 | class Flatten(nn.Module): 292 | def forward(self, x): 293 | return x.view(x.size(0), -1) 294 | 295 | class Mul(nn.Module): 296 | def __init__(self, scale): 297 | super().__init__() 298 | self.scale = scale 299 | def forward(self, x): 300 | return x * self.scale 301 | 302 | class BatchNorm(nn.BatchNorm2d): 303 | def __init__(self, num_features, eps=1e-12, 304 | weight=False, bias=True): 305 | super().__init__(num_features, eps=eps) 306 | self.weight.requires_grad = weight 307 | self.bias.requires_grad = bias 308 | # Note that PyTorch already initializes the weights to one and bias to zero 309 | 310 | class Conv(nn.Conv2d): 311 | def __init__(self, in_channels, out_channels, kernel_size=3, padding='same', bias=False): 312 | super().__init__(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias) 313 | 314 | def reset_parameters(self): 315 | super().reset_parameters() 316 | if self.bias is not None: 317 | self.bias.data.zero_() 318 | w = self.weight.data 319 | torch.nn.init.dirac_(w[:w.size(1)]) 320 | 321 | class ConvGroup(nn.Module): 322 | def __init__(self, channels_in, channels_out, depth): 323 | super().__init__() 324 | assert depth in (2, 3) 325 | self.depth = depth 326 | self.conv1 = Conv(channels_in, channels_out) 327 | self.pool = nn.MaxPool2d(2) 328 | self.norm1 = BatchNorm(channels_out) 329 | self.conv2 = Conv(channels_out, channels_out) 330 | self.norm2 = BatchNorm(channels_out) 331 | if depth == 3: 332 | self.conv3 = Conv(channels_out, channels_out) 333 | self.norm3 = BatchNorm(channels_out) 334 | self.activ = nn.GELU() 335 | 336 | def forward(self, x): 337 | x = self.conv1(x) 338 | x = self.pool(x) 339 | x = self.norm1(x) 340 | x = self.activ(x) 341 | if self.depth == 3: 342 | x0 = x 343 | x = self.conv2(x) 344 | x = self.norm2(x) 345 | x = self.activ(x) 346 | if self.depth == 3: 347 | x = self.conv3(x) 348 | x = self.norm3(x) 349 | x = x + x0 350 | x = self.activ(x) 351 | return x 352 | 353 | ############################################# 354 | # Network Definition # 355 | ############################################# 356 | 357 | def make_net(hyp): 358 | widths = hyp['widths'] 359 | scaling_factor = hyp['scaling_factor'] 360 | depth = hyp['depth'] 361 | whiten_kernel_size = 2 362 | whiten_width = 2 * 3 * whiten_kernel_size**2 363 | net = nn.Sequential( 364 | Conv(3, whiten_width, whiten_kernel_size, padding=0, bias=True), 365 | nn.GELU(), 366 | ConvGroup(whiten_width, widths['block1'], depth), 367 | ConvGroup(widths['block1'], widths['block2'], depth), 368 | ConvGroup(widths['block2'], widths['block3'], depth), 369 | nn.MaxPool2d(3), 370 | Flatten(), 371 | nn.Linear(widths['block3'], 10, bias=False), 372 | Mul(scaling_factor), 373 | ) 374 | net[0].weight.requires_grad = False 375 | net = net.half().cuda() 376 | net = net.to(memory_format=torch.channels_last) 377 | for mod in net.modules(): 378 | if isinstance(mod, BatchNorm): 379 | mod.float() 380 | return net 381 | 382 | def reinit_net(model): 383 | for m in model.modules(): 384 | if type(m) in (Conv, BatchNorm, nn.Linear): 385 | m.reset_parameters() 386 | 387 | ############################################# 388 | # Whitening Conv Initialization # 389 | ############################################# 390 | 391 | def get_patches(x, patch_shape): 392 | c, (h, w) = x.shape[1], patch_shape 393 | return x.unfold(2,h,1).unfold(3,w,1).transpose(1,3).reshape(-1,c,h,w).float() 394 | 395 | def get_whitening_parameters(patches): 396 | n,c,h,w = patches.shape 397 | patches_flat = patches.view(n, -1) 398 | est_patch_covariance = (patches_flat.T @ patches_flat) / n 399 | eigenvalues, eigenvectors = torch.linalg.eigh(est_patch_covariance, UPLO='U') 400 | return eigenvalues.flip(0).view(-1, 1, 1, 1), eigenvectors.T.reshape(c*h*w,c,h,w).flip(0) 401 | 402 | def init_whitening_conv(layer, train_set, eps=5e-4): 403 | patches = get_patches(train_set, patch_shape=layer.weight.data.shape[2:]) 404 | eigenvalues, eigenvectors = get_whitening_parameters(patches) 405 | eigenvectors_scaled = eigenvectors / torch.sqrt(eigenvalues + eps) 406 | layer.weight.data[:] = torch.cat((eigenvectors_scaled, -eigenvectors_scaled)) 407 | 408 | ############################################ 409 | # Lookahead # 410 | ############################################ 411 | 412 | class LookaheadState: 413 | def __init__(self, net): 414 | self.net_ema = {k: v.clone() for k, v in net.state_dict().items()} 415 | 416 | def update(self, net, decay): 417 | for ema_param, net_param in zip(self.net_ema.values(), net.state_dict().values()): 418 | if net_param.dtype in (torch.half, torch.float): 419 | ema_param.lerp_(net_param, 1-decay) 420 | net_param.copy_(ema_param) 421 | 422 | ############################################ 423 | # Logging # 424 | ############################################ 425 | 426 | def print_columns(columns_list, is_head=False, is_final_entry=False): 427 | print_string = '' 428 | for col in columns_list: 429 | print_string += '| %s ' % col 430 | print_string += '|' 431 | if is_head: 432 | print('-'*len(print_string)) 433 | print(print_string) 434 | if is_head or is_final_entry: 435 | print('-'*len(print_string)) 436 | 437 | logging_columns_list = ['run ', 'epoch', 'train_loss', 'train_acc', 'val_acc', 'tta_val_acc', 'total_time_seconds'] 438 | def print_training_details(variables, is_final_entry): 439 | formatted = [] 440 | for col in logging_columns_list: 441 | var = variables.get(col.strip(), None) 442 | if type(var) in (int, str): 443 | res = str(var) 444 | elif type(var) is float: 445 | res = '{:0.4f}'.format(var) 446 | else: 447 | assert var is None 448 | res = '' 449 | formatted.append(res.rjust(len(col))) 450 | print_columns(formatted, is_final_entry=is_final_entry) 451 | 452 | ############################################ 453 | # Training # 454 | ############################################ 455 | 456 | def train_proxy(hyp, model, data_seed): 457 | 458 | batch_size = hyp['opt']['batch_size'] 459 | epochs = hyp['opt']['train_epochs'] 460 | momentum = hyp['opt']['momentum'] 461 | kilostep_scale = 1024 * (1 + 1 / (1 - momentum)) 462 | lr = hyp['opt']['lr'] / kilostep_scale # un-decoupled learning rate for PyTorch SGD 463 | wd = hyp['opt']['weight_decay'] * batch_size / kilostep_scale 464 | lr_biases = lr * hyp['opt']['bias_scaler'] 465 | 466 | loss_fn = nn.CrossEntropyLoss(label_smoothing=hyp['opt']['label_smoothing'], reduction='none') 467 | train_loader = InfiniteCifarLoader('cifar10', train=True, batch_size=batch_size, aug=hyp['aug'], 468 | aug_seed=data_seed, order_seed=data_seed) 469 | steps_per_epoch = len(train_loader.images) // batch_size 470 | total_train_steps = ceil(steps_per_epoch * epochs) 471 | 472 | set_random_state(None, 0) 473 | reinit_net(model) 474 | print('Proxy parameters:', sum(p.numel() for p in model.parameters())) 475 | current_steps = 0 476 | 477 | norm_biases = [p for k, p in model.named_parameters() if 'norm' in k and p.requires_grad] 478 | other_params = [p for k, p in model.named_parameters() if 'norm' not in k and p.requires_grad] 479 | param_configs = [dict(params=norm_biases, lr=lr_biases, weight_decay=wd/lr_biases), 480 | dict(params=other_params, lr=lr, weight_decay=wd/lr)] 481 | optimizer = torch.optim.SGD(param_configs, momentum=momentum, nesterov=True) 482 | 483 | def get_lr(step): 484 | warmup_steps = int(total_train_steps * 0.1) 485 | warmdown_steps = total_train_steps - warmup_steps 486 | if step < warmup_steps: 487 | frac = step / warmup_steps 488 | return 0.2 * (1 - frac) + 1.0 * frac 489 | else: 490 | frac = (total_train_steps - step) / warmdown_steps 491 | return frac 492 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, get_lr) 493 | 494 | # Initialize the whitening layer using training images 495 | train_images = train_loader.normalize(train_loader.images[:5000]) 496 | init_whitening_conv(model._orig_mod[0], train_images) 497 | 498 | masks = [] 499 | 500 | for indices, inputs, labels in train_loader: 501 | 502 | if current_steps % steps_per_epoch == 0: 503 | epoch = current_steps // steps_per_epoch 504 | model.train() 505 | 506 | # Skip every other backward pass 507 | if current_steps % 4 == 0: 508 | outputs = model(inputs) 509 | loss1 = loss_fn(outputs, labels) 510 | mask = torch.zeros(len(inputs)).cuda().bool() 511 | mask[loss1.argsort()[-hyp['opt']['batch_size_masked']:]] = True 512 | masks.append(mask) 513 | loss = (loss1 * mask.float()).sum() 514 | optimizer.zero_grad(set_to_none=True) 515 | loss.backward() 516 | optimizer.step() 517 | else: 518 | with torch.no_grad(): 519 | outputs = model(inputs) 520 | loss1 = loss_fn(outputs, labels) 521 | mask = torch.zeros(len(inputs)).cuda().bool() 522 | mask[loss1.argsort()[-hyp['opt']['batch_size_masked']:]] = True 523 | masks.append(mask) 524 | 525 | scheduler.step() 526 | 527 | current_steps += 1 528 | if current_steps == total_train_steps: 529 | break 530 | 531 | return masks 532 | 533 | def main(run, hyp, model_proxy, model_trainbias, model_freezebias): 534 | 535 | batch_size = hyp['opt']['batch_size'] 536 | epochs = hyp['opt']['train_epochs'] 537 | momentum = hyp['opt']['momentum'] 538 | # Assuming gradients are constant in time, for Nesterov momentum, the below ratio is how much 539 | # larger the default steps will be than the underlying per-example gradients. We divide the 540 | # learning rate by this ratio in order to ensure steps are the same scale as gradients, regardless 541 | # of the choice of momentum. 542 | kilostep_scale = 1024 * (1 + 1 / (1 - momentum)) 543 | lr = hyp['opt']['lr'] / kilostep_scale # un-decoupled learning rate for PyTorch SGD 544 | wd = hyp['opt']['weight_decay'] * batch_size / kilostep_scale 545 | lr_biases = lr * hyp['opt']['bias_scaler'] 546 | 547 | set_random_state(None, 0) 548 | import random 549 | data_seed = random.randint(0, 2**50) 550 | 551 | loss_fn = nn.CrossEntropyLoss(label_smoothing=hyp['opt']['label_smoothing'], reduction='none') 552 | test_loader = InfiniteCifarLoader('cifar10', train=False, batch_size=2000) 553 | train_loader = InfiniteCifarLoader('cifar10', train=True, batch_size=batch_size, aug=hyp['aug'], 554 | aug_seed=data_seed, order_seed=data_seed) 555 | steps_per_epoch = len(train_loader.images) // batch_size 556 | total_train_steps = ceil(steps_per_epoch * epochs) 557 | 558 | set_random_state(None, 0) 559 | reinit_net(model_trainbias) 560 | print('Main model parameters:', sum(p.numel() for p in model_trainbias.parameters())) 561 | current_steps = 0 562 | 563 | norm_biases = [p for k, p in model_trainbias.named_parameters() if 'norm' in k] 564 | other_params = [p for k, p in model_trainbias.named_parameters() if 'norm' not in k] 565 | param_configs = [dict(params=norm_biases, lr=lr_biases, weight_decay=wd/lr_biases), 566 | dict(params=other_params, lr=lr, weight_decay=wd/lr)] 567 | optimizer_trainbias = torch.optim.SGD(param_configs, momentum=momentum, nesterov=True) 568 | 569 | norm_biases = [p for k, p in model_freezebias.named_parameters() if 'norm' in k] 570 | other_params = [p for k, p in model_freezebias.named_parameters() if 'norm' not in k] 571 | param_configs = [dict(params=norm_biases, lr=lr_biases, weight_decay=wd/lr_biases), 572 | dict(params=other_params, lr=lr, weight_decay=wd/lr)] 573 | optimizer_freezebias = torch.optim.SGD(param_configs, momentum=momentum, nesterov=True) 574 | 575 | def get_lr(step): 576 | warmup_steps = int(total_train_steps * 0.1) 577 | warmdown_steps = total_train_steps - warmup_steps 578 | if step < warmup_steps: 579 | frac = step / warmup_steps 580 | return 0.2 * (1 - frac) + 1.0 * frac 581 | else: 582 | frac = (total_train_steps - step) / warmdown_steps 583 | return frac 584 | scheduler_trainbias = torch.optim.lr_scheduler.LambdaLR(optimizer_trainbias, get_lr) 585 | scheduler_freezebias = torch.optim.lr_scheduler.LambdaLR(optimizer_freezebias, get_lr) 586 | 587 | alpha_schedule = 0.95**5 * (torch.arange(total_train_steps+1) / total_train_steps)**3 588 | lookahead_state = LookaheadState(model_trainbias) 589 | 590 | # For accurately timing GPU code 591 | starter = torch.cuda.Event(enable_timing=True) 592 | ender = torch.cuda.Event(enable_timing=True) 593 | total_time_seconds = 0.0 594 | 595 | # Initialize the whitening layer using training images 596 | starter.record() 597 | train_images = train_loader.normalize(train_loader.images[:5000]) 598 | init_whitening_conv(model_trainbias._orig_mod[0], train_images) 599 | ender.record() 600 | torch.cuda.synchronize() 601 | total_time_seconds += 1e-3 * starter.elapsed_time(ender) 602 | 603 | # Do a small proxy run to collect masks for use in fullsize run 604 | print('Training small proxy...') 605 | starter.record() 606 | masks = iter(train_proxy(hyp, model_proxy, data_seed)) 607 | ender.record() 608 | torch.cuda.synchronize() 609 | total_time_seconds += 1e-3 * starter.elapsed_time(ender) 610 | 611 | for indices, inputs, labels in train_loader: 612 | 613 | # After training the whiten bias for some epochs, swap in the compiled model with frozen bias 614 | if current_steps == 0: 615 | model = model_trainbias 616 | optimizer = optimizer_trainbias 617 | scheduler = scheduler_trainbias 618 | elif epoch == hyp['opt']['whiten_bias_epochs'] * steps_per_epoch: 619 | model = model_freezebias 620 | optimizer = optimizer_freezebias 621 | scheduler = scheduler_freezebias 622 | model.load_state_dict(model_trainbias.state_dict()) 623 | optimizer.load_state_dict(optimizer_trainbias.state_dict()) 624 | scheduler.load_state_dict(scheduler_trainbias.state_dict()) 625 | 626 | #################### 627 | # Training # 628 | #################### 629 | 630 | if current_steps % steps_per_epoch == 0: 631 | epoch = current_steps // steps_per_epoch 632 | starter.record() 633 | model.train() 634 | 635 | mask = next(masks) 636 | inputs = inputs[mask] 637 | labels = labels[mask] 638 | outputs = model(inputs) 639 | loss = loss_fn(outputs, labels).sum() 640 | 641 | optimizer.zero_grad(set_to_none=True) 642 | loss.backward() 643 | optimizer.step() 644 | scheduler.step() 645 | 646 | current_steps += 1 647 | if current_steps % 5 == 0: 648 | lookahead_state.update(model, decay=alpha_schedule[current_steps].item()) 649 | if current_steps == total_train_steps: 650 | if lookahead_state is not None: 651 | lookahead_state.update(model, decay=1.0) 652 | 653 | if (current_steps % steps_per_epoch == 0) or (current_steps == total_train_steps): 654 | ender.record() 655 | torch.cuda.synchronize() 656 | total_time_seconds += 1e-3 * starter.elapsed_time(ender) 657 | 658 | #################### 659 | # Evaluation # 660 | #################### 661 | 662 | # Save the accuracy and loss from the last training batch of the epoch 663 | train_acc = (outputs.detach().argmax(1) == labels).float().mean().item() 664 | train_loss = loss.item() / batch_size 665 | val_acc = evaluate(model, test_loader, tta_level=0) 666 | print_training_details(locals(), is_final_entry=False) 667 | run = None # Only print the run number once 668 | 669 | if current_steps == total_train_steps: 670 | break 671 | 672 | #################### 673 | # TTA Evaluation # 674 | #################### 675 | 676 | starter.record() 677 | tta_val_acc = evaluate(model, test_loader, tta_level=hyp['net']['tta_level']) 678 | ender.record() 679 | torch.cuda.synchronize() 680 | total_time_seconds += 1e-3 * starter.elapsed_time(ender) 681 | 682 | epoch = 'eval' 683 | print_training_details(locals(), is_final_entry=True) 684 | 685 | return tta_val_acc 686 | 687 | if __name__ == "__main__": 688 | with open(sys.argv[0]) as f: 689 | code = f.read() 690 | 691 | model_proxy = make_net(hyp['proxy']) 692 | model_proxy[0].bias.requires_grad = False 693 | model_trainbias = make_net(hyp['net']) 694 | model_freezebias = make_net(hyp['net']) 695 | model_freezebias[0].bias.requires_grad = False 696 | model_proxy = torch.compile(model_proxy, mode='max-autotune') 697 | model_trainbias = torch.compile(model_trainbias, mode='max-autotune') 698 | model_freezebias = torch.compile(model_freezebias, mode='max-autotune') 699 | 700 | print_columns(logging_columns_list, is_head=True) 701 | accs = torch.tensor([main(run, hyp, model_proxy, model_trainbias, model_freezebias) 702 | for run in range(200)]) 703 | print('Mean: %.4f Std: %.4f' % (accs.mean(), accs.std())) 704 | 705 | log = {'code': code, 'accs': accs} 706 | log_dir = os.path.join('logs', str(uuid.uuid4())) 707 | os.makedirs(log_dir, exist_ok=True) 708 | log_path = os.path.join(log_dir, 'log.pt') 709 | print(os.path.abspath(log_path)) 710 | torch.save(log, os.path.join(log_dir, 'log.pt')) 711 | 712 | -------------------------------------------------------------------------------- /img/airbench94_intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KellerJordan/cifar10-airbench/f1c599c4af24aca803e3e65a6a7e52a502e09d0e/img/airbench94_intro.png -------------------------------------------------------------------------------- /img/alternating_flip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KellerJordan/cifar10-airbench/f1c599c4af24aca803e3e65a6a7e52a502e09d0e/img/alternating_flip.png -------------------------------------------------------------------------------- /legacy/airbench94.py: -------------------------------------------------------------------------------- 1 | # Uncompiled variant of airbench94_compiled.py 2 | # 3.83s runtime on an A100; 0.36 PFLOPs. 3 | # Evidence: 94.01 average accuracy in n=1000 runs. 4 | # 5 | # We recorded the runtime of 3.83 seconds on an NVIDIA A100-SXM4-80GB with the following nvidia-smi: 6 | # NVIDIA-SMI 515.105.01 Driver Version: 515.105.01 CUDA Version: 11.7 7 | # torch.__version__ == '2.1.2+cu118' 8 | 9 | ############################################# 10 | # Setup/Hyperparameters # 11 | ############################################# 12 | 13 | import os 14 | import sys 15 | import uuid 16 | from math import ceil 17 | 18 | import torch 19 | from torch import nn 20 | import torch.nn.functional as F 21 | import torchvision 22 | import torchvision.transforms as T 23 | 24 | torch.backends.cudnn.benchmark = True 25 | 26 | # We express the main training hyperparameters (batch size, learning rate, momentum, and weight decay) 27 | # in decoupled form, so that each one can be tuned independently. This accomplishes the following: 28 | # * Assuming time-constant gradients, the average step size is decoupled from everything but the lr. 29 | # * The size of the weight decay update is decoupled from everything but the wd. 30 | # In constrast, normally when we increase the (Nesterov) momentum, this also scales up the step size 31 | # proportionally to 1 + 1 / (1 - momentum), meaning we cannot change momentum without having to re-tune 32 | # the learning rate. Similarly, normally when we increase the learning rate this also increases the size 33 | # of the weight decay, requiring a proportional decrease in the wd to maintain the same decay strength. 34 | # 35 | # The practical impact is that hyperparameter tuning is faster, since this parametrization allows each 36 | # one to be tuned independently. See https://myrtle.ai/learn/how-to-train-your-resnet-5-hyperparameters/. 37 | 38 | hyp = { 39 | 'opt': { 40 | 'train_epochs': 9.9, 41 | 'batch_size': 1024, 42 | 'lr': 11.5, # learning rate per 1024 examples 43 | 'momentum': 0.85, 44 | 'weight_decay': 0.0153, # weight decay per 1024 examples (decoupled from learning rate) 45 | 'bias_scaler': 64.0, # scales up learning rate (but not weight decay) for BatchNorm biases 46 | 'label_smoothing': 0.2, 47 | 'whiten_bias_epochs': 3, # how many epochs to train the whitening layer bias before freezing 48 | }, 49 | 'aug': { 50 | 'flip': True, 51 | 'translate': 2, 52 | }, 53 | 'net': { 54 | 'widths': { 55 | 'block1': 64, 56 | 'block2': 256, 57 | 'block3': 256, 58 | }, 59 | 'batchnorm_momentum': 0.6, 60 | 'scaling_factor': 1/9, 61 | 'tta_level': 2, # the level of test-time augmentation: 0=none, 1=mirror, 2=mirror+translate 62 | }, 63 | } 64 | 65 | ############################################# 66 | # DataLoader # 67 | ############################################# 68 | 69 | CIFAR_MEAN = torch.tensor((0.4914, 0.4822, 0.4465)) 70 | CIFAR_STD = torch.tensor((0.2470, 0.2435, 0.2616)) 71 | 72 | def batch_flip_lr(inputs): 73 | flip_mask = (torch.rand(len(inputs), device=inputs.device) < 0.5).view(-1, 1, 1, 1) 74 | return torch.where(flip_mask, inputs.flip(-1), inputs) 75 | 76 | def batch_crop(images, crop_size): 77 | r = (images.size(-1) - crop_size)//2 78 | shifts = torch.randint(-r, r+1, size=(len(images), 2), device=images.device) 79 | images_out = torch.empty((len(images), 3, crop_size, crop_size), device=images.device, dtype=images.dtype) 80 | # The two cropping methods in this if-else produce equivalent results, but the second is faster for r > 2. 81 | if r <= 2: 82 | for sy in range(-r, r+1): 83 | for sx in range(-r, r+1): 84 | mask = (shifts[:, 0] == sy) & (shifts[:, 1] == sx) 85 | images_out[mask] = images[mask, :, r+sy:r+sy+crop_size, r+sx:r+sx+crop_size] 86 | else: 87 | images_tmp = torch.empty((len(images), 3, crop_size, crop_size+2*r), device=images.device, dtype=images.dtype) 88 | for s in range(-r, r+1): 89 | mask = (shifts[:, 0] == s) 90 | images_tmp[mask] = images[mask, :, r+s:r+s+crop_size, :] 91 | for s in range(-r, r+1): 92 | mask = (shifts[:, 1] == s) 93 | images_out[mask] = images_tmp[mask, :, :, r+s:r+s+crop_size] 94 | return images_out 95 | 96 | class CifarLoader: 97 | 98 | def __init__(self, path, train=True, batch_size=500, aug=None, drop_last=None, shuffle=None, gpu=0): 99 | data_path = os.path.join(path, 'train.pt' if train else 'test.pt') 100 | if not os.path.exists(data_path): 101 | dset = torchvision.datasets.CIFAR10(path, download=True, train=train) 102 | images = torch.tensor(dset.data) 103 | labels = torch.tensor(dset.targets) 104 | torch.save({'images': images, 'labels': labels, 'classes': dset.classes}, data_path) 105 | 106 | data = torch.load(data_path, map_location=torch.device(gpu)) 107 | self.images, self.labels, self.classes = data['images'], data['labels'], data['classes'] 108 | # It's faster to load+process uint8 data than to load preprocessed fp16 data 109 | self.images = (self.images.half() / 255).permute(0, 3, 1, 2).to(memory_format=torch.channels_last) 110 | 111 | self.normalize = T.Normalize(CIFAR_MEAN, CIFAR_STD) 112 | self.proc_images = {} # Saved results of image processing to be done on the first epoch 113 | self.epoch = 0 114 | 115 | self.aug = aug or {} 116 | for k in self.aug.keys(): 117 | assert k in ['flip', 'translate'], 'Unrecognized key: %s' % k 118 | 119 | self.batch_size = batch_size 120 | self.drop_last = train if drop_last is None else drop_last 121 | self.shuffle = train if shuffle is None else shuffle 122 | 123 | def __len__(self): 124 | return len(self.images)//self.batch_size if self.drop_last else ceil(len(self.images)/self.batch_size) 125 | 126 | def __iter__(self): 127 | 128 | if self.epoch == 0: 129 | images = self.proc_images['norm'] = self.normalize(self.images) 130 | # Pre-flip images in order to do every-other epoch flipping scheme 131 | if self.aug.get('flip', False): 132 | images = self.proc_images['flip'] = batch_flip_lr(images) 133 | # Pre-pad images to save time when doing random translation 134 | pad = self.aug.get('translate', 0) 135 | if pad > 0: 136 | self.proc_images['pad'] = F.pad(images, (pad,)*4, 'reflect') 137 | 138 | if self.aug.get('translate', 0) > 0: 139 | images = batch_crop(self.proc_images['pad'], self.images.shape[-2]) 140 | elif self.aug.get('flip', False): 141 | images = self.proc_images['flip'] 142 | else: 143 | images = self.proc_images['norm'] 144 | # Flip all images together every other epoch. This increases diversity relative to random flipping 145 | if self.aug.get('flip', False): 146 | if self.epoch % 2 == 1: 147 | images = images.flip(-1) 148 | 149 | self.epoch += 1 150 | 151 | indices = (torch.randperm if self.shuffle else torch.arange)(len(images), device=images.device) 152 | for i in range(len(self)): 153 | idxs = indices[i*self.batch_size:(i+1)*self.batch_size] 154 | yield (images[idxs], self.labels[idxs]) 155 | 156 | ############################################# 157 | # Network Components # 158 | ############################################# 159 | 160 | class Flatten(nn.Module): 161 | def forward(self, x): 162 | return x.view(x.size(0), -1) 163 | 164 | class Mul(nn.Module): 165 | def __init__(self, scale): 166 | super().__init__() 167 | self.scale = scale 168 | def forward(self, x): 169 | return x * self.scale 170 | 171 | class BatchNorm(nn.BatchNorm2d): 172 | def __init__(self, num_features, momentum, eps=1e-12, 173 | weight=False, bias=True): 174 | super().__init__(num_features, eps=eps, momentum=1-momentum) 175 | self.weight.requires_grad = weight 176 | self.bias.requires_grad = bias 177 | # Note that PyTorch already initializes the weights to one and bias to zero 178 | 179 | class Conv(nn.Conv2d): 180 | def __init__(self, in_channels, out_channels, kernel_size=3, padding='same', bias=False): 181 | super().__init__(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias) 182 | 183 | def reset_parameters(self): 184 | super().reset_parameters() 185 | if self.bias is not None: 186 | self.bias.data.zero_() 187 | w = self.weight.data 188 | torch.nn.init.dirac_(w[:w.size(1)]) 189 | 190 | class ConvGroup(nn.Module): 191 | def __init__(self, channels_in, channels_out, batchnorm_momentum): 192 | super().__init__() 193 | self.conv1 = Conv(channels_in, channels_out) 194 | self.pool = nn.MaxPool2d(2) 195 | self.norm1 = BatchNorm(channels_out, batchnorm_momentum) 196 | self.conv2 = Conv(channels_out, channels_out) 197 | self.norm2 = BatchNorm(channels_out, batchnorm_momentum) 198 | self.activ = nn.GELU() 199 | 200 | def forward(self, x): 201 | x = self.conv1(x) 202 | x = self.pool(x) 203 | x = self.norm1(x) 204 | x = self.activ(x) 205 | x = self.conv2(x) 206 | x = self.norm2(x) 207 | x = self.activ(x) 208 | return x 209 | 210 | ############################################# 211 | # Network Definition # 212 | ############################################# 213 | 214 | def make_net(): 215 | widths = hyp['net']['widths'] 216 | batchnorm_momentum = hyp['net']['batchnorm_momentum'] 217 | whiten_kernel_size = 2 218 | whiten_width = 2 * 3 * whiten_kernel_size**2 219 | net = nn.Sequential( 220 | Conv(3, whiten_width, whiten_kernel_size, padding=0, bias=True), 221 | nn.GELU(), 222 | ConvGroup(whiten_width, widths['block1'], batchnorm_momentum), 223 | ConvGroup(widths['block1'], widths['block2'], batchnorm_momentum), 224 | ConvGroup(widths['block2'], widths['block3'], batchnorm_momentum), 225 | nn.MaxPool2d(3), 226 | Flatten(), 227 | nn.Linear(widths['block3'], 10, bias=False), 228 | Mul(hyp['net']['scaling_factor']), 229 | ) 230 | net[0].weight.requires_grad = False 231 | net = net.half().cuda() 232 | net = net.to(memory_format=torch.channels_last) 233 | for mod in net.modules(): 234 | if isinstance(mod, BatchNorm): 235 | mod.float() 236 | return net 237 | 238 | ############################################# 239 | # Whitening Conv Initialization # 240 | ############################################# 241 | 242 | def get_patches(x, patch_shape): 243 | c, (h, w) = x.shape[1], patch_shape 244 | return x.unfold(2,h,1).unfold(3,w,1).transpose(1,3).reshape(-1,c,h,w).float() 245 | 246 | def get_whitening_parameters(patches): 247 | n,c,h,w = patches.shape 248 | patches_flat = patches.view(n, -1) 249 | est_patch_covariance = (patches_flat.T @ patches_flat) / n 250 | eigenvalues, eigenvectors = torch.linalg.eigh(est_patch_covariance, UPLO='U') 251 | return eigenvalues.flip(0).view(-1, 1, 1, 1), eigenvectors.T.reshape(c*h*w,c,h,w).flip(0) 252 | 253 | def init_whitening_conv(layer, train_set, eps=5e-4): 254 | patches = get_patches(train_set, patch_shape=layer.weight.data.shape[2:]) 255 | eigenvalues, eigenvectors = get_whitening_parameters(patches) 256 | eigenvectors_scaled = eigenvectors / torch.sqrt(eigenvalues + eps) 257 | layer.weight.data[:] = torch.cat((eigenvectors_scaled, -eigenvectors_scaled)) 258 | 259 | ############################################ 260 | # Lookahead # 261 | ############################################ 262 | 263 | class LookaheadState: 264 | def __init__(self, net): 265 | self.net_ema = {k: v.clone() for k, v in net.state_dict().items()} 266 | 267 | def update(self, net, decay): 268 | for ema_param, net_param in zip(self.net_ema.values(), net.state_dict().values()): 269 | if net_param.dtype in (torch.half, torch.float): 270 | ema_param.lerp_(net_param, 1-decay) 271 | net_param.copy_(ema_param) 272 | 273 | ############################################ 274 | # Logging # 275 | ############################################ 276 | 277 | def print_columns(columns_list, is_head=False, is_final_entry=False): 278 | print_string = '' 279 | for col in columns_list: 280 | print_string += '| %s ' % col 281 | print_string += '|' 282 | if is_head: 283 | print('-'*len(print_string)) 284 | print(print_string) 285 | if is_head or is_final_entry: 286 | print('-'*len(print_string)) 287 | 288 | logging_columns_list = ['run ', 'epoch', 'train_loss', 'train_acc', 'val_acc', 'tta_val_acc', 'total_time_seconds'] 289 | def print_training_details(variables, is_final_entry): 290 | formatted = [] 291 | for col in logging_columns_list: 292 | var = variables.get(col.strip(), None) 293 | if type(var) in (int, str): 294 | res = str(var) 295 | elif type(var) is float: 296 | res = '{:0.4f}'.format(var) 297 | else: 298 | assert var is None 299 | res = '' 300 | formatted.append(res.rjust(len(col))) 301 | print_columns(formatted, is_final_entry=is_final_entry) 302 | 303 | ############################################ 304 | # Evaluation # 305 | ############################################ 306 | 307 | def infer(model, loader, tta_level=0): 308 | 309 | # Test-time augmentation strategy (for tta_level=2): 310 | # 1. Flip/mirror the image left-to-right (50% of the time). 311 | # 2. Translate the image by one pixel either up-and-left or down-and-right (50% of the time, 312 | # i.e. both happen 25% of the time). 313 | # 314 | # This creates 6 views per image (left/right times the two translations and no-translation), 315 | # which we evaluate and then weight according to the given probabilities. 316 | 317 | def infer_basic(inputs, net): 318 | return net(inputs).clone() 319 | 320 | def infer_mirror(inputs, net): 321 | return 0.5 * net(inputs) + 0.5 * net(inputs.flip(-1)) 322 | 323 | def infer_mirror_translate(inputs, net): 324 | logits = infer_mirror(inputs, net) 325 | pad = 1 326 | padded_inputs = F.pad(inputs, (pad,)*4, 'reflect') 327 | inputs_translate_list = [ 328 | padded_inputs[:, :, 0:32, 0:32], 329 | padded_inputs[:, :, 2:34, 2:34], 330 | ] 331 | logits_translate_list = [infer_mirror(inputs_translate, net) 332 | for inputs_translate in inputs_translate_list] 333 | logits_translate = torch.stack(logits_translate_list).mean(0) 334 | return 0.5 * logits + 0.5 * logits_translate 335 | 336 | model.eval() 337 | test_images = loader.normalize(loader.images) 338 | infer_fn = [infer_basic, infer_mirror, infer_mirror_translate][tta_level] 339 | with torch.no_grad(): 340 | return torch.cat([infer_fn(inputs, model) for inputs in test_images.split(2000)]) 341 | 342 | def evaluate(model, loader, tta_level=0): 343 | logits = infer(model, loader, tta_level) 344 | return (logits.argmax(1) == loader.labels).float().mean().item() 345 | 346 | ############################################ 347 | # Training # 348 | ############################################ 349 | 350 | def main(run): 351 | 352 | batch_size = hyp['opt']['batch_size'] 353 | epochs = hyp['opt']['train_epochs'] 354 | momentum = hyp['opt']['momentum'] 355 | # Assuming gradients are constant in time, for Nesterov momentum, the below ratio is how much 356 | # larger the default steps will be than the underlying per-example gradients. We divide the 357 | # learning rate by this ratio in order to ensure steps are the same scale as gradients, regardless 358 | # of the choice of momentum. 359 | kilostep_scale = 1024 * (1 + 1 / (1 - momentum)) 360 | lr = hyp['opt']['lr'] / kilostep_scale # un-decoupled learning rate for PyTorch SGD 361 | wd = hyp['opt']['weight_decay'] * batch_size / kilostep_scale 362 | lr_biases = lr * hyp['opt']['bias_scaler'] 363 | 364 | loss_fn = nn.CrossEntropyLoss(label_smoothing=hyp['opt']['label_smoothing'], reduction='none') 365 | test_loader = CifarLoader('cifar10', train=False, batch_size=2000) 366 | train_loader = CifarLoader('cifar10', train=True, batch_size=batch_size, aug=hyp['aug']) 367 | if run == 'warmup': 368 | # The only purpose of the first run is to warmup, so we can use dummy data 369 | train_loader.labels = torch.randint(0, 10, size=(len(train_loader.labels),), device=train_loader.labels.device) 370 | total_train_steps = ceil(len(train_loader) * epochs) 371 | 372 | model = make_net() 373 | current_steps = 0 374 | 375 | norm_biases = [p for k, p in model.named_parameters() if 'norm' in k and p.requires_grad] 376 | other_params = [p for k, p in model.named_parameters() if 'norm' not in k and p.requires_grad] 377 | param_configs = [dict(params=norm_biases, lr=lr_biases, weight_decay=wd/lr_biases), 378 | dict(params=other_params, lr=lr, weight_decay=wd/lr)] 379 | optimizer = torch.optim.SGD(param_configs, momentum=momentum, nesterov=True) 380 | 381 | def get_lr(step): 382 | warmup_steps = int(total_train_steps * 0.23) 383 | warmdown_steps = total_train_steps - warmup_steps 384 | if step < warmup_steps: 385 | frac = step / warmup_steps 386 | return 0.2 * (1 - frac) + 1.0 * frac 387 | else: 388 | frac = (step - warmup_steps) / warmdown_steps 389 | return 1.0 * (1 - frac) + 0.07 * frac 390 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, get_lr) 391 | 392 | alpha_schedule = 0.95**5 * (torch.arange(total_train_steps+1) / total_train_steps)**3 393 | lookahead_state = LookaheadState(model) 394 | 395 | # For accurately timing GPU code 396 | starter = torch.cuda.Event(enable_timing=True) 397 | ender = torch.cuda.Event(enable_timing=True) 398 | total_time_seconds = 0.0 399 | 400 | # Initialize the whitening layer using training images 401 | starter.record() 402 | train_images = train_loader.normalize(train_loader.images[:5000]) 403 | init_whitening_conv(model[0], train_images) 404 | ender.record() 405 | torch.cuda.synchronize() 406 | total_time_seconds += 1e-3 * starter.elapsed_time(ender) 407 | 408 | for epoch in range(ceil(epochs)): 409 | 410 | model[0].bias.requires_grad = (epoch < hyp['opt']['whiten_bias_epochs']) 411 | 412 | #################### 413 | # Training # 414 | #################### 415 | 416 | starter.record() 417 | 418 | model.train() 419 | for inputs, labels in train_loader: 420 | 421 | outputs = model(inputs) 422 | loss = loss_fn(outputs, labels).sum() 423 | optimizer.zero_grad(set_to_none=True) 424 | loss.backward() 425 | optimizer.step() 426 | scheduler.step() 427 | 428 | current_steps += 1 429 | 430 | if current_steps % 5 == 0: 431 | lookahead_state.update(model, decay=alpha_schedule[current_steps].item()) 432 | 433 | if current_steps >= total_train_steps: 434 | if lookahead_state is not None: 435 | lookahead_state.update(model, decay=1.0) 436 | break 437 | 438 | ender.record() 439 | torch.cuda.synchronize() 440 | total_time_seconds += 1e-3 * starter.elapsed_time(ender) 441 | 442 | #################### 443 | # Evaluation # 444 | #################### 445 | 446 | # Save the accuracy and loss from the last training batch of the epoch 447 | train_acc = (outputs.detach().argmax(1) == labels).float().mean().item() 448 | train_loss = loss.item() / batch_size 449 | val_acc = evaluate(model, test_loader, tta_level=0) 450 | print_training_details(locals(), is_final_entry=False) 451 | run = None # Only print the run number once 452 | 453 | #################### 454 | # TTA Evaluation # 455 | #################### 456 | 457 | starter.record() 458 | tta_val_acc = evaluate(model, test_loader, tta_level=hyp['net']['tta_level']) 459 | ender.record() 460 | torch.cuda.synchronize() 461 | total_time_seconds += 1e-3 * starter.elapsed_time(ender) 462 | 463 | epoch = 'eval' 464 | print_training_details(locals(), is_final_entry=True) 465 | 466 | return tta_val_acc 467 | 468 | if __name__ == "__main__": 469 | with open(sys.argv[0]) as f: 470 | code = f.read() 471 | 472 | print_columns(logging_columns_list, is_head=True) 473 | #main('warmup') 474 | accs = torch.tensor([main(run) for run in range(25)]) 475 | print('Mean: %.4f Std: %.4f' % (accs.mean(), accs.std())) 476 | 477 | log = {'code': code, 'accs': accs} 478 | log_dir = os.path.join('logs', str(uuid.uuid4())) 479 | os.makedirs(log_dir, exist_ok=True) 480 | log_path = os.path.join(log_dir, 'log.pt') 481 | print(os.path.abspath(log_path)) 482 | torch.save(log, os.path.join(log_dir, 'log.pt')) 483 | 484 | -------------------------------------------------------------------------------- /legacy/airbench94_compiled.py: -------------------------------------------------------------------------------- 1 | # airbench94_compiled.py 2 | # 3 | # This script is designed to reach 94% accuracy on the CIFAR-10 test-set in the shortest possible time 4 | # after first seeing the training set. It runs in 3.09 seconds on a single NVIDIA A100. 5 | # 6 | # It contains the following methods: 7 | # 8 | # 1. The network architecture is an 8-layer convnet with whitening and identity initialization. 9 | # * Following Page (2018), the first convolution is initialized as a frozen patch-whitening layer 10 | # using statistics from the training images. Additionally, the logit output is downscaled and 11 | # BatchNorm affine weights are disabled. 12 | # * Following hlb-CIFAR10, the whitening layer has patch size 2, precedes an activation, and is 13 | # concatenated with its negation to ensure completeness. The six remaining convolutional layers 14 | # lack residual connections and are initialized as identity transforms wherever possible. The 15 | # 8-layer architecture is also following hlb-CIFAR10, with reduced width in the final layer. 16 | # * We add a learnable bias to the whitening layer, which reduces the number of steps to 94% by 17 | # 5-10%. It converges quickly so we save time by freezing it after 3 epochs. 18 | # 2. The training data augmentation is horizontal flipping and random two-pixel translation. The 19 | # horizontal flipping uses novel method. At epoch one images are randomly flipped as usual. 20 | # At epoch two we flip exactly those images which weren't flipped in the first epoch. Epoch three 21 | # flips the same images as epoch one, four the same as two, and so on. This decreases the number 22 | # of steps to 94% accuracy by around 10% compared to standard random flipping. 23 | # 3. Test images are augmented with horizontal flipping, and one-pixel translation to the upper- 24 | # left and lower-right, for a total of six forward passes per test image. 25 | # 4. Following Page (2018) we use Nesterov SGD with a triangular learning rate schedule and increased 26 | # learning rate for BatchNorm biases. And following hlb-CIFAR10 we use a lookahead-like scheme with 27 | # slow decay rate at the end of training. 28 | # 5. We use GPU-accelerated dataloading and augmentation. 29 | # 6. We use torch.compile with mode='max-autotune'. 30 | # 31 | # To confirm that the mean accuracy is above 94%, we ran a test of n=1000 runs, which yielded an 32 | # average accuracy of 94.01% (p<0.0001 for the true mean being below 94%, via t-test). 33 | # 34 | # The runtime of 3.09 seconds was recorded on an NVIDIA A100-SXM4-40GB with the following nvidia-smi: 35 | # NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 36 | # torch.__version__ == '2.4.0+cu121' 37 | # 38 | # Note that the first time this script is run, compilation takes several minutes. See airbench94.py for 39 | # a script with much less warmup time. 40 | # 41 | # This script is descended from hlb-CIFAR10 [1], which is descended from [2]. The latter was the winning 42 | # submission to the Stanford DAWNbench competition for CIFAR-10 in 2018, with a time of 26 seconds to 43 | # 94% accuracy on an NVIDIA V100. 44 | # 45 | # Version 0.7.0 of hlb-CIFAR10 [1] uses 587 TFLOPs and runs in 6.2 seconds. The final training script 46 | # from David Page's series "How to Train Your ResNet" [2] uses 1,148 TFLOPs and runs in 14.9 seconds 47 | # on an A100. And a standard 200-epoch ResNet18 training uses ~30,000 TFLOPs and runs in minutes. 48 | # 49 | # This script trains an 8-layer convnet with 2M parameters and 0.28 GFLOPs per forward pass. The entire 50 | # training run uses 358 TFLOPs, which could theoretically take 1.15 A100-seconds at perfect GPU utilization. 51 | # 52 | # 1. tysam-code. "CIFAR-10 hyperlightspeedbench." https://github.com/tysam-code/hlb-CIFAR10. Jan 01 (2024). 53 | # 2. Page, David. "How to train your resnet." Myrtle, https://myrtle.ai/learn/how-to-train-your-resnet-8-bag-of-tricks/. Sept 24 (2018). 54 | 55 | ############################################# 56 | # Setup/Hyperparameters # 57 | ############################################# 58 | 59 | import os 60 | import sys 61 | import uuid 62 | from math import ceil 63 | 64 | import torch 65 | from torch import nn 66 | import torch.nn.functional as F 67 | import torchvision 68 | import torchvision.transforms as T 69 | 70 | torch.backends.cudnn.benchmark = True 71 | 72 | # We express the main training hyperparameters (batch size, learning rate, momentum, and weight decay) 73 | # in decoupled form, so that each one can be tuned independently. This accomplishes the following: 74 | # * Assuming time-constant gradients, the average step size is decoupled from everything but the lr. 75 | # * The size of the weight decay update is decoupled from everything but the wd. 76 | # In constrast, normally when we increase the (Nesterov) momentum, this also scales up the step size 77 | # proportionally to 1 + 1 / (1 - momentum), meaning we cannot change momentum without having to re-tune 78 | # the learning rate. Similarly, normally when we increase the learning rate this also increases the size 79 | # of the weight decay, requiring a proportional decrease in the wd to maintain the same decay strength. 80 | # 81 | # The practical impact is that hyperparameter tuning is faster, since this parametrization allows each 82 | # one to be tuned independently. See https://myrtle.ai/learn/how-to-train-your-resnet-5-hyperparameters/. 83 | 84 | hyp = { 85 | 'opt': { 86 | 'train_epochs': 9.9, 87 | 'batch_size': 1024, 88 | 'lr': 11.5, # learning rate per 1024 examples 89 | 'momentum': 0.85, 90 | 'weight_decay': 0.0153, # weight decay per 1024 examples (decoupled from learning rate) 91 | 'bias_scaler': 64.0, # scales up learning rate (but not weight decay) for BatchNorm biases 92 | 'label_smoothing': 0.2, 93 | 'whiten_bias_epochs': 3, # how many epochs to train the whitening layer bias before freezing 94 | }, 95 | 'aug': { 96 | 'flip': True, 97 | 'translate': 2, 98 | }, 99 | 'net': { 100 | 'widths': { 101 | 'block1': 64, 102 | 'block2': 256, 103 | 'block3': 256, 104 | }, 105 | 'batchnorm_momentum': 0.6, 106 | 'scaling_factor': 1/9, 107 | 'tta_level': 2, # the level of test-time augmentation: 0=none, 1=mirror, 2=mirror+translate 108 | }, 109 | } 110 | 111 | ############################################# 112 | # DataLoader # 113 | ############################################# 114 | 115 | CIFAR_MEAN = torch.tensor((0.4914, 0.4822, 0.4465)) 116 | CIFAR_STD = torch.tensor((0.2470, 0.2435, 0.2616)) 117 | 118 | def batch_flip_lr(inputs): 119 | flip_mask = (torch.rand(len(inputs), device=inputs.device) < 0.5).view(-1, 1, 1, 1) 120 | return torch.where(flip_mask, inputs.flip(-1), inputs) 121 | 122 | def batch_crop(images, crop_size): 123 | r = (images.size(-1) - crop_size)//2 124 | shifts = torch.randint(-r, r+1, size=(len(images), 2), device=images.device) 125 | images_out = torch.empty((len(images), 3, crop_size, crop_size), device=images.device, dtype=images.dtype) 126 | # The two cropping methods in this if-else produce equivalent results, but the second is faster for r > 2. 127 | if r <= 2: 128 | for sy in range(-r, r+1): 129 | for sx in range(-r, r+1): 130 | mask = (shifts[:, 0] == sy) & (shifts[:, 1] == sx) 131 | images_out[mask] = images[mask, :, r+sy:r+sy+crop_size, r+sx:r+sx+crop_size] 132 | else: 133 | images_tmp = torch.empty((len(images), 3, crop_size, crop_size+2*r), device=images.device, dtype=images.dtype) 134 | for s in range(-r, r+1): 135 | mask = (shifts[:, 0] == s) 136 | images_tmp[mask] = images[mask, :, r+s:r+s+crop_size, :] 137 | for s in range(-r, r+1): 138 | mask = (shifts[:, 1] == s) 139 | images_out[mask] = images_tmp[mask, :, :, r+s:r+s+crop_size] 140 | return images_out 141 | 142 | class CifarLoader: 143 | 144 | def __init__(self, path, train=True, batch_size=500, aug=None, drop_last=None, shuffle=None, gpu=0): 145 | data_path = os.path.join(path, 'train.pt' if train else 'test.pt') 146 | if not os.path.exists(data_path): 147 | dset = torchvision.datasets.CIFAR10(path, download=True, train=train) 148 | images = torch.tensor(dset.data) 149 | labels = torch.tensor(dset.targets) 150 | torch.save({'images': images, 'labels': labels, 'classes': dset.classes}, data_path) 151 | 152 | data = torch.load(data_path, map_location=torch.device(gpu)) 153 | self.images, self.labels, self.classes = data['images'], data['labels'], data['classes'] 154 | # It's faster to load+process uint8 data than to load preprocessed fp16 data 155 | self.images = (self.images.half() / 255).permute(0, 3, 1, 2).to(memory_format=torch.channels_last) 156 | 157 | self.normalize = T.Normalize(CIFAR_MEAN, CIFAR_STD) 158 | self.proc_images = {} # Saved results of image processing to be done on the first epoch 159 | self.epoch = 0 160 | 161 | self.aug = aug or {} 162 | for k in self.aug.keys(): 163 | assert k in ['flip', 'translate'], 'Unrecognized key: %s' % k 164 | 165 | self.batch_size = batch_size 166 | self.drop_last = train if drop_last is None else drop_last 167 | self.shuffle = train if shuffle is None else shuffle 168 | 169 | def __len__(self): 170 | return len(self.images)//self.batch_size if self.drop_last else ceil(len(self.images)/self.batch_size) 171 | 172 | def __iter__(self): 173 | 174 | if self.epoch == 0: 175 | images = self.proc_images['norm'] = self.normalize(self.images) 176 | # Pre-flip images in order to do every-other epoch flipping scheme 177 | if self.aug.get('flip', False): 178 | images = self.proc_images['flip'] = batch_flip_lr(images) 179 | # Pre-pad images to save time when doing random translation 180 | pad = self.aug.get('translate', 0) 181 | if pad > 0: 182 | self.proc_images['pad'] = F.pad(images, (pad,)*4, 'reflect') 183 | 184 | if self.aug.get('translate', 0) > 0: 185 | images = batch_crop(self.proc_images['pad'], self.images.shape[-2]) 186 | elif self.aug.get('flip', False): 187 | images = self.proc_images['flip'] 188 | else: 189 | images = self.proc_images['norm'] 190 | # Flip all images together every other epoch. This increases diversity relative to random flipping 191 | if self.aug.get('flip', False): 192 | if self.epoch % 2 == 1: 193 | images = images.flip(-1) 194 | 195 | self.epoch += 1 196 | 197 | indices = (torch.randperm if self.shuffle else torch.arange)(len(images), device=images.device) 198 | for i in range(len(self)): 199 | idxs = indices[i*self.batch_size:(i+1)*self.batch_size] 200 | yield (images[idxs], self.labels[idxs]) 201 | 202 | ############################################# 203 | # Network Components # 204 | ############################################# 205 | 206 | class Flatten(nn.Module): 207 | def forward(self, x): 208 | return x.view(x.size(0), -1) 209 | 210 | class Mul(nn.Module): 211 | def __init__(self, scale): 212 | super().__init__() 213 | self.scale = scale 214 | def forward(self, x): 215 | return x * self.scale 216 | 217 | class BatchNorm(nn.BatchNorm2d): 218 | def __init__(self, num_features, momentum, eps=1e-12, 219 | weight=False, bias=True): 220 | super().__init__(num_features, eps=eps, momentum=1-momentum) 221 | self.weight.requires_grad = weight 222 | self.bias.requires_grad = bias 223 | # Note that PyTorch already initializes the weights to one and bias to zero 224 | 225 | class Conv(nn.Conv2d): 226 | def __init__(self, in_channels, out_channels, kernel_size=3, padding='same', bias=False): 227 | super().__init__(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias) 228 | 229 | def reset_parameters(self): 230 | super().reset_parameters() 231 | if self.bias is not None: 232 | self.bias.data.zero_() 233 | w = self.weight.data 234 | torch.nn.init.dirac_(w[:w.size(1)]) 235 | 236 | class ConvGroup(nn.Module): 237 | def __init__(self, channels_in, channels_out, batchnorm_momentum): 238 | super().__init__() 239 | self.conv1 = Conv(channels_in, channels_out) 240 | self.pool = nn.MaxPool2d(2) 241 | self.norm1 = BatchNorm(channels_out, batchnorm_momentum) 242 | self.conv2 = Conv(channels_out, channels_out) 243 | self.norm2 = BatchNorm(channels_out, batchnorm_momentum) 244 | self.activ = nn.GELU() 245 | 246 | def forward(self, x): 247 | x = self.conv1(x) 248 | x = self.pool(x) 249 | x = self.norm1(x) 250 | x = self.activ(x) 251 | x = self.conv2(x) 252 | x = self.norm2(x) 253 | x = self.activ(x) 254 | return x 255 | 256 | ############################################# 257 | # Network Definition # 258 | ############################################# 259 | 260 | def make_net(): 261 | widths = hyp['net']['widths'] 262 | batchnorm_momentum = hyp['net']['batchnorm_momentum'] 263 | whiten_kernel_size = 2 264 | whiten_width = 2 * 3 * whiten_kernel_size**2 265 | net = nn.Sequential( 266 | Conv(3, whiten_width, whiten_kernel_size, padding=0, bias=True), 267 | nn.GELU(), 268 | ConvGroup(whiten_width, widths['block1'], batchnorm_momentum), 269 | ConvGroup(widths['block1'], widths['block2'], batchnorm_momentum), 270 | ConvGroup(widths['block2'], widths['block3'], batchnorm_momentum), 271 | nn.MaxPool2d(3), 272 | Flatten(), 273 | nn.Linear(widths['block3'], 10, bias=False), 274 | Mul(hyp['net']['scaling_factor']), 275 | ) 276 | net[0].weight.requires_grad = False 277 | net = net.half().cuda() 278 | net = net.to(memory_format=torch.channels_last) 279 | for mod in net.modules(): 280 | if isinstance(mod, BatchNorm): 281 | mod.float() 282 | return net 283 | 284 | def reinit_net(model): 285 | for m in model.modules(): 286 | if type(m) in (Conv, BatchNorm, nn.Linear): 287 | m.reset_parameters() 288 | 289 | ############################################# 290 | # Whitening Conv Initialization # 291 | ############################################# 292 | 293 | def get_patches(x, patch_shape): 294 | c, (h, w) = x.shape[1], patch_shape 295 | return x.unfold(2,h,1).unfold(3,w,1).transpose(1,3).reshape(-1,c,h,w).float() 296 | 297 | def get_whitening_parameters(patches): 298 | n,c,h,w = patches.shape 299 | patches_flat = patches.view(n, -1) 300 | est_patch_covariance = (patches_flat.T @ patches_flat) / n 301 | eigenvalues, eigenvectors = torch.linalg.eigh(est_patch_covariance, UPLO='U') 302 | return eigenvalues.flip(0).view(-1, 1, 1, 1), eigenvectors.T.reshape(c*h*w,c,h,w).flip(0) 303 | 304 | def init_whitening_conv(layer, train_set, eps=5e-4): 305 | patches = get_patches(train_set, patch_shape=layer.weight.data.shape[2:]) 306 | eigenvalues, eigenvectors = get_whitening_parameters(patches) 307 | eigenvectors_scaled = eigenvectors / torch.sqrt(eigenvalues + eps) 308 | layer.weight.data[:] = torch.cat((eigenvectors_scaled, -eigenvectors_scaled)) 309 | 310 | ############################################ 311 | # Lookahead # 312 | ############################################ 313 | 314 | class LookaheadState: 315 | 316 | def __init__(self, net): 317 | self.net_ema = {k: v.clone() for k, v in net.state_dict().items()} 318 | 319 | def update(self, net, decay): 320 | for ema_param, net_param in zip(self.net_ema.values(), net.state_dict().values()): 321 | if net_param.dtype in (torch.half, torch.float): 322 | ema_param.lerp_(net_param, 1-decay) 323 | net_param.copy_(ema_param) 324 | 325 | ############################################ 326 | # Logging # 327 | ############################################ 328 | 329 | def print_columns(columns_list, is_head=False, is_final_entry=False): 330 | print_string = '' 331 | for col in columns_list: 332 | print_string += '| %s ' % col 333 | print_string += '|' 334 | if is_head: 335 | print('-'*len(print_string)) 336 | print(print_string) 337 | if is_head or is_final_entry: 338 | print('-'*len(print_string)) 339 | 340 | logging_columns_list = ['run ', 'epoch', 'train_loss', 'train_acc', 'val_acc', 'tta_val_acc', 'total_time_seconds'] 341 | def print_training_details(variables, is_final_entry): 342 | formatted = [] 343 | for col in logging_columns_list: 344 | var = variables.get(col.strip(), None) 345 | if type(var) in (int, str): 346 | res = str(var) 347 | elif type(var) is float: 348 | res = '{:0.4f}'.format(var) 349 | else: 350 | assert var is None 351 | res = '' 352 | formatted.append(res.rjust(len(col))) 353 | print_columns(formatted, is_final_entry=is_final_entry) 354 | 355 | ############################################ 356 | # Evaluation # 357 | ############################################ 358 | 359 | def infer(model, loader, tta_level=0): 360 | 361 | # Test-time augmentation strategy (for tta_level=2): 362 | # 1. Flip/mirror the image left-to-right (50% of the time). 363 | # 2. Translate the image by one pixel either up-and-left or down-and-right (50% of the time, 364 | # i.e. both happen 25% of the time). 365 | # 366 | # This creates 6 views per image (left/right times the two translations and no-translation), 367 | # which we evaluate and then weight according to the given probabilities. 368 | 369 | def infer_basic(inputs, net): 370 | return net(inputs).clone() 371 | 372 | def infer_mirror(inputs, net): 373 | return 0.5 * net(inputs) + 0.5 * net(inputs.flip(-1)) 374 | 375 | def infer_mirror_translate(inputs, net): 376 | logits = infer_mirror(inputs, net) 377 | pad = 1 378 | padded_inputs = F.pad(inputs, (pad,)*4, 'reflect') 379 | inputs_translate_list = [ 380 | padded_inputs[:, :, 0:32, 0:32], 381 | padded_inputs[:, :, 2:34, 2:34], 382 | ] 383 | logits_translate_list = [infer_mirror(inputs_translate, net) 384 | for inputs_translate in inputs_translate_list] 385 | logits_translate = torch.stack(logits_translate_list).mean(0) 386 | return 0.5 * logits + 0.5 * logits_translate 387 | 388 | model.eval() 389 | test_images = loader.normalize(loader.images) 390 | infer_fn = [infer_basic, infer_mirror, infer_mirror_translate][tta_level] 391 | with torch.no_grad(): 392 | return torch.cat([infer_fn(inputs, model) for inputs in test_images.split(2000)]) 393 | 394 | def evaluate(model, loader, tta_level=0): 395 | logits = infer(model, loader, tta_level) 396 | return (logits.argmax(1) == loader.labels).float().mean().item() 397 | 398 | ############################################ 399 | # Training # 400 | ############################################ 401 | 402 | def main(run, model_trainbias, model_freezebias): 403 | 404 | batch_size = hyp['opt']['batch_size'] 405 | epochs = hyp['opt']['train_epochs'] 406 | momentum = hyp['opt']['momentum'] 407 | # Assuming gradients are constant in time, for Nesterov momentum, the below ratio is how much 408 | # larger the default steps will be than the underlying per-example gradients. We divide the 409 | # learning rate by this ratio in order to ensure steps are the same scale as gradients, regardless 410 | # of the choice of momentum. 411 | kilostep_scale = 1024 * (1 + 1 / (1 - momentum)) 412 | lr = hyp['opt']['lr'] / kilostep_scale # un-decoupled learning rate for PyTorch SGD 413 | wd = hyp['opt']['weight_decay'] * batch_size / kilostep_scale 414 | lr_biases = lr * hyp['opt']['bias_scaler'] 415 | 416 | loss_fn = nn.CrossEntropyLoss(label_smoothing=hyp['opt']['label_smoothing'], reduction='none') 417 | 418 | test_loader = CifarLoader('cifar10', train=False, batch_size=2000) 419 | train_loader = CifarLoader('cifar10', train=True, batch_size=batch_size, aug=hyp['aug']) 420 | if run == 'warmup': 421 | # The only purpose of the first run is to warmup the compiled model, so we can use dummy data 422 | train_loader.labels = torch.randint(0, 10, size=(len(train_loader.labels),), device=train_loader.labels.device) 423 | total_train_steps = ceil(len(train_loader) * epochs) 424 | 425 | # Reinitialize the network from scratch - nothing is reused from previous runs besides the PyTorch compilation 426 | reinit_net(model_trainbias) 427 | current_steps = 0 428 | 429 | norm_biases = [p for k, p in model_trainbias.named_parameters() if 'norm' in k] 430 | other_params = [p for k, p in model_trainbias.named_parameters() if 'norm' not in k] 431 | param_configs = [dict(params=norm_biases, lr=lr_biases, weight_decay=wd/lr_biases), 432 | dict(params=other_params, lr=lr, weight_decay=wd/lr)] 433 | optimizer_trainbias = torch.optim.SGD(param_configs, momentum=momentum, nesterov=True) 434 | 435 | norm_biases = [p for k, p in model_freezebias.named_parameters() if 'norm' in k] 436 | other_params = [p for k, p in model_freezebias.named_parameters() if 'norm' not in k] 437 | param_configs = [dict(params=norm_biases, lr=lr_biases, weight_decay=wd/lr_biases), 438 | dict(params=other_params, lr=lr, weight_decay=wd/lr)] 439 | optimizer_freezebias = torch.optim.SGD(param_configs, momentum=momentum, nesterov=True) 440 | 441 | def get_lr(step): 442 | warmup_steps = int(total_train_steps * 0.23) 443 | warmdown_steps = total_train_steps - warmup_steps 444 | if step < warmup_steps: 445 | frac = step / warmup_steps 446 | return 0.2 * (1 - frac) + 1.0 * frac 447 | else: 448 | frac = (step - warmup_steps) / warmdown_steps 449 | return 1.0 * (1 - frac) + 0.07 * frac 450 | scheduler_trainbias = torch.optim.lr_scheduler.LambdaLR(optimizer_trainbias, get_lr) 451 | scheduler_freezebias = torch.optim.lr_scheduler.LambdaLR(optimizer_freezebias, get_lr) 452 | 453 | alpha_schedule = 0.95**5 * (torch.arange(total_train_steps+1) / total_train_steps)**3 454 | lookahead_state = LookaheadState(model_trainbias) 455 | 456 | # For accurately timing GPU code 457 | starter = torch.cuda.Event(enable_timing=True) 458 | ender = torch.cuda.Event(enable_timing=True) 459 | total_time_seconds = 0.0 460 | 461 | # Initialize the whitening layer using training images 462 | starter.record() 463 | train_images = train_loader.normalize(train_loader.images[:5000]) 464 | init_whitening_conv(model_trainbias._orig_mod[0], train_images) 465 | ender.record() 466 | torch.cuda.synchronize() 467 | total_time_seconds += 1e-3 * starter.elapsed_time(ender) 468 | 469 | for epoch in range(ceil(epochs)): 470 | 471 | # After training the whiten bias for some epochs, swap in the compiled model with frozen bias 472 | if epoch == 0: 473 | model = model_trainbias 474 | optimizer = optimizer_trainbias 475 | scheduler = scheduler_trainbias 476 | elif epoch == hyp['opt']['whiten_bias_epochs']: 477 | model = model_freezebias 478 | optimizer = optimizer_freezebias 479 | scheduler = scheduler_freezebias 480 | model.load_state_dict(model_trainbias.state_dict()) 481 | optimizer.load_state_dict(optimizer_trainbias.state_dict()) 482 | scheduler.load_state_dict(scheduler_trainbias.state_dict()) 483 | 484 | #################### 485 | # Training # 486 | #################### 487 | 488 | starter.record() 489 | 490 | model.train() 491 | for inputs, labels in train_loader: 492 | 493 | outputs = model(inputs) 494 | loss = loss_fn(outputs, labels).sum() 495 | optimizer.zero_grad(set_to_none=True) 496 | loss.backward() 497 | optimizer.step() 498 | scheduler.step() 499 | 500 | current_steps += 1 501 | 502 | if current_steps % 5 == 0: 503 | lookahead_state.update(model, decay=alpha_schedule[current_steps].item()) 504 | 505 | if current_steps >= total_train_steps: 506 | if lookahead_state is not None: 507 | lookahead_state.update(model, decay=1.0) 508 | break 509 | 510 | ender.record() 511 | torch.cuda.synchronize() 512 | total_time_seconds += 1e-3 * starter.elapsed_time(ender) 513 | 514 | #################### 515 | # Evaluation # 516 | #################### 517 | 518 | # Save the accuracy and loss from the last training batch of the epoch 519 | train_acc = (outputs.detach().argmax(1) == labels).float().mean().item() 520 | train_loss = loss.item() / batch_size 521 | val_acc = evaluate(model, test_loader, tta_level=0) 522 | print_training_details(locals(), is_final_entry=False) 523 | run = None # Only print the run number once 524 | 525 | #################### 526 | # TTA Evaluation # 527 | #################### 528 | 529 | starter.record() 530 | tta_val_acc = evaluate(model, test_loader, tta_level=hyp['net']['tta_level']) 531 | ender.record() 532 | torch.cuda.synchronize() 533 | total_time_seconds += 1e-3 * starter.elapsed_time(ender) 534 | 535 | epoch = 'eval' 536 | print_training_details(locals(), is_final_entry=True) 537 | 538 | return tta_val_acc 539 | 540 | if __name__ == "__main__": 541 | with open(sys.argv[0]) as f: 542 | code = f.read() 543 | 544 | # These two compiled models are first warmed up, and then reinitialized every run. No learned 545 | # weights are reused between runs. To implement freezing of the whitening-layer bias parameter 546 | # midway through training, we use two compiled models, one with trainable and the other with 547 | # frozen whitening bias. This is faster than the naive approach of setting requires_grad=False 548 | # on the whitening bias midway through training on a single compiled model. 549 | model_trainbias = make_net() 550 | model_freezebias = make_net() 551 | model_freezebias[0].bias.requires_grad = False 552 | model_trainbias = torch.compile(model_trainbias, mode='max-autotune') 553 | model_freezebias = torch.compile(model_freezebias, mode='max-autotune') 554 | 555 | print_columns(logging_columns_list, is_head=True) 556 | main('warmup', model_trainbias, model_freezebias) 557 | accs = torch.tensor([main(run, model_trainbias, model_freezebias) for run in range(25)]) 558 | print('Mean: %.4f Std: %.4f' % (accs.mean(), accs.std())) 559 | 560 | log = {'code': code, 'accs': accs} 561 | log_dir = os.path.join('logs', str(uuid.uuid4())) 562 | os.makedirs(log_dir, exist_ok=True) 563 | log_path = os.path.join(log_dir, 'log.pt') 564 | print(os.path.abspath(log_path)) 565 | torch.save(log, os.path.join(log_dir, 'log.pt')) 566 | 567 | -------------------------------------------------------------------------------- /legacy/airbench95.py: -------------------------------------------------------------------------------- 1 | # A variant of airbench optimized for time-to-95%. 2 | # 10.4s runtime on an A100; 1.39 PFLOPs. 3 | # Evidence: 95.01 average accuracy in n=200 runs. 4 | # 5 | # We recorded the runtime of 10.4 seconds on an NVIDIA A100-SXM4-80GB with the following nvidia-smi: 6 | # NVIDIA-SMI 515.105.01 Driver Version: 515.105.01 CUDA Version: 11.7 7 | # torch.__version__ == '2.1.2+cu118' 8 | # 9 | # Changes relative to airbench: 10 | # - Increased width and reduced learning rate. 11 | # - Increased training duration to 15 epochs. 12 | # 13 | # If random flip is used instead of alternating, then decays to 94.95 average accuracy in n=100 runs. 14 | # With random flip and 16 epochs instead of 15, we get 94.97 in n=100 runs. 15 | # With random flip and 17, we get 95.01 in n=100 runs. 16 | 17 | ############################################# 18 | # Setup/Hyperparameters # 19 | ############################################# 20 | 21 | import os 22 | import sys 23 | import uuid 24 | from math import ceil 25 | 26 | import torch 27 | from torch import nn 28 | import torch.nn.functional as F 29 | import torchvision 30 | import torchvision.transforms as T 31 | 32 | torch.backends.cudnn.benchmark = True 33 | 34 | # We express the main training hyperparameters (batch size, learning rate, momentum, and weight decay) 35 | # in decoupled form, so that each one can be tuned independently. This accomplishes the following: 36 | # * Assuming time-constant gradients, the average step size is decoupled from everything but the lr. 37 | # * The size of the weight decay update is decoupled from everything but the wd. 38 | # In constrast, normally when we increase the (Nesterov) momentum, this also scales up the step size 39 | # proportionally to 1 + 1 / (1 - momentum), meaning we cannot change momentum without having to re-tune 40 | # the learning rate. Similarly, normally when we increase the learning rate this also increases the size 41 | # of the weight decay, requiring a proportional decrease in the wd to maintain the same decay strength. 42 | # 43 | # The practical impact is that hyperparameter tuning is faster, since this parametrization allows each 44 | # one to be tuned independently. See https://myrtle.ai/learn/how-to-train-your-resnet-5-hyperparameters/. 45 | 46 | hyp = { 47 | 'opt': { 48 | 'train_epochs': 15, 49 | 'batch_size': 1024, 50 | 'lr': 10.0, # learning rate per 1024 examples 51 | 'momentum': 0.85, 52 | 'weight_decay': 0.0153, # weight decay per 1024 examples (decoupled from learning rate) 53 | 'bias_scaler': 64.0, # scales up learning rate (but not weight decay) for BatchNorm biases 54 | 'label_smoothing': 0.2, 55 | 'whiten_bias_epochs': 3, # how many epochs to train the whitening layer bias before freezing 56 | }, 57 | 'aug': { 58 | 'flip': True, 59 | 'translate': 2, 60 | }, 61 | 'net': { 62 | 'widths': { 63 | 'block1': 128, 64 | 'block2': 384, 65 | 'block3': 384, 66 | }, 67 | 'batchnorm_momentum': 0.6, 68 | 'scaling_factor': 1/9, 69 | 'tta_level': 2, # the level of test-time augmentation: 0=none, 1=mirror, 2=mirror+translate 70 | }, 71 | } 72 | 73 | ############################################# 74 | # DataLoader # 75 | ############################################# 76 | 77 | CIFAR_MEAN = torch.tensor((0.4914, 0.4822, 0.4465)) 78 | CIFAR_STD = torch.tensor((0.2470, 0.2435, 0.2616)) 79 | 80 | def batch_flip_lr(inputs): 81 | flip_mask = (torch.rand(len(inputs), device=inputs.device) < 0.5).view(-1, 1, 1, 1) 82 | return torch.where(flip_mask, inputs.flip(-1), inputs) 83 | 84 | def batch_crop(images, crop_size): 85 | r = (images.size(-1) - crop_size)//2 86 | shifts = torch.randint(-r, r+1, size=(len(images), 2), device=images.device) 87 | images_out = torch.empty((len(images), 3, crop_size, crop_size), device=images.device, dtype=images.dtype) 88 | # The two cropping methods in this if-else produce equivalent results, but the second is faster for r > 2. 89 | if r <= 2: 90 | for sy in range(-r, r+1): 91 | for sx in range(-r, r+1): 92 | mask = (shifts[:, 0] == sy) & (shifts[:, 1] == sx) 93 | images_out[mask] = images[mask, :, r+sy:r+sy+crop_size, r+sx:r+sx+crop_size] 94 | else: 95 | images_tmp = torch.empty((len(images), 3, crop_size, crop_size+2*r), device=images.device, dtype=images.dtype) 96 | for s in range(-r, r+1): 97 | mask = (shifts[:, 0] == s) 98 | images_tmp[mask] = images[mask, :, r+s:r+s+crop_size, :] 99 | for s in range(-r, r+1): 100 | mask = (shifts[:, 1] == s) 101 | images_out[mask] = images_tmp[mask, :, :, r+s:r+s+crop_size] 102 | return images_out 103 | 104 | class CifarLoader: 105 | 106 | def __init__(self, path, train=True, batch_size=500, aug=None, drop_last=None, shuffle=None, gpu=0): 107 | data_path = os.path.join(path, 'train.pt' if train else 'test.pt') 108 | if not os.path.exists(data_path): 109 | dset = torchvision.datasets.CIFAR10(path, download=True, train=train) 110 | images = torch.tensor(dset.data) 111 | labels = torch.tensor(dset.targets) 112 | torch.save({'images': images, 'labels': labels, 'classes': dset.classes}, data_path) 113 | 114 | data = torch.load(data_path, map_location=torch.device(gpu)) 115 | self.images, self.labels, self.classes = data['images'], data['labels'], data['classes'] 116 | # It's faster to load+process uint8 data than to load preprocessed fp16 data 117 | self.images = (self.images.half() / 255).permute(0, 3, 1, 2).to(memory_format=torch.channels_last) 118 | 119 | self.normalize = T.Normalize(CIFAR_MEAN, CIFAR_STD) 120 | self.proc_images = {} # Saved results of image processing to be done on the first epoch 121 | self.epoch = 0 122 | 123 | self.aug = aug or {} 124 | for k in self.aug.keys(): 125 | assert k in ['flip', 'translate'], 'Unrecognized key: %s' % k 126 | 127 | self.batch_size = batch_size 128 | self.drop_last = train if drop_last is None else drop_last 129 | self.shuffle = train if shuffle is None else shuffle 130 | 131 | def __len__(self): 132 | return len(self.images)//self.batch_size if self.drop_last else ceil(len(self.images)/self.batch_size) 133 | 134 | def __iter__(self): 135 | 136 | if self.epoch == 0: 137 | images = self.proc_images['norm'] = self.normalize(self.images) 138 | # Pre-flip images in order to do every-other epoch flipping scheme 139 | if self.aug.get('flip', False): 140 | images = self.proc_images['flip'] = batch_flip_lr(images) 141 | # Pre-pad images to save time when doing random translation 142 | pad = self.aug.get('translate', 0) 143 | if pad > 0: 144 | self.proc_images['pad'] = F.pad(images, (pad,)*4, 'reflect') 145 | 146 | if self.aug.get('translate', 0) > 0: 147 | images = batch_crop(self.proc_images['pad'], self.images.shape[-2]) 148 | elif self.aug.get('flip', False): 149 | images = self.proc_images['flip'] 150 | else: 151 | images = self.proc_images['norm'] 152 | # Flip all images together every other epoch. This increases diversity relative to random flipping 153 | if self.aug.get('flip', False): 154 | if self.epoch % 2 == 1: 155 | images = images.flip(-1) 156 | 157 | self.epoch += 1 158 | 159 | indices = (torch.randperm if self.shuffle else torch.arange)(len(images), device=images.device) 160 | for i in range(len(self)): 161 | idxs = indices[i*self.batch_size:(i+1)*self.batch_size] 162 | yield (images[idxs], self.labels[idxs]) 163 | 164 | ############################################# 165 | # Network Components # 166 | ############################################# 167 | 168 | class Flatten(nn.Module): 169 | def forward(self, x): 170 | return x.view(x.size(0), -1) 171 | 172 | class Mul(nn.Module): 173 | def __init__(self, scale): 174 | super().__init__() 175 | self.scale = scale 176 | def forward(self, x): 177 | return x * self.scale 178 | 179 | class BatchNorm(nn.BatchNorm2d): 180 | def __init__(self, num_features, momentum, eps=1e-12, 181 | weight=False, bias=True): 182 | super().__init__(num_features, eps=eps, momentum=1-momentum) 183 | self.weight.requires_grad = weight 184 | self.bias.requires_grad = bias 185 | # Note that PyTorch already initializes the weights to one and bias to zero 186 | 187 | class Conv(nn.Conv2d): 188 | def __init__(self, in_channels, out_channels, kernel_size=3, padding='same', bias=False): 189 | super().__init__(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias) 190 | 191 | def reset_parameters(self): 192 | super().reset_parameters() 193 | if self.bias is not None: 194 | self.bias.data.zero_() 195 | w = self.weight.data 196 | torch.nn.init.dirac_(w[:w.size(1)]) 197 | 198 | class ConvGroup(nn.Module): 199 | def __init__(self, channels_in, channels_out, batchnorm_momentum): 200 | super().__init__() 201 | self.conv1 = Conv(channels_in, channels_out) 202 | self.pool = nn.MaxPool2d(2) 203 | self.norm1 = BatchNorm(channels_out, batchnorm_momentum) 204 | self.conv2 = Conv(channels_out, channels_out) 205 | self.norm2 = BatchNorm(channels_out, batchnorm_momentum) 206 | self.activ = nn.GELU() 207 | 208 | def forward(self, x): 209 | x = self.conv1(x) 210 | x = self.pool(x) 211 | x = self.norm1(x) 212 | x = self.activ(x) 213 | x = self.conv2(x) 214 | x = self.norm2(x) 215 | x = self.activ(x) 216 | return x 217 | 218 | ############################################# 219 | # Network Definition # 220 | ############################################# 221 | 222 | def make_net(): 223 | widths = hyp['net']['widths'] 224 | batchnorm_momentum = hyp['net']['batchnorm_momentum'] 225 | whiten_kernel_size = 2 226 | whiten_width = 2 * 3 * whiten_kernel_size**2 227 | net = nn.Sequential( 228 | Conv(3, whiten_width, whiten_kernel_size, padding=0, bias=True), 229 | nn.GELU(), 230 | ConvGroup(whiten_width, widths['block1'], batchnorm_momentum), 231 | ConvGroup(widths['block1'], widths['block2'], batchnorm_momentum), 232 | ConvGroup(widths['block2'], widths['block3'], batchnorm_momentum), 233 | nn.MaxPool2d(3), 234 | Flatten(), 235 | nn.Linear(widths['block3'], 10, bias=False), 236 | Mul(hyp['net']['scaling_factor']), 237 | ) 238 | net[0].weight.requires_grad = False 239 | net = net.half().cuda() 240 | net = net.to(memory_format=torch.channels_last) 241 | for mod in net.modules(): 242 | if isinstance(mod, BatchNorm): 243 | mod.float() 244 | return net 245 | 246 | ############################################# 247 | # Whitening Conv Initialization # 248 | ############################################# 249 | 250 | def get_patches(x, patch_shape): 251 | c, (h, w) = x.shape[1], patch_shape 252 | return x.unfold(2,h,1).unfold(3,w,1).transpose(1,3).reshape(-1,c,h,w).float() 253 | 254 | def get_whitening_parameters(patches): 255 | n,c,h,w = patches.shape 256 | patches_flat = patches.view(n, -1) 257 | est_patch_covariance = (patches_flat.T @ patches_flat) / n 258 | eigenvalues, eigenvectors = torch.linalg.eigh(est_patch_covariance, UPLO='U') 259 | return eigenvalues.flip(0).view(-1, 1, 1, 1), eigenvectors.T.reshape(c*h*w,c,h,w).flip(0) 260 | 261 | def init_whitening_conv(layer, train_set, eps=5e-4): 262 | patches = get_patches(train_set, patch_shape=layer.weight.data.shape[2:]) 263 | eigenvalues, eigenvectors = get_whitening_parameters(patches) 264 | eigenvectors_scaled = eigenvectors / torch.sqrt(eigenvalues + eps) 265 | layer.weight.data[:] = torch.cat((eigenvectors_scaled, -eigenvectors_scaled)) 266 | 267 | ############################################ 268 | # Lookahead # 269 | ############################################ 270 | 271 | class LookaheadState: 272 | def __init__(self, net): 273 | self.net_ema = {k: v.clone() for k, v in net.state_dict().items()} 274 | 275 | def update(self, net, decay): 276 | for ema_param, net_param in zip(self.net_ema.values(), net.state_dict().values()): 277 | if net_param.dtype in (torch.half, torch.float): 278 | ema_param.lerp_(net_param, 1-decay) 279 | net_param.copy_(ema_param) 280 | 281 | ############################################ 282 | # Logging # 283 | ############################################ 284 | 285 | def print_columns(columns_list, is_head=False, is_final_entry=False): 286 | print_string = '' 287 | for col in columns_list: 288 | print_string += '| %s ' % col 289 | print_string += '|' 290 | if is_head: 291 | print('-'*len(print_string)) 292 | print(print_string) 293 | if is_head or is_final_entry: 294 | print('-'*len(print_string)) 295 | 296 | logging_columns_list = ['run ', 'epoch', 'train_loss', 'train_acc', 'val_acc', 'tta_val_acc', 'total_time_seconds'] 297 | def print_training_details(variables, is_final_entry): 298 | formatted = [] 299 | for col in logging_columns_list: 300 | var = variables.get(col.strip(), None) 301 | if type(var) in (int, str): 302 | res = str(var) 303 | elif type(var) is float: 304 | res = '{:0.4f}'.format(var) 305 | else: 306 | assert var is None 307 | res = '' 308 | formatted.append(res.rjust(len(col))) 309 | print_columns(formatted, is_final_entry=is_final_entry) 310 | 311 | ############################################ 312 | # Evaluation # 313 | ############################################ 314 | 315 | def infer(model, loader, tta_level=0): 316 | 317 | # Test-time augmentation strategy (for tta_level=2): 318 | # 1. Flip/mirror the image left-to-right (50% of the time). 319 | # 2. Translate the image by one pixel either up-and-left or down-and-right (50% of the time, 320 | # i.e. both happen 25% of the time). 321 | # 322 | # This creates 6 views per image (left/right times the two translations and no-translation), 323 | # which we evaluate and then weight according to the given probabilities. 324 | 325 | def infer_basic(inputs, net): 326 | return net(inputs).clone() 327 | 328 | def infer_mirror(inputs, net): 329 | return 0.5 * net(inputs) + 0.5 * net(inputs.flip(-1)) 330 | 331 | def infer_mirror_translate(inputs, net): 332 | logits = infer_mirror(inputs, net) 333 | pad = 1 334 | padded_inputs = F.pad(inputs, (pad,)*4, 'reflect') 335 | inputs_translate_list = [ 336 | padded_inputs[:, :, 0:32, 0:32], 337 | padded_inputs[:, :, 2:34, 2:34], 338 | ] 339 | logits_translate_list = [infer_mirror(inputs_translate, net) 340 | for inputs_translate in inputs_translate_list] 341 | logits_translate = torch.stack(logits_translate_list).mean(0) 342 | return 0.5 * logits + 0.5 * logits_translate 343 | 344 | model.eval() 345 | test_images = loader.normalize(loader.images) 346 | infer_fn = [infer_basic, infer_mirror, infer_mirror_translate][tta_level] 347 | with torch.no_grad(): 348 | return torch.cat([infer_fn(inputs, model) for inputs in test_images.split(2000)]) 349 | 350 | def evaluate(model, loader, tta_level=0): 351 | logits = infer(model, loader, tta_level) 352 | return (logits.argmax(1) == loader.labels).float().mean().item() 353 | 354 | ############################################ 355 | # Training # 356 | ############################################ 357 | 358 | def main(run): 359 | 360 | batch_size = hyp['opt']['batch_size'] 361 | epochs = hyp['opt']['train_epochs'] 362 | momentum = hyp['opt']['momentum'] 363 | # Assuming gradients are constant in time, for Nesterov momentum, the below ratio is how much 364 | # larger the default steps will be than the underlying per-example gradients. We divide the 365 | # learning rate by this ratio in order to ensure steps are the same scale as gradients, regardless 366 | # of the choice of momentum. 367 | kilostep_scale = 1024 * (1 + 1 / (1 - momentum)) 368 | lr = hyp['opt']['lr'] / kilostep_scale # un-decoupled learning rate for PyTorch SGD 369 | wd = hyp['opt']['weight_decay'] * batch_size / kilostep_scale 370 | lr_biases = lr * hyp['opt']['bias_scaler'] 371 | 372 | loss_fn = nn.CrossEntropyLoss(label_smoothing=hyp['opt']['label_smoothing'], reduction='none') 373 | test_loader = CifarLoader('cifar10', train=False, batch_size=2000) 374 | train_loader = CifarLoader('cifar10', train=True, batch_size=batch_size, aug=hyp['aug']) 375 | if run == 'warmup': 376 | # The only purpose of the first run is to warmup, so we can use dummy data 377 | train_loader.labels = torch.randint(0, 10, size=(len(train_loader.labels),), device=train_loader.labels.device) 378 | total_train_steps = ceil(len(train_loader) * epochs) 379 | 380 | model = make_net() 381 | current_steps = 0 382 | 383 | norm_biases = [p for k, p in model.named_parameters() if 'norm' in k and p.requires_grad] 384 | other_params = [p for k, p in model.named_parameters() if 'norm' not in k and p.requires_grad] 385 | param_configs = [dict(params=norm_biases, lr=lr_biases, weight_decay=wd/lr_biases), 386 | dict(params=other_params, lr=lr, weight_decay=wd/lr)] 387 | optimizer = torch.optim.SGD(param_configs, momentum=momentum, nesterov=True) 388 | 389 | def get_lr(step): 390 | warmup_steps = int(total_train_steps * 0.23) 391 | warmdown_steps = total_train_steps - warmup_steps 392 | if step < warmup_steps: 393 | frac = step / warmup_steps 394 | return 0.2 * (1 - frac) + 1.0 * frac 395 | else: 396 | frac = (step - warmup_steps) / warmdown_steps 397 | return 1.0 * (1 - frac) + 0.07 * frac 398 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, get_lr) 399 | 400 | alpha_schedule = 0.95**5 * (torch.arange(total_train_steps+1) / total_train_steps)**3 401 | lookahead_state = LookaheadState(model) 402 | 403 | # For accurately timing GPU code 404 | starter = torch.cuda.Event(enable_timing=True) 405 | ender = torch.cuda.Event(enable_timing=True) 406 | total_time_seconds = 0.0 407 | 408 | # Initialize the whitening layer using training images 409 | starter.record() 410 | train_images = train_loader.normalize(train_loader.images[:5000]) 411 | init_whitening_conv(model[0], train_images) 412 | ender.record() 413 | torch.cuda.synchronize() 414 | total_time_seconds += 1e-3 * starter.elapsed_time(ender) 415 | 416 | for epoch in range(ceil(epochs)): 417 | 418 | model[0].bias.requires_grad = (epoch < hyp['opt']['whiten_bias_epochs']) 419 | 420 | #################### 421 | # Training # 422 | #################### 423 | 424 | starter.record() 425 | 426 | model.train() 427 | for inputs, labels in train_loader: 428 | 429 | outputs = model(inputs) 430 | loss = loss_fn(outputs, labels).sum() 431 | optimizer.zero_grad(set_to_none=True) 432 | loss.backward() 433 | optimizer.step() 434 | scheduler.step() 435 | 436 | current_steps += 1 437 | 438 | if current_steps % 5 == 0: 439 | lookahead_state.update(model, decay=alpha_schedule[current_steps].item()) 440 | 441 | if current_steps >= total_train_steps: 442 | if lookahead_state is not None: 443 | lookahead_state.update(model, decay=1.0) 444 | break 445 | 446 | ender.record() 447 | torch.cuda.synchronize() 448 | total_time_seconds += 1e-3 * starter.elapsed_time(ender) 449 | 450 | #################### 451 | # Evaluation # 452 | #################### 453 | 454 | # Save the accuracy and loss from the last training batch of the epoch 455 | train_acc = (outputs.detach().argmax(1) == labels).float().mean().item() 456 | train_loss = loss.item() / batch_size 457 | val_acc = evaluate(model, test_loader, tta_level=0) 458 | print_training_details(locals(), is_final_entry=False) 459 | run = None # Only print the run number once 460 | 461 | #################### 462 | # TTA Evaluation # 463 | #################### 464 | 465 | starter.record() 466 | tta_val_acc = evaluate(model, test_loader, tta_level=hyp['net']['tta_level']) 467 | ender.record() 468 | torch.cuda.synchronize() 469 | total_time_seconds += 1e-3 * starter.elapsed_time(ender) 470 | 471 | epoch = 'eval' 472 | print_training_details(locals(), is_final_entry=True) 473 | 474 | return tta_val_acc 475 | 476 | if __name__ == "__main__": 477 | with open(sys.argv[0]) as f: 478 | code = f.read() 479 | 480 | print_columns(logging_columns_list, is_head=True) 481 | #main('warmup') 482 | accs = torch.tensor([main(run) for run in range(25)]) 483 | print('Mean: %.4f Std: %.4f' % (accs.mean(), accs.std())) 484 | 485 | log = {'code': code, 'accs': accs} 486 | log_dir = os.path.join('logs', str(uuid.uuid4())) 487 | os.makedirs(log_dir, exist_ok=True) 488 | log_path = os.path.join(log_dir, 'log.pt') 489 | print(os.path.abspath(log_path)) 490 | torch.save(log, os.path.join(log_dir, 'log.pt')) 491 | 492 | -------------------------------------------------------------------------------- /legacy/airbench96.py: -------------------------------------------------------------------------------- 1 | # A variant of airbench optimized for time-to-96%. 2 | # 34.7s runtime on an A100; 4.91 PFLOPs. 3 | # Evidence: 96.03 average accuracy in n=400 runs. 4 | # 5 | # We recorded the runtime of 34.7 seconds on an NVIDIA A100-SXM4-80GB with the following nvidia-smi: 6 | # NVIDIA-SMI 515.105.01 Driver Version: 515.105.01 CUDA Version: 11.7 7 | # torch.__version__ == '2.1.2+cu118' 8 | # 9 | # Changes relative to airbench: 10 | # - Increased network width and reduced learning rate & weight decay. 11 | # - Reduced the warmup duration and let the learning rate decay go all the way to zero at the end of training. 12 | # - Added an extra layer to each ConvBlock. The network now contains 10 conv layers. 13 | # - Added residual connections over the last two conv layers in each ConvBlock. 14 | # - Added 12-pixel cutout data augmentation and increased random-translation strength from 2 to 4 pixels. 15 | # - Increased training duration to 37 epochs. 16 | 17 | ############################################# 18 | # Setup/Hyperparameters # 19 | ############################################# 20 | 21 | import os 22 | import sys 23 | import uuid 24 | from math import ceil 25 | 26 | import torch 27 | from torch import nn 28 | import torch.nn.functional as F 29 | import torchvision 30 | import torchvision.transforms as T 31 | 32 | torch.backends.cudnn.benchmark = True 33 | 34 | # We express the main training hyperparameters (batch size, learning rate, momentum, and weight decay) 35 | # in decoupled form, so that each one can be tuned independently. This accomplishes the following: 36 | # * Assuming time-constant gradients, the average step size is decoupled from everything but the lr. 37 | # * The size of the weight decay update is decoupled from everything but the wd. 38 | # In constrast, normally when we increase the (Nesterov) momentum, this also scales up the step size 39 | # proportionally to 1 + 1 / (1 - momentum), meaning we cannot change momentum without having to re-tune 40 | # the learning rate. Similarly, normally when we increase the learning rate this also increases the size 41 | # of the weight decay, requiring a proportional decrease in the wd to maintain the same decay strength. 42 | # 43 | # The practical impact is that hyperparameter tuning is faster, since this parametrization allows each 44 | # one to be tuned independently. See https://myrtle.ai/learn/how-to-train-your-resnet-5-hyperparameters/. 45 | 46 | hyp = { 47 | 'opt': { 48 | 'train_epochs': 37.0, 49 | 'batch_size': 1024, 50 | 'lr': 9.0, # learning rate per 1024 examples 51 | 'momentum': 0.85, 52 | 'weight_decay': 0.012, # weight decay per 1024 examples (decoupled from learning rate) 53 | 'bias_scaler': 64.0, # scales up learning rate (but not weight decay) for BatchNorm biases 54 | 'label_smoothing': 0.2, 55 | 'whiten_bias_epochs': 3, # how many epochs to train the whitening layer bias before freezing 56 | }, 57 | 'aug': { 58 | 'flip': True, 59 | 'translate': 4, 60 | 'cutout': 12, 61 | }, 62 | 'net': { 63 | 'widths': { 64 | 'block1': 128, 65 | 'block2': 384, 66 | 'block3': 512, 67 | }, 68 | 'scaling_factor': 1/9, 69 | 'tta_level': 2, # the level of test-time augmentation: 0=none, 1=mirror, 2=mirror+translate 70 | }, 71 | } 72 | 73 | ############################################# 74 | # DataLoader # 75 | ############################################# 76 | 77 | CIFAR_MEAN = torch.tensor((0.4914, 0.4822, 0.4465)) 78 | CIFAR_STD = torch.tensor((0.2470, 0.2435, 0.2616)) 79 | 80 | def batch_flip_lr(inputs): 81 | flip_mask = (torch.rand(len(inputs), device=inputs.device) < 0.5).view(-1, 1, 1, 1) 82 | return torch.where(flip_mask, inputs.flip(-1), inputs) 83 | 84 | def batch_crop(images, crop_size): 85 | r = (images.size(-1) - crop_size)//2 86 | shifts = torch.randint(-r, r+1, size=(len(images), 2), device=images.device) 87 | images_out = torch.empty((len(images), 3, crop_size, crop_size), device=images.device, dtype=images.dtype) 88 | # The two cropping methods in this if-else produce equivalent results, but the second is faster for r > 2. 89 | if r <= 2: 90 | for sy in range(-r, r+1): 91 | for sx in range(-r, r+1): 92 | mask = (shifts[:, 0] == sy) & (shifts[:, 1] == sx) 93 | images_out[mask] = images[mask, :, r+sy:r+sy+crop_size, r+sx:r+sx+crop_size] 94 | else: 95 | images_tmp = torch.empty((len(images), 3, crop_size, crop_size+2*r), device=images.device, dtype=images.dtype) 96 | for s in range(-r, r+1): 97 | mask = (shifts[:, 0] == s) 98 | images_tmp[mask] = images[mask, :, r+s:r+s+crop_size, :] 99 | for s in range(-r, r+1): 100 | mask = (shifts[:, 1] == s) 101 | images_out[mask] = images_tmp[mask, :, :, r+s:r+s+crop_size] 102 | return images_out 103 | 104 | def make_random_square_masks(inputs, size): 105 | is_even = int(size % 2 == 0) 106 | n,c,h,w = inputs.shape 107 | 108 | # seed top-left corners of squares to cutout boxes from, in one dimension each 109 | corner_y = torch.randint(0, h-size+1, size=(n,), device=inputs.device) 110 | corner_x = torch.randint(0, w-size+1, size=(n,), device=inputs.device) 111 | 112 | # measure distance, using the center as a reference point 113 | corner_y_dists = torch.arange(h, device=inputs.device).view(1, 1, h, 1) - corner_y.view(-1, 1, 1, 1) 114 | corner_x_dists = torch.arange(w, device=inputs.device).view(1, 1, 1, w) - corner_x.view(-1, 1, 1, 1) 115 | 116 | mask_y = (corner_y_dists >= 0) * (corner_y_dists < size) 117 | mask_x = (corner_x_dists >= 0) * (corner_x_dists < size) 118 | 119 | final_mask = mask_y * mask_x 120 | 121 | return final_mask 122 | 123 | def batch_cutout(inputs, size): 124 | cutout_masks = make_random_square_masks(inputs, size) 125 | return inputs.masked_fill(cutout_masks, 0) 126 | 127 | class CifarLoader: 128 | 129 | def __init__(self, path, train=True, batch_size=500, aug=None, drop_last=None, shuffle=None, gpu=0): 130 | data_path = os.path.join(path, 'train.pt' if train else 'test.pt') 131 | if not os.path.exists(data_path): 132 | dset = torchvision.datasets.CIFAR10(path, download=True, train=train) 133 | images = torch.tensor(dset.data) 134 | labels = torch.tensor(dset.targets) 135 | torch.save({'images': images, 'labels': labels, 'classes': dset.classes}, data_path) 136 | 137 | data = torch.load(data_path, map_location=torch.device(gpu)) 138 | self.images, self.labels, self.classes = data['images'], data['labels'], data['classes'] 139 | # It's faster to load+process uint8 data than to load preprocessed fp16 data 140 | self.images = (self.images.half() / 255).permute(0, 3, 1, 2).to(memory_format=torch.channels_last) 141 | 142 | self.normalize = T.Normalize(CIFAR_MEAN, CIFAR_STD) 143 | self.proc_images = {} # Saved results of image processing to be done on the first epoch 144 | self.epoch = 0 145 | 146 | self.aug = aug or {} 147 | for k in self.aug.keys(): 148 | assert k in ['flip', 'translate', 'cutout'], 'Unrecognized key: %s' % k 149 | 150 | self.batch_size = batch_size 151 | self.drop_last = train if drop_last is None else drop_last 152 | self.shuffle = train if shuffle is None else shuffle 153 | 154 | def __len__(self): 155 | return len(self.images)//self.batch_size if self.drop_last else ceil(len(self.images)/self.batch_size) 156 | 157 | def __iter__(self): 158 | 159 | if self.epoch == 0: 160 | images = self.proc_images['norm'] = self.normalize(self.images) 161 | # Pre-flip images in order to do every-other epoch flipping scheme 162 | if self.aug.get('flip', False): 163 | images = self.proc_images['flip'] = batch_flip_lr(images) 164 | # Pre-pad images to save time when doing random translation 165 | pad = self.aug.get('translate', 0) 166 | if pad > 0: 167 | self.proc_images['pad'] = F.pad(images, (pad,)*4, 'reflect') 168 | 169 | if self.aug.get('translate', 0) > 0: 170 | images = batch_crop(self.proc_images['pad'], self.images.shape[-2]) 171 | elif self.aug.get('flip', False): 172 | images = self.proc_images['flip'] 173 | else: 174 | images = self.proc_images['norm'] 175 | # Flip all images together every other epoch. This increases diversity relative to random flipping 176 | if self.aug.get('flip', False): 177 | if self.epoch % 2 == 1: 178 | images = images.flip(-1) 179 | if self.aug.get('cutout', 0) > 0: 180 | images = batch_cutout(images, self.aug['cutout']) 181 | 182 | self.epoch += 1 183 | 184 | indices = (torch.randperm if self.shuffle else torch.arange)(len(images), device=images.device) 185 | for i in range(len(self)): 186 | idxs = indices[i*self.batch_size:(i+1)*self.batch_size] 187 | yield (images[idxs], self.labels[idxs]) 188 | 189 | ############################################# 190 | # Network Components # 191 | ############################################# 192 | 193 | class Flatten(nn.Module): 194 | def forward(self, x): 195 | return x.view(x.size(0), -1) 196 | 197 | class Mul(nn.Module): 198 | def __init__(self, scale): 199 | super().__init__() 200 | self.scale = scale 201 | def forward(self, x): 202 | return x * self.scale 203 | 204 | class BatchNorm(nn.BatchNorm2d): 205 | def __init__(self, num_features, eps=1e-12, 206 | weight=False, bias=True): 207 | super().__init__(num_features, eps=eps) 208 | self.weight.requires_grad = weight 209 | self.bias.requires_grad = bias 210 | # Note that PyTorch already initializes the weights to one and bias to zero 211 | 212 | class Conv(nn.Conv2d): 213 | def __init__(self, in_channels, out_channels, kernel_size=3, padding='same', bias=False): 214 | super().__init__(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias) 215 | 216 | def reset_parameters(self): 217 | super().reset_parameters() 218 | if self.bias is not None: 219 | self.bias.data.zero_() 220 | w = self.weight.data 221 | torch.nn.init.dirac_(w[:w.size(1)]) 222 | 223 | class ConvGroup(nn.Module): 224 | def __init__(self, channels_in, channels_out): 225 | super().__init__() 226 | self.conv1 = Conv(channels_in, channels_out) 227 | self.pool = nn.MaxPool2d(2) 228 | self.norm1 = BatchNorm(channels_out) 229 | self.conv2 = Conv(channels_out, channels_out) 230 | self.norm2 = BatchNorm(channels_out) 231 | self.conv3 = Conv(channels_out, channels_out) 232 | self.norm3 = BatchNorm(channels_out) 233 | self.activ = nn.GELU() 234 | 235 | def forward(self, x): 236 | x = self.conv1(x) 237 | x = self.pool(x) 238 | x = self.norm1(x) 239 | x = self.activ(x) 240 | x0 = x 241 | x = self.conv2(x) 242 | x = self.norm2(x) 243 | x = self.activ(x) 244 | x = self.conv3(x) 245 | x = self.norm3(x) 246 | x = x + x0 247 | x = self.activ(x) 248 | return x 249 | 250 | ############################################# 251 | # Network Definition # 252 | ############################################# 253 | 254 | def make_net(): 255 | widths = hyp['net']['widths'] 256 | whiten_kernel_size = 2 257 | whiten_width = 2 * 3 * whiten_kernel_size**2 258 | net = nn.Sequential( 259 | Conv(3, whiten_width, whiten_kernel_size, padding=0, bias=True), 260 | nn.GELU(), 261 | ConvGroup(whiten_width, widths['block1']), 262 | ConvGroup(widths['block1'], widths['block2']), 263 | ConvGroup(widths['block2'], widths['block3']), 264 | nn.MaxPool2d(3), 265 | Flatten(), 266 | nn.Linear(widths['block3'], 10, bias=False), 267 | Mul(hyp['net']['scaling_factor']), 268 | ) 269 | net[0].weight.requires_grad = False 270 | net = net.half().cuda() 271 | net = net.to(memory_format=torch.channels_last) 272 | for mod in net.modules(): 273 | if isinstance(mod, BatchNorm): 274 | mod.float() 275 | return net 276 | 277 | ############################################# 278 | # Whitening Conv Initialization # 279 | ############################################# 280 | 281 | def get_patches(x, patch_shape): 282 | c, (h, w) = x.shape[1], patch_shape 283 | return x.unfold(2,h,1).unfold(3,w,1).transpose(1,3).reshape(-1,c,h,w).float() 284 | 285 | def get_whitening_parameters(patches): 286 | n,c,h,w = patches.shape 287 | patches_flat = patches.view(n, -1) 288 | est_patch_covariance = (patches_flat.T @ patches_flat) / n 289 | eigenvalues, eigenvectors = torch.linalg.eigh(est_patch_covariance, UPLO='U') 290 | return eigenvalues.flip(0).view(-1, 1, 1, 1), eigenvectors.T.reshape(c*h*w,c,h,w).flip(0) 291 | 292 | def init_whitening_conv(layer, train_set, eps=5e-4): 293 | patches = get_patches(train_set, patch_shape=layer.weight.data.shape[2:]) 294 | eigenvalues, eigenvectors = get_whitening_parameters(patches) 295 | eigenvectors_scaled = eigenvectors / torch.sqrt(eigenvalues + eps) 296 | layer.weight.data[:] = torch.cat((eigenvectors_scaled, -eigenvectors_scaled)) 297 | 298 | ############################################ 299 | # Lookahead # 300 | ############################################ 301 | 302 | class LookaheadState: 303 | def __init__(self, net): 304 | self.net_ema = {k: v.clone() for k, v in net.state_dict().items()} 305 | 306 | def update(self, net, decay): 307 | for ema_param, net_param in zip(self.net_ema.values(), net.state_dict().values()): 308 | if net_param.dtype in (torch.half, torch.float): 309 | ema_param.lerp_(net_param, 1-decay) 310 | net_param.copy_(ema_param) 311 | 312 | ############################################ 313 | # Logging # 314 | ############################################ 315 | 316 | def print_columns(columns_list, is_head=False, is_final_entry=False): 317 | print_string = '' 318 | for col in columns_list: 319 | print_string += '| %s ' % col 320 | print_string += '|' 321 | if is_head: 322 | print('-'*len(print_string)) 323 | print(print_string) 324 | if is_head or is_final_entry: 325 | print('-'*len(print_string)) 326 | 327 | logging_columns_list = ['run ', 'epoch', 'train_loss', 'train_acc', 'val_acc', 'tta_val_acc', 'total_time_seconds'] 328 | def print_training_details(variables, is_final_entry): 329 | formatted = [] 330 | for col in logging_columns_list: 331 | var = variables.get(col.strip(), None) 332 | if type(var) in (int, str): 333 | res = str(var) 334 | elif type(var) is float: 335 | res = '{:0.4f}'.format(var) 336 | else: 337 | assert var is None 338 | res = '' 339 | formatted.append(res.rjust(len(col))) 340 | print_columns(formatted, is_final_entry=is_final_entry) 341 | 342 | ############################################ 343 | # Evaluation # 344 | ############################################ 345 | 346 | def infer(model, loader, tta_level=0): 347 | 348 | # Test-time augmentation strategy (for tta_level=2): 349 | # 1. Flip/mirror the image left-to-right (50% of the time). 350 | # 2. Translate the image by one pixel either up-and-left or down-and-right (50% of the time, 351 | # i.e. both happen 25% of the time). 352 | # 353 | # This creates 6 views per image (left/right times the two translations and no-translation), 354 | # which we evaluate and then weight according to the given probabilities. 355 | 356 | def infer_basic(inputs, net): 357 | return net(inputs).clone() 358 | 359 | def infer_mirror(inputs, net): 360 | return 0.5 * net(inputs) + 0.5 * net(inputs.flip(-1)) 361 | 362 | def infer_mirror_translate(inputs, net): 363 | logits = infer_mirror(inputs, net) 364 | pad = 1 365 | padded_inputs = F.pad(inputs, (pad,)*4, 'reflect') 366 | inputs_translate_list = [ 367 | padded_inputs[:, :, 0:32, 0:32], 368 | padded_inputs[:, :, 2:34, 2:34], 369 | ] 370 | logits_translate_list = [infer_mirror(inputs_translate, net) 371 | for inputs_translate in inputs_translate_list] 372 | logits_translate = torch.stack(logits_translate_list).mean(0) 373 | return 0.5 * logits + 0.5 * logits_translate 374 | 375 | model.eval() 376 | test_images = loader.normalize(loader.images) 377 | infer_fn = [infer_basic, infer_mirror, infer_mirror_translate][tta_level] 378 | with torch.no_grad(): 379 | return torch.cat([infer_fn(inputs, model) for inputs in test_images.split(2000)]) 380 | 381 | def evaluate(model, loader, tta_level=0): 382 | logits = infer(model, loader, tta_level) 383 | return (logits.argmax(1) == loader.labels).float().mean().item() 384 | 385 | ############################################ 386 | # Training # 387 | ############################################ 388 | 389 | def main(run): 390 | 391 | batch_size = hyp['opt']['batch_size'] 392 | epochs = hyp['opt']['train_epochs'] 393 | momentum = hyp['opt']['momentum'] 394 | # Assuming gradients are constant in time, for Nesterov momentum, the below ratio is how much 395 | # larger the default steps will be than the underlying per-example gradients. We divide the 396 | # learning rate by this ratio in order to ensure steps are the same scale as gradients, regardless 397 | # of the choice of momentum. 398 | kilostep_scale = 1024 * (1 + 1 / (1 - momentum)) 399 | lr = hyp['opt']['lr'] / kilostep_scale # un-decoupled learning rate for PyTorch SGD 400 | wd = hyp['opt']['weight_decay'] * batch_size / kilostep_scale 401 | lr_biases = lr * hyp['opt']['bias_scaler'] 402 | 403 | loss_fn = nn.CrossEntropyLoss(label_smoothing=hyp['opt']['label_smoothing'], reduction='none') 404 | test_loader = CifarLoader('cifar10', train=False, batch_size=2000) 405 | train_loader = CifarLoader('cifar10', train=True, batch_size=batch_size, aug=hyp['aug']) 406 | if run == 'warmup': 407 | # The only purpose of the first run is to warmup, so we can use dummy data 408 | train_loader.labels = torch.randint(0, 10, size=(len(train_loader.labels),), device=train_loader.labels.device) 409 | total_train_steps = ceil(len(train_loader) * epochs) 410 | 411 | model = make_net() 412 | current_steps = 0 413 | 414 | norm_biases = [p for k, p in model.named_parameters() if 'norm' in k and p.requires_grad] 415 | other_params = [p for k, p in model.named_parameters() if 'norm' not in k and p.requires_grad] 416 | param_configs = [dict(params=norm_biases, lr=lr_biases, weight_decay=wd/lr_biases), 417 | dict(params=other_params, lr=lr, weight_decay=wd/lr)] 418 | optimizer = torch.optim.SGD(param_configs, momentum=momentum, nesterov=True) 419 | 420 | def get_lr(step): 421 | warmup_steps = int(total_train_steps * 0.1) 422 | warmdown_steps = total_train_steps - warmup_steps 423 | if step < warmup_steps: 424 | frac = step / warmup_steps 425 | return 0.2 * (1 - frac) + 1.0 * frac 426 | else: 427 | frac = (total_train_steps - step) / warmdown_steps 428 | return frac 429 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, get_lr) 430 | 431 | alpha_schedule = 0.95**5 * (torch.arange(total_train_steps+1) / total_train_steps)**3 432 | lookahead_state = LookaheadState(model) 433 | 434 | # For accurately timing GPU code 435 | starter = torch.cuda.Event(enable_timing=True) 436 | ender = torch.cuda.Event(enable_timing=True) 437 | total_time_seconds = 0.0 438 | 439 | # Initialize the whitening layer using training images 440 | starter.record() 441 | train_images = train_loader.normalize(train_loader.images[:5000]) 442 | init_whitening_conv(model[0], train_images) 443 | ender.record() 444 | torch.cuda.synchronize() 445 | total_time_seconds += 1e-3 * starter.elapsed_time(ender) 446 | 447 | for epoch in range(ceil(epochs)): 448 | 449 | model[0].bias.requires_grad = (epoch < hyp['opt']['whiten_bias_epochs']) 450 | 451 | #################### 452 | # Training # 453 | #################### 454 | 455 | starter.record() 456 | 457 | model.train() 458 | for inputs, labels in train_loader: 459 | 460 | outputs = model(inputs) 461 | loss = loss_fn(outputs, labels).sum() 462 | optimizer.zero_grad(set_to_none=True) 463 | loss.backward() 464 | optimizer.step() 465 | scheduler.step() 466 | 467 | current_steps += 1 468 | 469 | if current_steps % 5 == 0: 470 | lookahead_state.update(model, decay=alpha_schedule[current_steps].item()) 471 | 472 | if current_steps >= total_train_steps: 473 | if lookahead_state is not None: 474 | lookahead_state.update(model, decay=1.0) 475 | break 476 | 477 | ender.record() 478 | torch.cuda.synchronize() 479 | total_time_seconds += 1e-3 * starter.elapsed_time(ender) 480 | 481 | #################### 482 | # Evaluation # 483 | #################### 484 | 485 | # Save the accuracy and loss from the last training batch of the epoch 486 | train_acc = (outputs.detach().argmax(1) == labels).float().mean().item() 487 | train_loss = loss.item() / batch_size 488 | val_acc = evaluate(model, test_loader, tta_level=0) 489 | print_training_details(locals(), is_final_entry=False) 490 | run = None # Only print the run number once 491 | 492 | #################### 493 | # TTA Evaluation # 494 | #################### 495 | 496 | starter.record() 497 | tta_val_acc = evaluate(model, test_loader, tta_level=hyp['net']['tta_level']) 498 | ender.record() 499 | torch.cuda.synchronize() 500 | total_time_seconds += 1e-3 * starter.elapsed_time(ender) 501 | 502 | epoch = 'eval' 503 | print_training_details(locals(), is_final_entry=True) 504 | 505 | return tta_val_acc 506 | 507 | if __name__ == "__main__": 508 | with open(sys.argv[0]) as f: 509 | code = f.read() 510 | 511 | print_columns(logging_columns_list, is_head=True) 512 | #main('warmup') 513 | accs = torch.tensor([main(run) for run in range(25)]) 514 | print('Mean: %.4f Std: %.4f' % (accs.mean(), accs.std())) 515 | 516 | log = {'code': code, 'accs': accs} 517 | log_dir = os.path.join('logs', str(uuid.uuid4())) 518 | os.makedirs(log_dir, exist_ok=True) 519 | log_path = os.path.join(log_dir, 'log.pt') 520 | print(os.path.abspath(log_path)) 521 | torch.save(log, os.path.join(log_dir, 'log.pt')) 522 | 523 | -------------------------------------------------------------------------------- /research/airbench94_muon_simple.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | import torch 3 | torch.backends.cudnn.benchmark = True 4 | from torch import nn 5 | import torch.nn.functional as F 6 | import airbench 7 | 8 | @torch.compile 9 | def zeropower_via_newtonschulz5(G, steps=3, eps=1e-7): 10 | assert len(G.shape) == 2 11 | a, b, c = (3.4445, -4.7750, 2.0315) 12 | X = G.bfloat16() 13 | X /= (X.norm() + eps) # ensure top singular value <= 1 14 | if G.size(0) > G.size(1): 15 | X = X.T 16 | for _ in range(steps): 17 | A = X @ X.T 18 | B = b * A + c * A @ A 19 | X = a * X + B @ X 20 | if G.size(0) > G.size(1): 21 | X = X.T 22 | return X 23 | 24 | class Muon(torch.optim.Optimizer): 25 | def __init__(self, params, lr=1e-3, momentum=0, nesterov=False): 26 | defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov) 27 | super().__init__(params, defaults) 28 | 29 | def step(self): 30 | for group in self.param_groups: 31 | lr = group["lr"] 32 | momentum = group["momentum"] 33 | for p in group["params"]: 34 | g = p.grad 35 | if g is None: 36 | continue 37 | state = self.state[p] 38 | 39 | if "momentum_buffer" not in state.keys(): 40 | state["momentum_buffer"] = torch.zeros_like(g) 41 | buf = state["momentum_buffer"] 42 | buf.mul_(momentum).add_(g) 43 | g = g.add(buf, alpha=momentum) if group["nesterov"] else buf 44 | 45 | p.data.mul_(len(p.data)**0.5 / p.data.norm()) # normalize the weight 46 | update = zeropower_via_newtonschulz5(g.reshape(len(g), -1)).view(g.shape) # whiten the update 47 | p.data.add_(update, alpha=-lr) # take a step 48 | 49 | # note the use of low BatchNorm stats momentum 50 | class BatchNorm(nn.BatchNorm2d): 51 | def __init__(self, num_features, momentum=0.6, eps=1e-12): 52 | super().__init__(num_features, eps=eps, momentum=1-momentum) 53 | self.weight.requires_grad = False 54 | # Note that PyTorch already initializes the weights to one and bias to zero 55 | 56 | class Conv(nn.Conv2d): 57 | def __init__(self, in_channels, out_channels): 58 | super().__init__(in_channels, out_channels, kernel_size=3, padding="same", bias=False) 59 | 60 | def reset_parameters(self): 61 | super().reset_parameters() 62 | w = self.weight.data 63 | torch.nn.init.dirac_(w[:w.size(1)]) 64 | 65 | class ConvGroup(nn.Module): 66 | def __init__(self, channels_in, channels_out): 67 | super().__init__() 68 | self.conv1 = Conv(channels_in, channels_out) 69 | self.pool = nn.MaxPool2d(2) 70 | self.norm1 = BatchNorm(channels_out) 71 | self.conv2 = Conv(channels_out, channels_out) 72 | self.norm2 = BatchNorm(channels_out) 73 | self.activ = nn.GELU() 74 | 75 | def forward(self, x): 76 | x = self.conv1(x) 77 | x = self.pool(x) 78 | x = self.norm1(x) 79 | x = self.activ(x) 80 | x = self.conv2(x) 81 | x = self.norm2(x) 82 | x = self.activ(x) 83 | return x 84 | 85 | class CifarNet(nn.Module): 86 | def __init__(self): 87 | super().__init__() 88 | widths = dict(block1=64, block2=256, block3=256) 89 | whiten_kernel_size = 2 90 | whiten_width = 2 * 3 * whiten_kernel_size**2 91 | self.whiten = nn.Conv2d(3, whiten_width, whiten_kernel_size, padding=0, bias=True) 92 | self.whiten.weight.requires_grad = False 93 | self.layers = nn.Sequential( 94 | nn.GELU(), 95 | ConvGroup(whiten_width, widths["block1"]), 96 | ConvGroup(widths["block1"], widths["block2"]), 97 | ConvGroup(widths["block2"], widths["block3"]), 98 | nn.MaxPool2d(3), 99 | ) 100 | self.head = nn.Linear(widths["block3"], 10, bias=False) 101 | for mod in self.modules(): 102 | if isinstance(mod, BatchNorm): 103 | mod.float() 104 | else: 105 | mod.half() 106 | 107 | def reset(self): 108 | for m in self.modules(): 109 | if type(m) in (nn.Conv2d, Conv, BatchNorm, nn.Linear): 110 | m.reset_parameters() 111 | w = self.head.weight.data 112 | w *= 1 / w.std() 113 | 114 | def init_whiten(self, train_images, eps=5e-4): 115 | c, (h, w) = train_images.shape[1], self.whiten.weight.shape[2:] 116 | patches = train_images.unfold(2,h,1).unfold(3,w,1).transpose(1,3).reshape(-1,c,h,w).float() 117 | patches_flat = patches.view(len(patches), -1) 118 | est_patch_covariance = (patches_flat.T @ patches_flat) / len(patches_flat) 119 | eigenvalues, eigenvectors = torch.linalg.eigh(est_patch_covariance, UPLO="U") 120 | eigenvectors_scaled = eigenvectors.T.reshape(-1,c,h,w) / torch.sqrt(eigenvalues.view(-1,1,1,1) + eps) 121 | self.whiten.weight.data[:] = torch.cat((eigenvectors_scaled, -eigenvectors_scaled)) 122 | 123 | def forward(self, x, whiten_bias_grad=True): 124 | b = self.whiten.bias 125 | x = F.conv2d(x, self.whiten.weight, b if whiten_bias_grad else b.detach()) 126 | x = self.layers(x) 127 | x = x.view(len(x), -1) 128 | return self.head(x) / x.size(-1) 129 | 130 | def main(): 131 | 132 | model = CifarNet().cuda().to(memory_format=torch.channels_last) 133 | 134 | batch_size = 2000 135 | bias_lr = 0.053 136 | head_lr = 0.67 137 | wd = 2e-6 * batch_size 138 | 139 | test_loader = airbench.CifarLoader("cifar10", train=False, batch_size=2000) 140 | train_loader = airbench.CifarLoader("cifar10", train=True, batch_size=batch_size, 141 | aug=dict(flip=True, translate=2), altflip=True) 142 | total_train_steps = ceil(8 * len(train_loader)) 143 | whiten_bias_train_steps = ceil(3 * len(train_loader)) 144 | 145 | # Create optimizers and learning rate schedulers 146 | filter_params = [p for p in model.parameters() if len(p.shape) == 4 and p.requires_grad] 147 | norm_biases = [p for n, p in model.named_parameters() if "norm" in n and p.requires_grad] 148 | param_configs = [dict(params=[model.whiten.bias], lr=bias_lr, weight_decay=wd/bias_lr), 149 | dict(params=norm_biases, lr=bias_lr, weight_decay=wd/bias_lr), 150 | dict(params=[model.head.weight], lr=head_lr, weight_decay=wd/head_lr)] 151 | optimizer1 = torch.optim.SGD(param_configs, momentum=0.85, nesterov=True, fused=True) 152 | optimizer2 = Muon(filter_params, lr=0.24, momentum=0.6, nesterov=True) 153 | optimizers = [optimizer1, optimizer2] 154 | for opt in optimizers: 155 | for group in opt.param_groups: 156 | group["initial_lr"] = group["lr"] 157 | 158 | model.reset() 159 | step = 0 160 | 161 | # Initialize the whitening layer using training images 162 | train_images = train_loader.normalize(train_loader.images[:5000]) 163 | model.init_whiten(train_images) 164 | 165 | for epoch in range(ceil(total_train_steps / len(train_loader))): 166 | 167 | model.train() 168 | for inputs, labels in train_loader: 169 | outputs = model(inputs, whiten_bias_grad=(step < whiten_bias_train_steps)) 170 | F.cross_entropy(outputs, labels, label_smoothing=0.2, reduction="sum").backward() 171 | for group in optimizer1.param_groups[:1]: 172 | group["lr"] = group["initial_lr"] * (1 - step / whiten_bias_train_steps) 173 | for group in optimizer1.param_groups[1:]+optimizer2.param_groups: 174 | group["lr"] = group["initial_lr"] * (1 - step / total_train_steps) 175 | for opt in optimizers: 176 | opt.step() 177 | model.zero_grad(set_to_none=True) 178 | step += 1 179 | if step >= total_train_steps: 180 | break 181 | 182 | tta_val_acc = airbench.evaluate(model, test_loader, tta_level=2) 183 | print(f"{tta_val_acc:.4f}") 184 | return tta_val_acc 185 | 186 | if __name__ == "__main__": 187 | accs = torch.tensor([main() for run in range(25)]) 188 | print("Mean: %.4f Std: %.4f" % (accs.mean(), accs.std())) 189 | 190 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='airbench', 5 | version='0.1.9', 6 | author='Keller Jordan', 7 | author_email='kjordan4077@gmail.com', 8 | description='Utilities and baselines for fast neural network training on CIFAR-10', 9 | long_description=open('README.md').read(), 10 | long_description_content_type='text/markdown', 11 | url='https://github.com/KellerJordan/cifar10-airbench', 12 | packages=find_packages(), 13 | classifiers=[ 14 | 'Development Status :: 5 - Production/Stable', 15 | 'Intended Audience :: Developers', 16 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 17 | "Topic :: Scientific/Engineering :: Image Recognition", 18 | "Topic :: Scientific/Engineering :: Information Analysis", 19 | 'License :: OSI Approved :: MIT License', 20 | 'Programming Language :: Python :: 3', 21 | 'Programming Language :: Python :: 3.7', 22 | 'Programming Language :: Python :: 3.8', 23 | 'Programming Language :: Python :: 3.9', 24 | 'Programming Language :: Python :: 3.10', 25 | 'Programming Language :: Python :: 3.11', 26 | ], 27 | ) 28 | --------------------------------------------------------------------------------