├── LICENSE ├── README.md ├── lecun1989.png ├── modern.py ├── prepro.py ├── repro.py └── vis.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Andrej 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # lecun1989-repro 3 | 4 | ![teaser](lecun1989.png) 5 | 6 | This code tries to reproduce the 1989 Yann LeCun et al. paper: [Backpropagation Applied to Handwritten Zip Code Recognition](http://yann.lecun.com/exdb/publis/pdf/lecun-89e.pdf). To my knowledge this is the earliest real-world application of a neural net trained with backpropagation (now 33 years ago). 7 | 8 | #### run 9 | 10 | Since we don't have the exact dataset that was used in the paper, we take MNIST and randomly pick examples from it to generate an approximation of the dataset, which contains only 7291 training and 2007 testing digits, only of size 16x16 pixels (standard MNIST is 28x28). 11 | 12 | ``` 13 | $ python prepro.py 14 | ``` 15 | 16 | Now we can attempt to reproduce the paper. The original network trained for 3 days, but my (Apple Silicon M1) MacBook Air 33 years later chunks through it in about 90 seconds. (non-emulated arm64 but CPU only, I don't believe PyTorch and Apple M1 are best friends ever just yet, but anyway still about 3000X speedup). So now that we've run prepro we can run repro! (haha): 17 | 18 | ``` 19 | $ python repro.py 20 | ``` 21 | 22 | Running this prints (on the 23rd, final pass): 23 | 24 | ``` 25 | eval: split train. loss 4.073383e-03. error 0.62%. misses: 45 26 | eval: split test . loss 2.838382e-02. error 4.09%. misses: 82 27 | ``` 28 | 29 | This is close but not quite the same as what the paper reports. To match the paper exactly we'd expect the following instead: 30 | 31 | ``` 32 | eval: split train. loss 2.5e-3. error 0.14%. misses: 10 33 | eval: split test . loss 1.8e-2. error 5.00%. misses: 102 34 | ``` 35 | 36 | I expect that the majority of this discrepancy comes from the training dataset itself. We've only simulated the original dataset using what we have today 33 years later (MNIST). There are a number of other details that are not specified in the paper, so I also had to do some guessing (see notes below). For example, the specific sparse connectivity structure between layers H1 and H2 is not described, the paper just says that the inputs are "chosen according to a scheme that will not be dicussed here". Alternatively, the paper uses a "special version of Newton's algorithm that uses a positive, diagonal approximation of Hessian", but I only used simple SGD in this implementation because it is signficiantly simpler and, according to the paper, "this algorithm is not believed to bring a tremendous increase in learning speed". Anyway, we are getting numbers on similar orders of magnitude... 37 | 38 | #### notes 39 | 40 | My notes from the paper: 41 | 42 | - 7291 digits are used for training 43 | - 2007 digits are used for testing 44 | - each image is 16x16 pixels grayscale (not binary) 45 | - images are scaled to range [-1, 1] 46 | - network has three hidden layers H1 H2 H3 47 | - H1 is 5x5 stride 2 conv with 12 planes. constant padding of -1. 48 | - not "standard": units do not share biases! (including in the same feature plane) 49 | - H1 has 768 units (8\*8\*12), 19,968 connections (768\*26), 1,068 parameters (768 biases + 25\*12 weights) 50 | - not "standard": H2 units all draw input from 5x5 stride 2 conv but each only connecting to different 8 out of the 12 planes 51 | - H2 contains 192 units (4\*4\*12), 38,592 connections (192 units * 201 input lines), 2,592 parameters (12 * 200 weights + 192 biases) 52 | - H3 has 30 units fully connected to H2. So 5790 connections (30 * 192 + 30) 53 | - output layer has 10 units fully connected to H3. So 310 weights (30 * 10 + 10) 54 | - total: 1256 units, 64,660 connections, 9760 parameters 55 | - tanh activations on all units (including output units!) 56 | - weights of output chosen to be in quasi-linear regime 57 | - cost function: mean squared error 58 | - weight init: random values in U[-2.4/F, 2.4/F] where F is the fan-in. "tends to keep total inputs in operating range of sigmoid" 59 | - training 60 | - patterns presented in constant order 61 | - SGD on single example at a time 62 | - use special version of Newton's algorithm that uses a positive, diagonal approximation of Hessian 63 | - trained for 23 passes over data, measuring train+test error after each pass. total 167,693 presentations (23 * 7291) 64 | - final error: 2.5e-3 train, 1.8e-2 test 65 | - percent misclassification: 0.14% on train (10 mistakes), 5.0% on test (102 mistakes). 66 | - compute: 67 | - run on SUN-4/260 workstation 68 | - digital signal co-processor: 69 | - 256 kbytes of local memory 70 | - peak performance of 12.5M MAC/s on fp32 (ie 25MFLOPS) 71 | - trained for 3 days 72 | - throughput of 10-12 digits/s, "limited mainly by the normalization step" 73 | - throughput of 30 digits/s on normalized digits 74 | - "we have successfully applied backpropagation learning to a large, real-world task" 75 | 76 | **Open questions:** 77 | 78 | - The 12 -> 8 connections from H2 to H1 are not described in this paper... I will assume a sensible block structure connectivity 79 | - Not clear what exactly is the "MSE loss". Was the scaling factor of 1/2 included to simplify the gradient calculation? Will assume no. 80 | - What is the learning rate? I will run a sweep to determine the best one manually. 81 | - Was any learning rate decay used? Not mentioned, I am assuming no. 82 | - Was any weight decay used? not mentioned, assuming no. 83 | - Is there a bug in the pdf where in weight init the fan in should have a square root? The pdf's formatting is a bit messed up. Assuming yes. 84 | - The paper does not say, but what exactly are the targets? Assuming they are +1/-1 for pos/neg, as the output units have tanh too... 85 | 86 | One more notes on the weight init conundrum. Eg the "Kaiming init" is: 87 | 88 | ``` 89 | a = gain * sqrt(3 / fan_in) 90 | ~U(-a, a) 91 | ``` 92 | 93 | For tanh neurons the recommended gain is 5/3. So therefore we would have `a = sqrt(3) * 5 / 3 * sqrt(1 / fan_in) = 2.89 * sqrt(1 / fan_in)`, which is close to what the paper does (gain 2.4). So if the original work in fact did use a sqrt and the pdf is just formatted wrong, then the (modern) Kaiming init and the originally used init are pretty close. 94 | 95 | #### todos 96 | 97 | - modernize the network using knowledge from 33 years of time travel. 98 | - include my janky hyperparameter sweeping code for tuning the learning rate potentially 99 | -------------------------------------------------------------------------------- /lecun1989.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/karpathy/lecun1989-repro/8553f52c8d0a51a4bbbabf98e005cef28a22a36b/lecun1989.png -------------------------------------------------------------------------------- /modern.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | repro.py gives: 4 | 23 5 | eval: split train. loss 4.073383e-03. error 0.62%. misses: 45 6 | eval: split test . loss 2.838382e-02. error 4.09%. misses: 82 7 | 8 | we can try to use our knowledge from 33 years later to improve on this, 9 | but keeping the model size same. 10 | 11 | Change 1: replace tanh on last layer with FC and use softmax. Had to 12 | lower the learning rate to 0.01 as well. This improves the optimization 13 | quite a lot, we now crush the training set: 14 | 23 15 | eval: split train. loss 9.536698e-06. error 0.00%. misses: 0 16 | eval: split test . loss 9.536698e-06. error 4.38%. misses: 87 17 | 18 | Change 2: change from SGD to AdamW with LR 3e-4 because I find this 19 | to be significantly more stable and requires little to no tuning. Also 20 | double epochs to 46. I decay the LR to 1e-4 over course of training. 21 | These changes make it so optimization is not culprit of bad performance 22 | with high probability. We also seem to improve test set a bit: 23 | 46 24 | eval: split train. loss 0.000000e+00. error 0.00%. misses: 0 25 | eval: split test . loss 0.000000e+00. error 3.59%. misses: 72 26 | 27 | Change 3: since we are overfitting we can introduce data augmentation, 28 | e.g. let's intro a shift by at most 1 pixel in both x/y directions. Also 29 | because we are augmenting we again want to bump up training time, e.g. 30 | to 60 epochs: 31 | 60 32 | eval: split train. loss 8.780676e-04. error 1.70%. misses: 123 33 | eval: split test . loss 8.780676e-04. error 2.19%. misses: 43 34 | 35 | Change 4: we want to add dropout at the layer with most parameters (H3), 36 | but in addition we also have to shift the activation function to relu so 37 | that dropout makes sense. We also bring up iterations to 80: 38 | 80 39 | eval: split train. loss 2.601336e-03. error 1.47%. misses: 106 40 | eval: split test . loss 2.601336e-03. error 1.59%. misses: 32 41 | 42 | To be continued... 43 | """ 44 | 45 | import os 46 | import json 47 | import argparse 48 | 49 | import numpy as np 50 | import torch 51 | import torch.nn as nn 52 | import torch.nn.functional as F 53 | import torch.optim as optim 54 | from tensorboardX import SummaryWriter # pip install tensorboardX 55 | 56 | # ----------------------------------------------------------------------------- 57 | 58 | class Net(nn.Module): 59 | """ 1989 LeCun ConvNet per description in the paper """ 60 | 61 | def __init__(self): 62 | super().__init__() 63 | 64 | # initialization as described in the paper to my best ability, but it doesn't look right... 65 | winit = lambda fan_in, *shape: (torch.rand(*shape) - 0.5) * 2 * 2.4 / fan_in**0.5 66 | macs = 0 # keep track of MACs (multiply accumulates) 67 | acts = 0 # keep track of number of activations 68 | 69 | # H1 layer parameters and their initialization 70 | self.H1w = nn.Parameter(winit(5*5*1, 12, 1, 5, 5)) 71 | self.H1b = nn.Parameter(torch.zeros(12, 8, 8)) # presumably init to zero for biases 72 | macs += (5*5*1) * (8*8) * 12 73 | acts += (8*8) * 12 74 | 75 | # H2 layer parameters and their initialization 76 | """ 77 | H2 neurons all connect to only 8 of the 12 input planes, with an unspecified pattern 78 | I am going to assume the most sensible block pattern where 4 planes at a time connect 79 | to differently overlapping groups of 8/12 input planes. We will implement this with 3 80 | separate convolutions that we concatenate the results of. 81 | """ 82 | self.H2w = nn.Parameter(winit(5*5*8, 12, 8, 5, 5)) 83 | self.H2b = nn.Parameter(torch.zeros(12, 4, 4)) # presumably init to zero for biases 84 | macs += (5*5*8) * (4*4) * 12 85 | acts += (4*4) * 12 86 | 87 | # H3 is a fully connected layer 88 | self.H3w = nn.Parameter(winit(4*4*12, 4*4*12, 30)) 89 | self.H3b = nn.Parameter(torch.zeros(30)) 90 | macs += (4*4*12) * 30 91 | acts += 30 92 | 93 | # output layer is also fully connected layer 94 | self.outw = nn.Parameter(winit(30, 30, 10)) 95 | self.outb = nn.Parameter(torch.zeros(10)) 96 | macs += 30 * 10 97 | acts += 10 98 | 99 | self.macs = macs 100 | self.acts = acts 101 | 102 | def forward(self, x): 103 | 104 | # poor man's data augmentation by 1 pixel along x/y directions 105 | if self.training: 106 | shift_x, shift_y = np.random.randint(-1, 2, size=2) 107 | x = torch.roll(x, (shift_x, shift_y), (2, 3)) 108 | 109 | # x has shape (1, 1, 16, 16) 110 | x = F.pad(x, (2, 2, 2, 2), 'constant', -1.0) # pad by two using constant -1 for background 111 | x = F.conv2d(x, self.H1w, stride=2) + self.H1b 112 | x = torch.relu(x) 113 | 114 | # x is now shape (1, 12, 8, 8) 115 | x = F.pad(x, (2, 2, 2, 2), 'constant', -1.0) # pad by two using constant -1 for background 116 | slice1 = F.conv2d(x[:, 0:8], self.H2w[0:4], stride=2) # first 4 planes look at first 8 input planes 117 | slice2 = F.conv2d(x[:, 4:12], self.H2w[4:8], stride=2) # next 4 planes look at last 8 input planes 118 | slice3 = F.conv2d(torch.cat((x[:, 0:4], x[:, 8:12]), dim=1), self.H2w[8:12], stride=2) # last 4 planes are cross 119 | x = torch.cat((slice1, slice2, slice3), dim=1) + self.H2b 120 | x = torch.relu(x) 121 | x = F.dropout(x, p=0.25, training=self.training) 122 | 123 | # x is now shape (1, 12, 4, 4) 124 | x = x.flatten(start_dim=1) # (1, 12*4*4) 125 | x = x @ self.H3w + self.H3b 126 | x = torch.relu(x) 127 | 128 | # x is now shape (1, 30) 129 | x = x @ self.outw + self.outb 130 | 131 | # x is finally shape (1, 10) 132 | return x 133 | 134 | # ----------------------------------------------------------------------------- 135 | 136 | if __name__ == '__main__': 137 | 138 | parser = argparse.ArgumentParser(description="Train a 2022 but mini ConvNet on digits") 139 | parser.add_argument('--learning-rate', '-l', type=float, default=3e-4, help="Learning rate") 140 | parser.add_argument('--output-dir' , '-o', type=str, default='out/modern', help="output directory for training logs") 141 | args = parser.parse_args() 142 | print(vars(args)) 143 | 144 | # init rng 145 | torch.manual_seed(1337) 146 | np.random.seed(1337) 147 | torch.use_deterministic_algorithms(True) 148 | 149 | # set up logging 150 | os.makedirs(args.output_dir, exist_ok=True) 151 | with open(os.path.join(args.output_dir, 'args.json'), 'w') as f: 152 | json.dump(vars(args), f, indent=2) 153 | writer = SummaryWriter(args.output_dir) 154 | 155 | # init a model 156 | model = Net() 157 | print("model stats:") 158 | print("# params: ", sum(p.numel() for p in model.parameters())) # in paper total is 9,760 159 | print("# MACs: ", model.macs) 160 | print("# activations: ", model.acts) 161 | 162 | # init data 163 | Xtr, Ytr = torch.load('train1989.pt') 164 | Xte, Yte = torch.load('test1989.pt') 165 | 166 | # init optimizer 167 | optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) 168 | 169 | def eval_split(split): 170 | # eval the full train/test set, batched implementation for efficiency 171 | model.eval() 172 | X, Y = (Xtr, Ytr) if split == 'train' else (Xte, Yte) 173 | Yhat = model(X) 174 | loss = F.cross_entropy(yhat, y.argmax(dim=1)) 175 | err = torch.mean((Y.argmax(dim=1) != Yhat.argmax(dim=1)).float()) 176 | print(f"eval: split {split:5s}. loss {loss.item():e}. error {err.item()*100:.2f}%. misses: {int(err.item()*Y.size(0))}") 177 | writer.add_scalar(f'error/{split}', err.item()*100, pass_num) 178 | writer.add_scalar(f'loss/{split}', loss.item(), pass_num) 179 | 180 | # train 181 | for pass_num in range(80): 182 | 183 | # learning rate decay 184 | alpha = pass_num / 79 185 | for g in optimizer.param_groups: 186 | g['lr'] = (1 - alpha) * args.learning_rate + alpha * (args.learning_rate / 3) 187 | 188 | # perform one epoch of training 189 | model.train() 190 | for step_num in range(Xtr.size(0)): 191 | 192 | # fetch a single example into a batch of 1 193 | x, y = Xtr[[step_num]], Ytr[[step_num]] 194 | 195 | # forward the model and the loss 196 | yhat = model(x) 197 | loss = F.cross_entropy(yhat, y.argmax(dim=1)) 198 | 199 | # calculate the gradient and update the parameters 200 | optimizer.zero_grad(set_to_none=True) 201 | loss.backward() 202 | optimizer.step() 203 | 204 | # after epoch epoch evaluate the train and test error / metrics 205 | print(pass_num + 1) 206 | eval_split('train') 207 | eval_split('test') 208 | 209 | # save final model to file 210 | torch.save(model.state_dict(), os.path.join(args.output_dir, 'model.pt')) 211 | -------------------------------------------------------------------------------- /prepro.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess today's MNIST dataset into 1989 version's size/format (approximately) 3 | http://yann.lecun.com/exdb/publis/pdf/lecun-89e.pdf 4 | 5 | Some relevant notes for this part: 6 | - 7291 digits are used for training 7 | - 2007 digits are used for testing 8 | - each image is 16x16 pixels grayscale (not binary) 9 | - images are scaled to range [-1, 1] 10 | - paper doesn't say exactly, but reading between the lines I assume label targets to be {-1, 1} 11 | """ 12 | 13 | import numpy as np 14 | import torch 15 | import torch.nn.functional as F 16 | from torchvision import datasets 17 | 18 | # ----------------------------------------------------------------------------- 19 | 20 | torch.manual_seed(1337) 21 | np.random.seed(1337) 22 | 23 | for split in {'train', 'test'}: 24 | 25 | data = datasets.MNIST('./data', train=split=='train', download=True) 26 | 27 | n = 7291 if split == 'train' else 2007 28 | rp = np.random.permutation(len(data))[:n] 29 | 30 | X = torch.full((n, 1, 16, 16), 0.0, dtype=torch.float32) 31 | Y = torch.full((n, 10), -1.0, dtype=torch.float32) 32 | for i, ix in enumerate(rp): 33 | I, yint = data[int(ix)] 34 | # PIL image -> numpy -> torch tensor -> [-1, 1] fp32 35 | xi = torch.from_numpy(np.array(I, dtype=np.float32)) / 127.5 - 1.0 36 | # add a fake batch dimension and a channel dimension of 1 or F.interpolate won't be happy 37 | xi = xi[None, None, ...] 38 | # resize to (16, 16) images with bilinear interpolation 39 | xi = F.interpolate(xi, (16, 16), mode='bilinear') 40 | X[i] = xi[0] # store 41 | 42 | # set the correct class to have target of +1.0 43 | Y[i, yint] = 1.0 44 | 45 | torch.save((X, Y), split + '1989.pt') 46 | -------------------------------------------------------------------------------- /repro.py: -------------------------------------------------------------------------------- 1 | """ 2 | Running this script eventually gives: 3 | 23 4 | eval: split train. loss 4.073383e-03. error 0.62%. misses: 45 5 | eval: split test . loss 2.838382e-02. error 4.09%. misses: 82 6 | """ 7 | 8 | import os 9 | import json 10 | import argparse 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.optim as optim 17 | from tensorboardX import SummaryWriter # pip install tensorboardX 18 | 19 | # ----------------------------------------------------------------------------- 20 | 21 | class Net(nn.Module): 22 | """ 1989 LeCun ConvNet per description in the paper """ 23 | 24 | def __init__(self): 25 | super().__init__() 26 | 27 | # initialization as described in the paper to my best ability, but it doesn't look right... 28 | winit = lambda fan_in, *shape: (torch.rand(*shape) - 0.5) * 2 * 2.4 / fan_in**0.5 29 | macs = 0 # keep track of MACs (multiply accumulates) 30 | acts = 0 # keep track of number of activations 31 | 32 | # H1 layer parameters and their initialization 33 | self.H1w = nn.Parameter(winit(5*5*1, 12, 1, 5, 5)) 34 | self.H1b = nn.Parameter(torch.zeros(12, 8, 8)) # presumably init to zero for biases 35 | assert self.H1w.nelement() + self.H1b.nelement() == 1068 36 | macs += (5*5*1) * (8*8) * 12 37 | acts += (8*8) * 12 38 | 39 | # H2 layer parameters and their initialization 40 | """ 41 | H2 neurons all connect to only 8 of the 12 input planes, with an unspecified pattern 42 | I am going to assume the most sensible block pattern where 4 planes at a time connect 43 | to differently overlapping groups of 8/12 input planes. We will implement this with 3 44 | separate convolutions that we concatenate the results of. 45 | """ 46 | self.H2w = nn.Parameter(winit(5*5*8, 12, 8, 5, 5)) 47 | self.H2b = nn.Parameter(torch.zeros(12, 4, 4)) # presumably init to zero for biases 48 | assert self.H2w.nelement() + self.H2b.nelement() == 2592 49 | macs += (5*5*8) * (4*4) * 12 50 | acts += (4*4) * 12 51 | 52 | # H3 is a fully connected layer 53 | self.H3w = nn.Parameter(winit(4*4*12, 4*4*12, 30)) 54 | self.H3b = nn.Parameter(torch.zeros(30)) 55 | assert self.H3w.nelement() + self.H3b.nelement() == 5790 56 | macs += (4*4*12) * 30 57 | acts += 30 58 | 59 | # output layer is also fully connected layer 60 | self.outw = nn.Parameter(winit(30, 30, 10)) 61 | self.outb = nn.Parameter(-torch.ones(10)) # 9/10 targets are -1, so makes sense to init slightly towards it 62 | assert self.outw.nelement() + self.outb.nelement() == 310 63 | macs += 30 * 10 64 | acts += 10 65 | 66 | self.macs = macs 67 | self.acts = acts 68 | 69 | def forward(self, x): 70 | 71 | # x has shape (1, 1, 16, 16) 72 | x = F.pad(x, (2, 2, 2, 2), 'constant', -1.0) # pad by two using constant -1 for background 73 | x = F.conv2d(x, self.H1w, stride=2) + self.H1b 74 | x = torch.tanh(x) 75 | 76 | # x is now shape (1, 12, 8, 8) 77 | x = F.pad(x, (2, 2, 2, 2), 'constant', -1.0) # pad by two using constant -1 for background 78 | slice1 = F.conv2d(x[:, 0:8], self.H2w[0:4], stride=2) # first 4 planes look at first 8 input planes 79 | slice2 = F.conv2d(x[:, 4:12], self.H2w[4:8], stride=2) # next 4 planes look at last 8 input planes 80 | slice3 = F.conv2d(torch.cat((x[:, 0:4], x[:, 8:12]), dim=1), self.H2w[8:12], stride=2) # last 4 planes are cross 81 | x = torch.cat((slice1, slice2, slice3), dim=1) + self.H2b 82 | x = torch.tanh(x) 83 | 84 | # x is now shape (1, 12, 4, 4) 85 | x = x.flatten(start_dim=1) # (1, 12*4*4) 86 | x = x @ self.H3w + self.H3b 87 | x = torch.tanh(x) 88 | 89 | # x is now shape (1, 30) 90 | x = x @ self.outw + self.outb 91 | x = torch.tanh(x) 92 | 93 | # x is finally shape (1, 10) 94 | return x 95 | 96 | # ----------------------------------------------------------------------------- 97 | 98 | if __name__ == '__main__': 99 | 100 | parser = argparse.ArgumentParser(description="Train a 1989 LeCun ConvNet on digits") 101 | parser.add_argument('--learning-rate', '-l', type=float, default=0.03, help="SGD learning rate") 102 | parser.add_argument('--output-dir' , '-o', type=str, default='out/base', help="output directory for training logs") 103 | args = parser.parse_args() 104 | print(vars(args)) 105 | 106 | # init rng 107 | torch.manual_seed(1337) 108 | np.random.seed(1337) 109 | torch.use_deterministic_algorithms(True) 110 | 111 | # set up logging 112 | os.makedirs(args.output_dir, exist_ok=True) 113 | with open(os.path.join(args.output_dir, 'args.json'), 'w') as f: 114 | json.dump(vars(args), f, indent=2) 115 | writer = SummaryWriter(args.output_dir) 116 | 117 | # init a model 118 | model = Net() 119 | print("model stats:") 120 | print("# params: ", sum(p.numel() for p in model.parameters())) # in paper total is 9,760 121 | print("# MACs: ", model.macs) 122 | print("# activations: ", model.acts) 123 | 124 | # init data 125 | Xtr, Ytr = torch.load('train1989.pt') 126 | Xte, Yte = torch.load('test1989.pt') 127 | 128 | # init optimizer 129 | optimizer = optim.SGD(model.parameters(), lr=args.learning_rate) 130 | 131 | def eval_split(split): 132 | # eval the full train/test set, batched implementation for efficiency 133 | model.eval() 134 | X, Y = (Xtr, Ytr) if split == 'train' else (Xte, Yte) 135 | Yhat = model(X) 136 | loss = torch.mean((Y - Yhat)**2) 137 | err = torch.mean((Y.argmax(dim=1) != Yhat.argmax(dim=1)).float()) 138 | print(f"eval: split {split:5s}. loss {loss.item():e}. error {err.item()*100:.2f}%. misses: {int(err.item()*Y.size(0))}") 139 | writer.add_scalar(f'error/{split}', err.item()*100, pass_num) 140 | writer.add_scalar(f'loss/{split}', loss.item(), pass_num) 141 | 142 | # train 143 | for pass_num in range(23): 144 | 145 | # perform one epoch of training 146 | model.train() 147 | for step_num in range(Xtr.size(0)): 148 | 149 | # fetch a single example into a batch of 1 150 | x, y = Xtr[[step_num]], Ytr[[step_num]] 151 | 152 | # forward the model and the loss 153 | yhat = model(x) 154 | loss = torch.mean((y - yhat)**2) 155 | 156 | # calculate the gradient and update the parameters 157 | optimizer.zero_grad(set_to_none=True) 158 | loss.backward() 159 | optimizer.step() 160 | 161 | # after epoch epoch evaluate the train and test error / metrics 162 | print(pass_num + 1) 163 | eval_split('train') 164 | eval_split('test') 165 | 166 | # save final model to file 167 | torch.save(model.state_dict(), os.path.join(args.output_dir, 'model.pt')) 168 | -------------------------------------------------------------------------------- /vis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 72, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# imports\n", 10 | "import os\n", 11 | "\n", 12 | "import numpy as np\n", 13 | "import torch\n", 14 | "import torch.nn as nn\n", 15 | "import torch.nn.functional as F\n", 16 | "\n", 17 | "import matplotlib\n", 18 | "import matplotlib.pyplot as plt\n", 19 | "%matplotlib inline" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 75, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "# load model\n", 29 | "from repro import Net\n", 30 | "model_dir = 'out/base'\n", 31 | "model = Net()\n", 32 | "model.load_state_dict(torch.load(os.path.join(model_dir, 'model.pt')))\n", 33 | "model.eval();" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 76, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "# load data\n", 43 | "Xtr, Ytr = torch.load('train1989.pt')\n", 44 | "Xte, Yte = torch.load('test1989.pt')" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 85, 50 | "metadata": { 51 | "scrolled": false 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "\n", 56 | "def grid_mistakes(X, Y):\n", 57 | " \n", 58 | " plt.figure(figsize=(14, 4))\n", 59 | " ishow, nshow = 0, 14\n", 60 | " for ix in range(X.size(0)):\n", 61 | " x, y = X[[ix]], Y[[ix]]\n", 62 | " yhat = model(x)\n", 63 | " yi = y.argmax()\n", 64 | " yhati = yhat.argmax()\n", 65 | " if yi != yhati:\n", 66 | " plt.subplot(2, 7, ishow+1)\n", 67 | " plt.imshow(x[0,0], cmap='gray')\n", 68 | " plt.title(f'gt={yi}, pred={yhati}')\n", 69 | " plt.axis('off')\n", 70 | " ishow += 1\n", 71 | " if ishow >= nshow:\n", 72 | " break\n" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 86, 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "data": { 82 | "image/png": "\n", 83 | "text/plain": [ 84 | "
" 85 | ] 86 | }, 87 | "metadata": { 88 | "needs_background": "light" 89 | }, 90 | "output_type": "display_data" 91 | } 92 | ], 93 | "source": [ 94 | "grid_mistakes(Xtr, Ytr) # training set mistakes" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 87, 100 | "metadata": {}, 101 | "outputs": [ 102 | { 103 | "data": { 104 | "image/png": "\n", 105 | "text/plain": [ 106 | "
" 107 | ] 108 | }, 109 | "metadata": { 110 | "needs_background": "light" 111 | }, 112 | "output_type": "display_data" 113 | } 114 | ], 115 | "source": [ 116 | "grid_mistakes(Xte, Yte) # test set mistakes" 117 | ] 118 | } 119 | ], 120 | "metadata": { 121 | "kernelspec": { 122 | "display_name": "Python 3 (ipykernel)", 123 | "language": "python", 124 | "name": "python3" 125 | }, 126 | "language_info": { 127 | "codemirror_mode": { 128 | "name": "ipython", 129 | "version": 3 130 | }, 131 | "file_extension": ".py", 132 | "mimetype": "text/x-python", 133 | "name": "python", 134 | "nbconvert_exporter": "python", 135 | "pygments_lexer": "ipython3", 136 | "version": "3.8.12" 137 | } 138 | }, 139 | "nbformat": 4, 140 | "nbformat_minor": 4 141 | } 142 | --------------------------------------------------------------------------------