├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── docs ├── CIFAR-10 Experiments.ipynb ├── MNIST Experiments.ipynb ├── SFDDD Experiments.ipynb └── imgs │ ├── SFDDD.png │ ├── cifar.png │ └── mnistres.png ├── requirements.txt ├── res_net.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # OSX cache files 2 | .DS_Store 3 | 4 | # images 5 | *.png 6 | 7 | # models 8 | *model 9 | 10 | # checkpoints 11 | checkpoints/ 12 | 13 | # datasets 14 | cifar-10-batches-py/ 15 | gtsrb/ 16 | MNIST/ 17 | *_DATASET_PATH 18 | 19 | # notebooks 20 | noteboook/ 21 | 22 | # Byte-compiled / optimized / DLL files 23 | __pycache__/ 24 | *.py[cod] 25 | *$py.class 26 | 27 | # C extensions 28 | *.so 29 | 30 | export_cp.py 31 | 32 | # Distribution / packaging 33 | .Python 34 | env/ 35 | build/ 36 | develop-eggs/ 37 | dist/ 38 | downloads/ 39 | eggs/ 40 | .eggs/ 41 | lib/ 42 | lib64/ 43 | parts/ 44 | sdist/ 45 | var/ 46 | *.egg-info/ 47 | .installed.cfg 48 | *.egg 49 | 50 | # PyInstaller 51 | # Usually these files are written by a python script from a template 52 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 53 | *.manifest 54 | *.spec 55 | 56 | # Installer logs 57 | pip-log.txt 58 | pip-delete-this-directory.txt 59 | 60 | # Unit test / coverage reports 61 | htmlcov/ 62 | .tox/ 63 | .coverage 64 | .coverage.* 65 | .cache 66 | nosetests.xml 67 | coverage.xml 68 | *,cover 69 | .hypothesis/ 70 | 71 | # Translations 72 | *.mo 73 | *.pot 74 | 75 | # Django stuff: 76 | *.log 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | #Ipython Notebook 85 | .ipynb_checkpoints 86 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Daniele Ciriello 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Residual Learning for Image Recognition 2 | 3 | Implementation of ["Deep Residual Learning for Image Recognition", Kaiming 4 | He, Xiangyu Zhang, Shaoqing Ren, Jian Sun](http://arxiv.org/abs/1512.03385) in [PyFunt](https://github.com/dnlcrl/PyFunt) (a simple Python + Numpy DL framework). 5 | 6 | Also inspired by [this implementation in Lua + Torch](https://github.com/gcr/torch-residual-networks). 7 | 8 | The network operates on minibatches of data that have shape (N, C, H, W) 9 | consisting of N images, each with height H and width W and with C input 10 | channels. It has, like in the reference paper, (6*n)+2 layers, 11 | composed as below: 12 | 13 | (image_dim: 3, 32, 32; F=16) 14 | (input_dim: N, *image_dim) 15 | INPUT 16 | | 17 | v 18 | +-------------------+ 19 | |conv[F, *image_dim]| (out_shape: N, 16, 32, 32) 20 | +-------------------+ 21 | | 22 | v 23 | +-------------------------+ 24 | |n * res_block[F, F, 3, 3]| (out_shape: N, 16, 32, 32) 25 | +-------------------------+ 26 | | 27 | v 28 | +-------------------------+ 29 | |res_block[2*F, F, 3, 3] | (out_shape: N, 32, 16, 16) 30 | +-------------------------+ 31 | | 32 | v 33 | +---------------------------------+ 34 | |(n-1) * res_block[2*F, 2*F, 3, 3]| (out_shape: N, 32, 16, 16) 35 | +---------------------------------+ 36 | | 37 | v 38 | +-------------------------+ 39 | |res_block[4*F, 2*F, 3, 3]| (out_shape: N, 64, 8, 8) 40 | +-------------------------+ 41 | | 42 | v 43 | +---------------------------------+ 44 | |(n-1) * res_block[4*F, 4*F, 3, 3]| (out_shape: N, 64, 8, 8) 45 | +---------------------------------+ 46 | | 47 | v 48 | +-------------+ 49 | |pool[1, 8, 8]| (out_shape: N, 64, 1, 1) 50 | +-------------+ 51 | | 52 | v 53 | +-------+ 54 | |softmax| (out_shape: N, num_classes) 55 | +-------+ 56 | | 57 | v 58 | OUTPUT 59 | 60 | Every convolution layer has a pad=1 and stride=1, except for the dimension 61 | enhancning layers which has a stride of 2 to mantain the computational 62 | complexity. 63 | Optionally, there is the possibility of setting m affine layers immediatley before the softmax layer by setting the hidden_dims parameter, which should be a list of integers representing the numbe of neurons for each affine layer. 64 | 65 | Each residual block is composed as below: 66 | 67 | Input 68 | | 69 | ,-------+-----. 70 | Downsampling 3x3 convolution+dimensionality reduction 71 | | | 72 | v v 73 | Zero-padding 3x3 convolution 74 | | | 75 | `-----( Add )---' 76 | | 77 | Output 78 | 79 | After every layer, a batch normalization with momentum .1 is applied. 80 | 81 | 82 | ## Requirements 83 | 84 | - [Python 2.7](https://www.python.org/) 85 | - [numpy](www.numpy.org/) 86 | - [pyfunt](https://github.com/dnlcrl/PyFunt) 87 | - [pydatset](https://github.com/dnlcrl/PyDatSet) 88 | 89 | 90 | After you get Python, you can get [pip](https://pypi.python.org/pypi/pip) and install all requirements by running: 91 | 92 | pip install -r requirements.txt 93 | 94 | ## Usage 95 | 96 | If you want to train the network on the CIFAR-10 dataset, simply run: 97 | 98 | python train.py --help 99 | 100 | Otherwise, you have to get the right train.py for MNIST or SFDDD datasets, they are respectively on the mnist and sfddd git branches: 101 | 102 | - train.py for MNIST: https://github.com/dnlcrl/PyResNet/blob/mnist/train.py 103 | 104 | - train.py for SFDDD: https://github.com/dnlcrl/PyResNet/blob/sfddd/train.py 105 | 106 | ## Experiments Results 107 | 108 | You can view all the experiments results in the [./docs directory](https://github.com/dnlcrl/PyResNet/tree/master/docs). Main results are shown below: 109 | 110 | ### [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) 111 | 112 | best error: 9.59 % (accuracy: 0.9041) with a 20 layers residual network (n=3): 113 | 114 | [![CIFAR-10 results](https://github.com/dnlcrl/PyResNet/blob/master/docs/imgs/cifar.png)](https://github.com/dnlcrl/PyResNet/blob/master/docs/CIFAR-10%20Experiments.ipynb) 115 | 116 | [CIFAR-10 Results - iPython notebook](https://github.com/dnlcrl/PyResNet/blob/master/docs/CIFAR-10%20Experiments.ipynb) 117 | 118 | ### [MNIST](http://yann.lecun.com/exdb/mnist/) 119 | 120 | best error: 0.36 % (accuracy: 0.9964) with a 32 layers residual network (n=5): 121 | 122 | [![MNIST results](https://github.com/dnlcrl/PyResNet/blob/master/docs/imgs/mnistres.png)](https://github.com/dnlcrl/PyResNet/blob/master/docs/MNIST%20Experiments.ipynb) 123 | 124 | [MNIST Results - iPython notebook](https://github.com/dnlcrl/PyResNet/blob/master/docs/MNIST%20Experiments.ipynb) 125 | 126 | 127 | ### [SFDDD](https://www.kaggle.com/c/state-farm-distracted-driver-detection) 128 | 129 | best error: 0.25 % (accuracy: 0.9975 %) on a subset (1000 samples) of the train data (~21k images) with a 44 layers residual network (n=7), resizing the images to 64x48, randomly cropping 32x32 images for training and cropping a 32x32 image from the center of the original images for testing. Unfortunately I got more than 2% error on Kaggle's results (composed of ~80k images). 130 | 131 | [![SFDDD results](https://github.com/dnlcrl/deep-residual-networks-pyfunt/blob/master/docs/imgs/SFDDD.png)](https://github.com/dnlcrl/PyResNet/blob/master/docs/SFDDD%20Experiments.ipynb) 132 | 133 | [SFDDD Results - iPython notebook](https://github.com/dnlcrl/PyResNet/blob/master/docs/SFDDD%20Experiments.ipynb) 134 | 135 | 144 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "nnet", 3 | "res_net" 4 | ] 5 | -------------------------------------------------------------------------------- /docs/imgs/SFDDD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dnlcrl/deep-residual-networks-pyfunt/801ce775a907ff85c5c8844cd927255ddc93c4ac/docs/imgs/SFDDD.png -------------------------------------------------------------------------------- /docs/imgs/cifar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dnlcrl/deep-residual-networks-pyfunt/801ce775a907ff85c5c8844cd927255ddc93c4ac/docs/imgs/cifar.png -------------------------------------------------------------------------------- /docs/imgs/mnistres.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dnlcrl/deep-residual-networks-pyfunt/801ce775a907ff85c5c8844cd927255ddc93c4ac/docs/imgs/mnistres.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy >= 1.7.1 2 | pydatset 3 | git+git://github.com/dnlcrl/PyDatSet.git 4 | pyfunt >= 0.1b 5 | git+git://github.com/dnlcrl/PyFunt.git 6 | -------------------------------------------------------------------------------- /res_net.py: -------------------------------------------------------------------------------- 1 | from pyfunt import (SpatialConvolution, SpatialBatchNormalization, 2 | SpatialAveragePooling, Sequential, ReLU, Linear, 3 | Reshape, LogSoftMax, Padding, Identity, ConcatTable, 4 | CAddTable) 5 | 6 | 7 | def residual_layer(n_channels, n_out_channels=None, stride=1): 8 | n_out_channels = n_out_channels or n_channels 9 | 10 | convs = Sequential() 11 | add = convs.add 12 | add(SpatialConvolution( 13 | n_channels, n_out_channels, 3, 3, stride, stride, 1, 1)) 14 | add(SpatialBatchNormalization(n_out_channels)) 15 | add(SpatialConvolution(n_out_channels, n_out_channels, 3, 3, 1, 1, 1, 1)) 16 | add(SpatialBatchNormalization(n_out_channels)) 17 | 18 | if stride > 1: 19 | shortcut = Sequential() 20 | shortcut.add(SpatialAveragePooling(2, 2, stride, stride)) 21 | shortcut.add(Padding(1, (n_out_channels - n_channels)/2, 3)) 22 | else: 23 | shortcut = Identity() 24 | 25 | res = Sequential() 26 | res.add(ConcatTable().add(convs).add(shortcut)).add(CAddTable()) 27 | # https://github.com/szagoruyko/wide-residual-networks/blob/master/models/resnet-pre-act.lua 28 | 29 | res.add(ReLU(True)) 30 | 31 | return res 32 | 33 | 34 | def resnet(n_size, num_starting_filters, reg): 35 | ''' 36 | Implementation of ["Deep Residual Learning for Image Recognition",Kaiming \ 37 | He, Xiangyu Zhang, Shaoqing Ren, Jian Sun - http://arxiv.org/abs/1512.03385 38 | 39 | Inspired by https://github.com/gcr/torch-residual-networks 40 | 41 | This network should model a similiar behaviour of gcr's implementation. 42 | Check https://github.com/gcr/torch-residual-networks for more infos about \ 43 | the structure. 44 | 45 | The network operates on minibatches of data that have shape (N, C, H, W) 46 | consisting of N images, each with height H and width W and with C input 47 | channels. 48 | 49 | The network has, like in the reference paper (except for the final optional 50 | affine layers), (6*n)+2 layers, composed as below: 51 | 52 | (image_dim: 3, 32, 32; F=16) 53 | (input_dim: N, *image_dim) 54 | INPUT 55 | | 56 | v 57 | +-------------------+ 58 | |conv[F, *image_dim]| (out_shape: N, 16, 32, 32) 59 | +-------------------+ 60 | | 61 | v 62 | +-------------------------+ 63 | |n * res_block[F, F, 3, 3]| (out_shape: N, 16, 32, 32) 64 | +-------------------------+ 65 | | 66 | v 67 | +-------------------------+ 68 | |res_block[2*F, F, 3, 3] | (out_shape: N, 32, 16, 16) 69 | +-------------------------+ 70 | | 71 | v 72 | +---------------------------------+ 73 | |(n-1) * res_block[2*F, 2*F, 3, 3]| (out_shape: N, 32, 16, 16) 74 | +---------------------------------+ 75 | | 76 | v 77 | +-------------------------+ 78 | |res_block[4*F, 2*F, 3, 3]| (out_shape: N, 64, 8, 8) 79 | +-------------------------+ 80 | | 81 | v 82 | +---------------------------------+ 83 | |(n-1) * res_block[4*F, 4*F, 3, 3]| (out_shape: N, 64, 8, 8) 84 | +---------------------------------+ 85 | | 86 | v 87 | +-------------+ 88 | |pool[1, 8, 8]| (out_shape: N, 64, 1, 1) 89 | +-------------+ 90 | | 91 | v 92 | +- - - - - - - - -+ 93 | |(opt) m * affine | (out_shape: N, 64, 1, 1) 94 | +- - - - - - - - -+ 95 | | 96 | v 97 | +-------+ 98 | |softmax| (out_shape: N, num_classes) 99 | +-------+ 100 | | 101 | v 102 | OUTPUT 103 | 104 | Every convolution layer has a pad=1 and stride=1, except for the dimension 105 | enhancning layers which has a stride of 2 to mantain the computational 106 | complexity. 107 | Optionally, there is the possibility of setting m affine layers immediatley 108 | before the softmax layer by setting the hidden_dims parameter, which should 109 | be a list of integers representing the numbe of neurons for each affine 110 | layer. 111 | 112 | Each residual block is composed as below: 113 | 114 | Input 115 | | 116 | ,-------+-----. 117 | Downsampling 3x3 convolution+dimensionality reduction 118 | | | 119 | v v 120 | Zero-padding 3x3 convolution 121 | | | 122 | `-----( Add )---' 123 | | 124 | Output 125 | 126 | After every layer, a batch normalization with momentum .1 is applied. 127 | 128 | Weight initialization (check also layers/init.py and layers/README.md): 129 | - Inizialize the weights and biases for the affine layers in the same 130 | way of torch's default mode by calling _init_affine_wb that returns a 131 | tuple (w, b). 132 | - Inizialize the weights for the conv layers in the same 133 | way of torch's default mode by calling init_conv_w. 134 | - Inizialize the weights for the conv layers in the same 135 | way of kaiming's mode by calling init_conv_w_kaiming 136 | (http://arxiv.org/abs/1502.01852 and 137 | http://andyljones.tumblr.com/post/110998971763/an-explanation-of-xavier-\ 138 | initialization) 139 | - Initialize batch normalization layer's weights like torch's default by 140 | calling init_bn_w 141 | - Initialize batch normalization layer's weights like cgr's first resblock\ 142 | 's bn (https://github.com/gcr/torch-residual-networks/blob/master/residual\ 143 | -layers.lua#L57-L59) by calling init_bn_w_gcr. 144 | 145 | num_filters=[16, 16, 32, 32, 64, 64], 146 | Initialize a new network. 147 | 148 | Inputs: 149 | - input_dim: Tuple (C, H, W) giving size of input data. 150 | - num_starting_filters: Number of filters for the first convolution 151 | layer. 152 | - n_size: nSize for the residual network like in the reference paper 153 | - hidden_dims: Optional list number of units to use in the 154 | fully-connected hidden layers between the fianl pool and the sofmatx 155 | layer. 156 | - num_classes: Number of scores to produce from the final affine layer. 157 | - reg: Scalar giving L2 regularization strength 158 | - dtype: numpy datatype to use for computation. 159 | ''' 160 | 161 | nfs = num_starting_filters 162 | model = Sequential() 163 | add = model.add 164 | add(SpatialConvolution(3, nfs, 3, 3, 1, 1, 1, 1)) 165 | add(SpatialBatchNormalization(nfs)) 166 | add(ReLU()) 167 | 168 | for i in xrange(1, n_size): 169 | add(residual_layer(nfs)) 170 | add(residual_layer(nfs, 2*nfs, 2)) 171 | 172 | for i in xrange(1, n_size-1): 173 | add(residual_layer(2*nfs)) 174 | add(residual_layer(2*nfs, 4*nfs, 2)) 175 | 176 | for i in xrange(1, n_size-1): 177 | add(residual_layer(4*nfs)) 178 | 179 | add(SpatialAveragePooling(8, 8)) 180 | add(Reshape(nfs*4)) 181 | add(Linear(nfs*4, 10)) 182 | add(LogSoftMax()) 183 | return model 184 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import uuid 5 | import numpy as np 6 | # import matplotlib.pyplot as plt 7 | from pydatset.cifar10 import get_CIFAR10_data 8 | from pydatset.data_augmentation import (random_flips, 9 | random_crops) 10 | from res_net import resnet 11 | from pyfunt.solver import Solver as Solver 12 | 13 | import inspect 14 | import argparse 15 | 16 | from pyfunt.class_nll_criterion import ClassNLLCriterion 17 | 18 | np.seterr(all='raise') 19 | 20 | np.random.seed(0) 21 | 22 | DATA_PATH = '../CIFAR_DATASET_PATH' 23 | 24 | path_set = False 25 | while not path_set: 26 | try: 27 | with open(DATA_PATH) as f: 28 | DATASET_PATH = f.read() 29 | path_set = True 30 | except: 31 | data_path = raw_input('Enter the path for the CIFAR10 dataset: ') 32 | with open(DATA_PATH, "w") as f: 33 | f.write(data_path) 34 | 35 | 36 | EXPERIMENT_PATH = '../Experiments/' + str(uuid.uuid4())[-10:] 37 | 38 | # residual network constants 39 | NSIZE = 3 40 | N_STARTING_FILTERS = 16 41 | 42 | # solver constants 43 | NUM_PROCESSES = 4 44 | 45 | NUM_TRAIN = 50000 46 | NUM_TEST = 10000 47 | 48 | WEIGHT_DEACY = 1e-4 49 | REGULARIZATION = 0 50 | LEARNING_RATE = .1 51 | MOMENTUM = .99 52 | NUM_EPOCHS = 160 53 | BATCH_SIZE = 64 54 | CHECKPOINT_EVERY = 20 55 | 56 | XH, XW = 32, 32 57 | 58 | args = argparse.Namespace() 59 | 60 | 61 | def parse_args(): 62 | """ 63 | Parse the options for running the Residual Network on CIFAR-10. 64 | """ 65 | desc = 'Train a Residual Network on CIFAR-10.' 66 | parser = argparse.ArgumentParser(description=desc, 67 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 68 | add = parser.add_argument 69 | add('--dataset_path', 70 | metavar='DIRECOTRY', 71 | default=DATASET_PATH, 72 | type=str, 73 | help='directory where results will be saved') 74 | add('--experiment_path', 75 | metavar='DIRECOTRY', 76 | default=EXPERIMENT_PATH, 77 | type=str, 78 | help='directory where results will be saved') 79 | add('-load', '--load_checkpoint', 80 | metavar='DIRECOTRY', 81 | default='', 82 | type=str, 83 | help='load checkpoint from load_checkpoint') 84 | add('--n_size', 85 | metavar='INT', 86 | default=NSIZE, 87 | type=int, 88 | help='Network will have (6*n)+2 conv layers') 89 | add('--n_starting_filters', 90 | metavar='INT', 91 | default=N_STARTING_FILTERS, 92 | type=int, 93 | help='Network will starts with those number of filters') 94 | add('--n_processes', '-np', 95 | metavar='INT', 96 | default=NUM_PROCESSES, 97 | type=int, 98 | help='Number of processes for each step') 99 | add('--n_train', 100 | metavar='INT', 101 | default=NUM_TRAIN, 102 | type=int, 103 | help='Number of total images to select for training') 104 | add('--n_test', 105 | metavar='INT', 106 | default=NUM_TEST, 107 | type=int, 108 | help='Number of total images to select for validation') 109 | add('-wd', '--weight_decay', 110 | metavar='FLOAT', 111 | default=WEIGHT_DEACY, 112 | type=float, 113 | help='Weight decay for sgd_th') 114 | add('-reg', '--network_regularization', 115 | metavar='FLOAT', 116 | default=REGULARIZATION, 117 | type=float, 118 | help='L2 regularization term for the network') 119 | add('-lr', '--learning_rate', 120 | metavar='FLOAT', 121 | default=LEARNING_RATE, 122 | type=float, 123 | help='Learning rate to use with sgd_th') 124 | add('-mom', '--momentum', 125 | metavar='FLOAT', 126 | default=MOMENTUM, 127 | type=float, 128 | help='Nesterov momentum use with sgd_th') 129 | add('--n_epochs', '-nep', 130 | metavar='INT', 131 | default=NUM_EPOCHS, 132 | type=int, 133 | help='Number of epochs for training') 134 | add('--batch_size', '-bs', 135 | metavar='INT', 136 | default=BATCH_SIZE, 137 | type=int, 138 | help='Number of images for each iteration') 139 | add('--checkpoint_every', '-cp', 140 | metavar='INT', 141 | default=CHECKPOINT_EVERY, 142 | type=int, 143 | help='Number of epochs between each checkpoint') 144 | parser.parse_args(namespace=args) 145 | assert not (args.network_regularization and args.weight_decay) 146 | 147 | 148 | def data_augm(batch): 149 | p = 2 150 | h, w = XH, XW 151 | 152 | # batch = random_tint(batch) 153 | # batch = random_contrast(batch) 154 | batch = random_flips(batch) 155 | # batch = random_rotate(batch, 10) 156 | batch = random_crops(batch, (h, w), pad=p) 157 | return batch 158 | 159 | 160 | def custom_update_decay(epoch): 161 | if epoch in (80, 120): 162 | return 0.1 163 | return 1 164 | 165 | 166 | def print_infos(solver): 167 | print('Model: \n%s' % solver.model) 168 | 169 | print('Solver: \n%s' % solver) 170 | 171 | print('Data Augmentation Function: \n') 172 | print(''.join(['\t' + i for i in inspect.getsourcelines(data_augm)[0]])) 173 | print('Custom Weight Decay Update Rule: \n') 174 | print(''.join(['\t' + i for i in inspect.getsourcelines(custom_update_decay)[0]])) 175 | 176 | 177 | def main(): 178 | parse_args() 179 | 180 | data = get_CIFAR10_data(args.dataset_path, 181 | num_training=args.n_train, num_validation=0, num_test=args.n_test) 182 | 183 | data = { 184 | 'X_train': data['X_train'], 185 | 'y_train': data['y_train'], 186 | 'X_val': data['X_test'], 187 | 'y_val': data['y_test'], 188 | } 189 | 190 | exp_path = args.experiment_path 191 | nf = args.n_starting_filters 192 | reg = args.network_regularization 193 | 194 | model = resnet(n_size=args.n_size, 195 | num_starting_filters=nf, 196 | reg=reg) 197 | 198 | wd = args.weight_decay 199 | lr = args.learning_rate 200 | mom = args.momentum 201 | 202 | optim_config = {'learning_rate': lr, 'nesterov': True, 203 | 'momentum': mom, 'weight_decay': wd} 204 | 205 | epochs = args.n_epochs 206 | bs = args.batch_size 207 | num_p = args.n_processes 208 | cp = args.checkpoint_every 209 | criterion = ClassNLLCriterion() 210 | solver = Solver(model, data, args.load_checkpoint, 211 | criterion=criterion, 212 | num_epochs=epochs, batch_size=bs, # 20 213 | update_rule='sgd_th', 214 | optim_config=optim_config, 215 | custom_update_ld=custom_update_decay, 216 | batch_augment_func=data_augm, 217 | checkpoint_every=cp, 218 | num_processes=num_p) 219 | 220 | print_infos(solver) 221 | solver.train() 222 | 223 | solver.export_model(exp_path) 224 | solver.export_histories(exp_path) 225 | 226 | print('finish') 227 | 228 | 229 | if __name__ == '__main__': 230 | main() 231 | --------------------------------------------------------------------------------