├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── accuracy.jpg ├── googlenet.py ├── loss.jpg ├── main.py └── nutszebra_optimizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | # trainer 92 | # trainer 93 | 94 | # model 95 | model 96 | model.dot 97 | 98 | # log 99 | log 100 | log.json 101 | 102 | # cifar10 103 | cifar10.pkl 104 | cifar-10-batches-py 105 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "trainer"] 2 | path = trainer 3 | url = https://github.com/nutszebra/trainer.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 ikki kishida 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # What's this 2 | Implementation of GoogLeNet by chainer 3 | 4 | 5 | # Dependencies 6 | 7 | git clone https://github.com/nutszebra/googlenet.git 8 | cd googlenet 9 | git submodule init 10 | git submodule update 11 | 12 | # How to run 13 | python main.py -p ./ -g 0 14 | 15 | 16 | # Details about my implementation 17 | 18 | * Data augmentation 19 | Train: Pictures are randomly resized in the range of [256, 512], then 224x224 patches are extracted randomly and are normalized locally. Horizontal flipping is applied with 0.5 probability. 20 | Test: Pictures are resized to 384x384, then they are normalized locally. Single image test is used to calculate total accuracy. 21 | 22 | * Auxiliary classifiers 23 | No implementation 24 | 25 | * Learning rate 26 | As [[1]][Paper] said, learning rate are multiplied by 0.96 at every 8 epochs. The description about initial learning rate can't be found in [[1]][Paper], so initial learning is setted as 0.0015 that is found in [[2]][Paper2]. 27 | 28 | * Weight decay 29 | The description about weight decay can't be found in [[1]][Paper], so by using [[2]][Paper2] and [[3]][Paper3] I guessed that weight decay is 2.0*10^-4. 30 | 31 | # Cifar10 result 32 | 33 | | network | depth | total accuracy (%) | 34 | |:---------------------|--------|-------------------:| 35 | | my implementation | 22 | 91.33 | 36 | 37 | loss 38 | total accuracy 39 | 40 | # References 41 | Going Deeper with Convolutions [[1]][Paper] 42 | Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift [[2]][Paper2] 43 | Rethinking the Inception Architecture for Computer Vision [[3]][Paper3] 44 | [paper]: https://arxiv.org/abs/1409.4842 "Paper" 45 | [paper2]: https://arxiv.org/abs/1502.03167 "Paper2" 46 | [paper3]: https://arxiv.org/abs/1512.00567 "Paper3" 47 | -------------------------------------------------------------------------------- /accuracy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nutszebra/googlenet/e5e01b25085443a4b1f616c6ddd6d7ad24e2f363/accuracy.jpg -------------------------------------------------------------------------------- /googlenet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import functools 3 | import chainer.links as L 4 | import chainer.functions as F 5 | from collections import defaultdict 6 | import nutszebra_chainer 7 | 8 | 9 | class Inception(nutszebra_chainer.Model): 10 | 11 | def __init__(self, in_channel, conv1x1=64, reduce3x3=96, conv3x3=128, reduce5x5=16, conv5x5=32, pool_proj=32): 12 | super(Inception, self).__init__() 13 | modules = [] 14 | modules.append(('conv1x1', L.Convolution2D(in_channel, conv1x1, 1, 1, 0))) 15 | modules.append(('reduce3x3', L.Convolution2D(in_channel, reduce3x3, 1, 1, 0))) 16 | modules.append(('conv3x3', L.Convolution2D(reduce3x3, conv3x3, 3, 1, 1))) 17 | modules.append(('reduce5x5', L.Convolution2D(in_channel, reduce5x5, 1, 1, 0))) 18 | modules.append(('conv5x5', L.Convolution2D(reduce5x5, conv5x5, 5, 1, 2))) 19 | modules.append(('pool_proj', L.Convolution2D(in_channel, pool_proj, 1, 1, 0))) 20 | # register layers 21 | [self.add_link(*link) for link in modules] 22 | self.modules = modules 23 | 24 | def weight_initialization(self): 25 | for name, link in self.modules: 26 | self[name].W.data = self.weight_relu_initialization(link) 27 | self[name].b.data = self.bias_initialization(link, constant=0) 28 | 29 | def __call__(self, x, train=False): 30 | a = F.relu(self.conv1x1(x)) 31 | b = F.relu(self.conv3x3(F.relu(self.reduce3x3(x)))) 32 | c = F.relu(self.conv5x5(F.relu(self.reduce5x5(x)))) 33 | d = F.relu(self.pool_proj(F.max_pooling_2d(x, ksize=(3, 3), stride=(1, 1), pad=(1, 1)))) 34 | return F.concat((a, b, c, d), axis=1) 35 | 36 | @staticmethod 37 | def _conv_count_parameters(conv): 38 | return functools.reduce(lambda a, b: a * b, conv.W.data.shape) 39 | 40 | def count_parameters(self): 41 | count = 0 42 | for name, link in self.modules: 43 | count += Inception._conv_count_parameters(link) 44 | return count 45 | 46 | 47 | class Googlenet(nutszebra_chainer.Model): 48 | 49 | def __init__(self, category_num): 50 | super(Googlenet, self).__init__() 51 | modules = [] 52 | modules += [('conv1', L.Convolution2D(3, 64, (7, 7), (2, 2), (3, 3)))] 53 | modules += [('conv2_1x1', L.Convolution2D(64, 64, (1, 1), (1, 1), (0, 0)))] 54 | modules += [('conv2_3x3', L.Convolution2D(64, 192, (3, 3), (1, 1), (1, 1)))] 55 | modules += [('inception3a', Inception(192, 64, 96, 128, 16, 32, 32))] 56 | modules += [('inception3b', Inception(256, 128, 128, 192, 32, 96, 64))] 57 | modules += [('inception4a', Inception(480, 192, 96, 208, 16, 48, 64))] 58 | modules += [('inception4b', Inception(512, 160, 112, 224, 24, 64, 64))] 59 | modules += [('inception4c', Inception(512, 128, 128, 256, 24, 64, 64))] 60 | modules += [('inception4d', Inception(512, 112, 144, 288, 32, 64, 64))] 61 | modules += [('inception4e', Inception(528, 256, 160, 320, 32, 128, 128))] 62 | modules += [('inception5a', Inception(832, 256, 160, 320, 32, 128, 128))] 63 | modules += [('inception5b', Inception(832, 384, 192, 384, 48, 128, 128))] 64 | modules += [('linear', L.Linear(1024, category_num))] 65 | # register layers 66 | [self.add_link(*link) for link in modules] 67 | self.modules = modules 68 | self.name = 'googlenet_{}'.format(category_num) 69 | 70 | def count_parameters(self): 71 | count = 0 72 | count += functools.reduce(lambda a, b: a * b, self.conv1.W.data.shape) 73 | count += functools.reduce(lambda a, b: a * b, self.conv2_1x1.W.data.shape) 74 | count += functools.reduce(lambda a, b: a * b, self.conv2_3x3.W.data.shape) 75 | count += self.inception3a.count_parameters() 76 | count += self.inception3b.count_parameters() 77 | count += self.inception4a.count_parameters() 78 | count += self.inception4b.count_parameters() 79 | count += self.inception4c.count_parameters() 80 | count += self.inception4d.count_parameters() 81 | count += self.inception4e.count_parameters() 82 | count += self.inception5a.count_parameters() 83 | count += self.inception5b.count_parameters() 84 | count += functools.reduce(lambda a, b: a * b, self.linear.W.data.shape) 85 | return count 86 | 87 | def weight_initialization(self): 88 | self.conv1.W.data = self.weight_relu_initialization(self.conv1) 89 | self.conv1.b.data = self.bias_initialization(self.conv1, constant=0) 90 | self.conv2_1x1.W.data = self.weight_relu_initialization(self.conv2_1x1) 91 | self.conv2_1x1.b.data = self.bias_initialization(self.conv2_1x1, constant=0) 92 | self.conv2_3x3.W.data = self.weight_relu_initialization(self.conv2_3x3) 93 | self.conv2_3x3.b.data = self.bias_initialization(self.conv2_3x3, constant=0) 94 | self.inception3a.weight_initialization() 95 | self.inception3b.weight_initialization() 96 | self.inception4a.weight_initialization() 97 | self.inception4b.weight_initialization() 98 | self.inception4c.weight_initialization() 99 | self.inception4d.weight_initialization() 100 | self.inception4e.weight_initialization() 101 | self.inception5a.weight_initialization() 102 | self.inception5b.weight_initialization() 103 | self.linear.W.data = self.weight_relu_initialization(self.linear) 104 | self.linear.b.data = self.bias_initialization(self.linear, constant=0) 105 | 106 | def __call__(self, x, train=True): 107 | h = F.relu(self.conv1(x)) 108 | h = F.max_pooling_2d(h, ksize=(3, 3), stride=(2, 2), pad=(1, 1)) 109 | h = F.relu(self.conv2_1x1(h)) 110 | h = F.relu(self.conv2_3x3(h)) 111 | h = F.max_pooling_2d(h, ksize=(3, 3), stride=(2, 2), pad=(1, 1)) 112 | h = self.inception3a(h) 113 | h = self.inception3b(h) 114 | h = F.max_pooling_2d(h, ksize=(3, 3), stride=(2, 2), pad=(1, 1)) 115 | h = self.inception4a(h) 116 | h = self.inception4b(h) 117 | h = self.inception4c(h) 118 | h = self.inception4d(h) 119 | h = self.inception4e(h) 120 | h = F.max_pooling_2d(h, ksize=(3, 3), stride=(2, 2), pad=(1, 1)) 121 | h = self.inception5a(h) 122 | h = F.relu(self.inception5b(h)) 123 | num, categories, y, x = h.data.shape 124 | # global average pooling 125 | h = F.reshape(F.average_pooling_2d(h, (y, x)), (num, categories)) 126 | h = F.dropout(h, ratio=0.4, train=train) 127 | h = self.linear(h) 128 | return h 129 | 130 | def calc_loss(self, y, t): 131 | loss = F.softmax_cross_entropy(y, t) 132 | return loss 133 | 134 | def accuracy(self, y, t, xp=np): 135 | y.to_cpu() 136 | t.to_cpu() 137 | indices = np.where((t.data == np.argmax(y.data, axis=1)) == True)[0] 138 | accuracy = defaultdict(int) 139 | for i in indices: 140 | accuracy[t.data[i]] += 1 141 | indices = np.where((t.data == np.argmax(y.data, axis=1)) == False)[0] 142 | false_accuracy = defaultdict(int) 143 | false_y = np.argmax(y.data, axis=1) 144 | for i in indices: 145 | false_accuracy[(t.data[i], false_y[i])] += 1 146 | return accuracy, false_accuracy 147 | -------------------------------------------------------------------------------- /loss.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nutszebra/googlenet/e5e01b25085443a4b1f616c6ddd6d7ad24e2f363/loss.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./trainer') 3 | import argparse 4 | import googlenet 5 | import nutszebra_data_augmentation 6 | import nutszebra_cifar10 7 | import nutszebra_optimizer 8 | 9 | if __name__ == '__main__': 10 | 11 | parser = argparse.ArgumentParser(description='cifar10') 12 | parser.add_argument('--load_model', '-m', 13 | default=None, 14 | help='trained model') 15 | parser.add_argument('--load_optimizer', '-o', 16 | default=None, 17 | help='optimizer for trained model') 18 | parser.add_argument('--load_log', '-l', 19 | default=None, 20 | help='optimizer for trained model') 21 | parser.add_argument('--save_path', '-p', 22 | default='./', 23 | help='model and optimizer will be saved every epoch') 24 | parser.add_argument('--epoch', '-e', type=int, 25 | default=200, 26 | help='maximum epoch') 27 | parser.add_argument('--batch', '-b', type=int, 28 | default=32, 29 | help='mini batch number') 30 | parser.add_argument('--gpu', '-g', type=int, 31 | default=-1, 32 | help='-1 means cpu mode, put gpu id here') 33 | parser.add_argument('--start_epoch', '-s', type=int, 34 | default=1, 35 | help='start from this epoch') 36 | parser.add_argument('--train_batch_divide', '-trb', type=int, 37 | default=4, 38 | help='divid batch number by this') 39 | parser.add_argument('--test_batch_divide', '-teb', type=int, 40 | default=4, 41 | help='divid batch number by this') 42 | parser.add_argument('--lr', '-lr', type=float, 43 | default=0.0015, 44 | help='leraning rate') 45 | 46 | args = parser.parse_args().__dict__ 47 | lr = args.pop('lr') 48 | 49 | print('generating model') 50 | model = googlenet.Googlenet(10) 51 | print('Done') 52 | optimizer = nutszebra_optimizer.OptimizerGooglenet(model, lr=lr) 53 | args['model'] = model 54 | args['optimizer'] = optimizer 55 | args['da'] = nutszebra_data_augmentation.DataAugmentationCifar10NormalizeBigger 56 | main = nutszebra_cifar10.TrainCifar10(**args) 57 | main.run() 58 | -------------------------------------------------------------------------------- /nutszebra_optimizer.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | from chainer import optimizers 3 | import nutszebra_basic_print 4 | 5 | 6 | class Optimizer(object): 7 | 8 | def __init__(self, model=None): 9 | self.model = model 10 | self.optimizer = None 11 | 12 | def __call__(self, i): 13 | pass 14 | 15 | def update(self): 16 | self.optimizer.update() 17 | 18 | 19 | class OptimizerResnet(Optimizer): 20 | 21 | def __init__(self, model=None, schedule=(int(32000. / (50000. / 128)), int(48000. / (50000. / 128))), lr=0.1, momentum=0.9, weight_decay=1.0e-4, warm_up_lr=0.01): 22 | super(OptimizerResnet, self).__init__(model) 23 | optimizer = optimizers.MomentumSGD(warm_up_lr, momentum) 24 | weight_decay = chainer.optimizer.WeightDecay(weight_decay) 25 | optimizer.setup(self.model) 26 | optimizer.add_hook(weight_decay) 27 | self.optimizer = optimizer 28 | self.schedule = schedule 29 | self.lr = lr 30 | self.warmup_lr = warm_up_lr 31 | self.momentum = momentum 32 | self.weight_decay = weight_decay 33 | 34 | def __call__(self, i): 35 | if i == 1: 36 | lr = self.lr 37 | print('finishded warming up') 38 | print('lr is changed: {} -> {}'.format(self.optimizer.lr, lr)) 39 | self.optimizer.lr = lr 40 | if i in self.schedule: 41 | lr = self.optimizer.lr / 10 42 | print('lr is changed: {} -> {}'.format(self.optimizer.lr, lr)) 43 | self.optimizer.lr = lr 44 | 45 | 46 | class OptimizerDense(Optimizer): 47 | 48 | def __init__(self, model=None, schedule=(150, 225), lr=0.1, momentum=0.9, weight_decay=1.0e-4): 49 | super(OptimizerDense, self).__init__(model) 50 | optimizer = optimizers.MomentumSGD(lr, momentum) 51 | weight_decay = chainer.optimizer.WeightDecay(weight_decay) 52 | optimizer.setup(self.model) 53 | optimizer.add_hook(weight_decay) 54 | self.optimizer = optimizer 55 | self.schedule = schedule 56 | self.lr = lr 57 | self.momentum = momentum 58 | self.weight_decay = weight_decay 59 | 60 | def __call__(self, i): 61 | if i in self.schedule: 62 | lr = self.optimizer.lr / 10 63 | print('lr is changed: {} -> {}'.format(self.optimizer.lr, lr)) 64 | self.optimizer.lr = lr 65 | 66 | 67 | class OptimizerWideRes(Optimizer): 68 | 69 | def __init__(self, model=None, schedule=(60, 120, 160), lr=0.1, momentum=0.9, weight_decay=5.0e-4): 70 | super(OptimizerWideRes, self).__init__(model) 71 | optimizer = optimizers.MomentumSGD(lr, momentum) 72 | weight_decay = chainer.optimizer.WeightDecay(weight_decay) 73 | optimizer.setup(self.model) 74 | optimizer.add_hook(weight_decay) 75 | self.optimizer = optimizer 76 | self.schedule = schedule 77 | self.lr = lr 78 | self.momentum = momentum 79 | self.weight_decay = weight_decay 80 | 81 | def __call__(self, i): 82 | if i in self.schedule: 83 | lr = self.optimizer.lr * 0.2 84 | print('lr is changed: {} -> {}'.format(self.optimizer.lr, lr)) 85 | self.optimizer.lr = lr 86 | 87 | 88 | class OptimizerSwapout(Optimizer): 89 | 90 | def __init__(self, model=None, schedule=(196, 224), lr=0.1, momentum=0.9, weight_decay=1.0e-4): 91 | super(OptimizerSwapout, self).__init__(model) 92 | optimizer = optimizers.MomentumSGD(lr, momentum) 93 | weight_decay = chainer.optimizer.WeightDecay(weight_decay) 94 | optimizer.setup(self.model) 95 | optimizer.add_hook(weight_decay) 96 | self.optimizer = optimizer 97 | self.schedule = schedule 98 | self.lr = lr 99 | self.momentum = momentum 100 | self.weight_decay = weight_decay 101 | 102 | def __call__(self, i): 103 | if i in self.schedule: 104 | lr = self.optimizer.lr / 10 105 | print('lr is changed: {} -> {}'.format(self.optimizer.lr, lr)) 106 | self.optimizer.lr = lr 107 | 108 | 109 | class OptimizerXception(Optimizer): 110 | 111 | def __init__(self, model=None, lr=0.045, momentum=0.9, weight_decay=1.0e-5, period=2): 112 | super(OptimizerXception, self).__init__(model) 113 | optimizer = optimizers.MomentumSGD(lr, momentum) 114 | weight_decay = chainer.optimizer.WeightDecay(weight_decay) 115 | optimizer.setup(self.model) 116 | optimizer.add_hook(weight_decay) 117 | self.optimizer = optimizer 118 | self.lr = lr 119 | self.momentum = momentum 120 | self.weight_decay = weight_decay 121 | self.period = int(period) 122 | 123 | def __call__(self, i): 124 | if i % self.period == 0: 125 | lr = self.optimizer.lr * 0.94 126 | print('lr is changed: {} -> {}'.format(self.optimizer.lr, lr)) 127 | self.optimizer.lr = lr 128 | 129 | 130 | class OptimizerVGG(Optimizer): 131 | 132 | def __init__(self, model=None, lr=0.01, momentum=0.9, weight_decay=5.0e-4): 133 | super(OptimizerVGG, self).__init__(model) 134 | optimizer = optimizers.MomentumSGD(lr, momentum) 135 | weight_decay = chainer.optimizer.WeightDecay(weight_decay) 136 | optimizer.setup(self.model) 137 | optimizer.add_hook(weight_decay) 138 | self.optimizer = optimizer 139 | self.lr = lr 140 | self.momentum = momentum 141 | self.weight_decay = weight_decay 142 | 143 | def __call__(self, i): 144 | # 150 epoch means (0.94 ** 75) * lr 145 | # if lr is 0.01, then (0.94 ** 75) * 0.01 is 0.0001 at the end 146 | if i % 2 == 0: 147 | lr = self.optimizer.lr * 0.94 148 | print('lr is changed: {} -> {}'.format(self.optimizer.lr, lr)) 149 | self.optimizer.lr = lr 150 | 151 | 152 | class OptimizerGooglenet(Optimizer): 153 | 154 | def __init__(self, model=None, lr=0.0015, momentum=0.9, weight_decay=2.0e-4): 155 | super(OptimizerGooglenet, self).__init__(model) 156 | optimizer = optimizers.MomentumSGD(lr, momentum) 157 | weight_decay = chainer.optimizer.WeightDecay(weight_decay) 158 | optimizer.setup(self.model) 159 | optimizer.add_hook(weight_decay) 160 | self.optimizer = optimizer 161 | self.lr = lr 162 | self.momentum = momentum 163 | self.weight_decay = weight_decay 164 | 165 | def __call__(self, i): 166 | if i % 8 == 0: 167 | lr = self.optimizer.lr * 0.96 168 | print('lr is changed: {} -> {}'.format(self.optimizer.lr, lr)) 169 | self.optimizer.lr = lr 170 | 171 | 172 | class OptimizerNetworkInNetwork(Optimizer): 173 | 174 | def __init__(self, model=None, lr=0.1, momentum=0.9, weight_decay=1.0e-4, schedule=(int(1.0e5 / (50000. / 128)), )): 175 | super(OptimizerNetworkInNetwork, self).__init__(model) 176 | optimizer = optimizers.MomentumSGD(lr, momentum) 177 | weight_decay = chainer.optimizer.WeightDecay(weight_decay) 178 | optimizer.setup(self.model) 179 | optimizer.add_hook(weight_decay) 180 | self.optimizer = optimizer 181 | self.lr = lr 182 | self.momentum = momentum 183 | self.weight_decay = weight_decay 184 | self.schedule = schedule 185 | 186 | def __call__(self, i): 187 | if i in self.schedule: 188 | lr = self.optimizer.lr / 10 189 | print('lr is changed: {} -> {}'.format(self.optimizer.lr, lr)) 190 | self.optimizer.lr = lr 191 | --------------------------------------------------------------------------------