├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── accuracy.jpg
├── googlenet.py
├── loss.jpg
├── main.py
└── nutszebra_optimizer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask stuff:
57 | instance/
58 | .webassets-cache
59 |
60 | # Scrapy stuff:
61 | .scrapy
62 |
63 | # Sphinx documentation
64 | docs/_build/
65 |
66 | # PyBuilder
67 | target/
68 |
69 | # IPython Notebook
70 | .ipynb_checkpoints
71 |
72 | # pyenv
73 | .python-version
74 |
75 | # celery beat schedule file
76 | celerybeat-schedule
77 |
78 | # dotenv
79 | .env
80 |
81 | # virtualenv
82 | venv/
83 | ENV/
84 |
85 | # Spyder project settings
86 | .spyderproject
87 |
88 | # Rope project settings
89 | .ropeproject
90 |
91 | # trainer
92 | # trainer
93 |
94 | # model
95 | model
96 | model.dot
97 |
98 | # log
99 | log
100 | log.json
101 |
102 | # cifar10
103 | cifar10.pkl
104 | cifar-10-batches-py
105 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "trainer"]
2 | path = trainer
3 | url = https://github.com/nutszebra/trainer.git
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2016 ikki kishida
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # What's this
2 | Implementation of GoogLeNet by chainer
3 |
4 |
5 | # Dependencies
6 |
7 | git clone https://github.com/nutszebra/googlenet.git
8 | cd googlenet
9 | git submodule init
10 | git submodule update
11 |
12 | # How to run
13 | python main.py -p ./ -g 0
14 |
15 |
16 | # Details about my implementation
17 |
18 | * Data augmentation
19 | Train: Pictures are randomly resized in the range of [256, 512], then 224x224 patches are extracted randomly and are normalized locally. Horizontal flipping is applied with 0.5 probability.
20 | Test: Pictures are resized to 384x384, then they are normalized locally. Single image test is used to calculate total accuracy.
21 |
22 | * Auxiliary classifiers
23 | No implementation
24 |
25 | * Learning rate
26 | As [[1]][Paper] said, learning rate are multiplied by 0.96 at every 8 epochs. The description about initial learning rate can't be found in [[1]][Paper], so initial learning is setted as 0.0015 that is found in [[2]][Paper2].
27 |
28 | * Weight decay
29 | The description about weight decay can't be found in [[1]][Paper], so by using [[2]][Paper2] and [[3]][Paper3] I guessed that weight decay is 2.0*10^-4.
30 |
31 | # Cifar10 result
32 |
33 | | network | depth | total accuracy (%) |
34 | |:---------------------|--------|-------------------:|
35 | | my implementation | 22 | 91.33 |
36 |
37 |
38 |
39 |
40 | # References
41 | Going Deeper with Convolutions [[1]][Paper]
42 | Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift [[2]][Paper2]
43 | Rethinking the Inception Architecture for Computer Vision [[3]][Paper3]
44 | [paper]: https://arxiv.org/abs/1409.4842 "Paper"
45 | [paper2]: https://arxiv.org/abs/1502.03167 "Paper2"
46 | [paper3]: https://arxiv.org/abs/1512.00567 "Paper3"
47 |
--------------------------------------------------------------------------------
/accuracy.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nutszebra/googlenet/e5e01b25085443a4b1f616c6ddd6d7ad24e2f363/accuracy.jpg
--------------------------------------------------------------------------------
/googlenet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import functools
3 | import chainer.links as L
4 | import chainer.functions as F
5 | from collections import defaultdict
6 | import nutszebra_chainer
7 |
8 |
9 | class Inception(nutszebra_chainer.Model):
10 |
11 | def __init__(self, in_channel, conv1x1=64, reduce3x3=96, conv3x3=128, reduce5x5=16, conv5x5=32, pool_proj=32):
12 | super(Inception, self).__init__()
13 | modules = []
14 | modules.append(('conv1x1', L.Convolution2D(in_channel, conv1x1, 1, 1, 0)))
15 | modules.append(('reduce3x3', L.Convolution2D(in_channel, reduce3x3, 1, 1, 0)))
16 | modules.append(('conv3x3', L.Convolution2D(reduce3x3, conv3x3, 3, 1, 1)))
17 | modules.append(('reduce5x5', L.Convolution2D(in_channel, reduce5x5, 1, 1, 0)))
18 | modules.append(('conv5x5', L.Convolution2D(reduce5x5, conv5x5, 5, 1, 2)))
19 | modules.append(('pool_proj', L.Convolution2D(in_channel, pool_proj, 1, 1, 0)))
20 | # register layers
21 | [self.add_link(*link) for link in modules]
22 | self.modules = modules
23 |
24 | def weight_initialization(self):
25 | for name, link in self.modules:
26 | self[name].W.data = self.weight_relu_initialization(link)
27 | self[name].b.data = self.bias_initialization(link, constant=0)
28 |
29 | def __call__(self, x, train=False):
30 | a = F.relu(self.conv1x1(x))
31 | b = F.relu(self.conv3x3(F.relu(self.reduce3x3(x))))
32 | c = F.relu(self.conv5x5(F.relu(self.reduce5x5(x))))
33 | d = F.relu(self.pool_proj(F.max_pooling_2d(x, ksize=(3, 3), stride=(1, 1), pad=(1, 1))))
34 | return F.concat((a, b, c, d), axis=1)
35 |
36 | @staticmethod
37 | def _conv_count_parameters(conv):
38 | return functools.reduce(lambda a, b: a * b, conv.W.data.shape)
39 |
40 | def count_parameters(self):
41 | count = 0
42 | for name, link in self.modules:
43 | count += Inception._conv_count_parameters(link)
44 | return count
45 |
46 |
47 | class Googlenet(nutszebra_chainer.Model):
48 |
49 | def __init__(self, category_num):
50 | super(Googlenet, self).__init__()
51 | modules = []
52 | modules += [('conv1', L.Convolution2D(3, 64, (7, 7), (2, 2), (3, 3)))]
53 | modules += [('conv2_1x1', L.Convolution2D(64, 64, (1, 1), (1, 1), (0, 0)))]
54 | modules += [('conv2_3x3', L.Convolution2D(64, 192, (3, 3), (1, 1), (1, 1)))]
55 | modules += [('inception3a', Inception(192, 64, 96, 128, 16, 32, 32))]
56 | modules += [('inception3b', Inception(256, 128, 128, 192, 32, 96, 64))]
57 | modules += [('inception4a', Inception(480, 192, 96, 208, 16, 48, 64))]
58 | modules += [('inception4b', Inception(512, 160, 112, 224, 24, 64, 64))]
59 | modules += [('inception4c', Inception(512, 128, 128, 256, 24, 64, 64))]
60 | modules += [('inception4d', Inception(512, 112, 144, 288, 32, 64, 64))]
61 | modules += [('inception4e', Inception(528, 256, 160, 320, 32, 128, 128))]
62 | modules += [('inception5a', Inception(832, 256, 160, 320, 32, 128, 128))]
63 | modules += [('inception5b', Inception(832, 384, 192, 384, 48, 128, 128))]
64 | modules += [('linear', L.Linear(1024, category_num))]
65 | # register layers
66 | [self.add_link(*link) for link in modules]
67 | self.modules = modules
68 | self.name = 'googlenet_{}'.format(category_num)
69 |
70 | def count_parameters(self):
71 | count = 0
72 | count += functools.reduce(lambda a, b: a * b, self.conv1.W.data.shape)
73 | count += functools.reduce(lambda a, b: a * b, self.conv2_1x1.W.data.shape)
74 | count += functools.reduce(lambda a, b: a * b, self.conv2_3x3.W.data.shape)
75 | count += self.inception3a.count_parameters()
76 | count += self.inception3b.count_parameters()
77 | count += self.inception4a.count_parameters()
78 | count += self.inception4b.count_parameters()
79 | count += self.inception4c.count_parameters()
80 | count += self.inception4d.count_parameters()
81 | count += self.inception4e.count_parameters()
82 | count += self.inception5a.count_parameters()
83 | count += self.inception5b.count_parameters()
84 | count += functools.reduce(lambda a, b: a * b, self.linear.W.data.shape)
85 | return count
86 |
87 | def weight_initialization(self):
88 | self.conv1.W.data = self.weight_relu_initialization(self.conv1)
89 | self.conv1.b.data = self.bias_initialization(self.conv1, constant=0)
90 | self.conv2_1x1.W.data = self.weight_relu_initialization(self.conv2_1x1)
91 | self.conv2_1x1.b.data = self.bias_initialization(self.conv2_1x1, constant=0)
92 | self.conv2_3x3.W.data = self.weight_relu_initialization(self.conv2_3x3)
93 | self.conv2_3x3.b.data = self.bias_initialization(self.conv2_3x3, constant=0)
94 | self.inception3a.weight_initialization()
95 | self.inception3b.weight_initialization()
96 | self.inception4a.weight_initialization()
97 | self.inception4b.weight_initialization()
98 | self.inception4c.weight_initialization()
99 | self.inception4d.weight_initialization()
100 | self.inception4e.weight_initialization()
101 | self.inception5a.weight_initialization()
102 | self.inception5b.weight_initialization()
103 | self.linear.W.data = self.weight_relu_initialization(self.linear)
104 | self.linear.b.data = self.bias_initialization(self.linear, constant=0)
105 |
106 | def __call__(self, x, train=True):
107 | h = F.relu(self.conv1(x))
108 | h = F.max_pooling_2d(h, ksize=(3, 3), stride=(2, 2), pad=(1, 1))
109 | h = F.relu(self.conv2_1x1(h))
110 | h = F.relu(self.conv2_3x3(h))
111 | h = F.max_pooling_2d(h, ksize=(3, 3), stride=(2, 2), pad=(1, 1))
112 | h = self.inception3a(h)
113 | h = self.inception3b(h)
114 | h = F.max_pooling_2d(h, ksize=(3, 3), stride=(2, 2), pad=(1, 1))
115 | h = self.inception4a(h)
116 | h = self.inception4b(h)
117 | h = self.inception4c(h)
118 | h = self.inception4d(h)
119 | h = self.inception4e(h)
120 | h = F.max_pooling_2d(h, ksize=(3, 3), stride=(2, 2), pad=(1, 1))
121 | h = self.inception5a(h)
122 | h = F.relu(self.inception5b(h))
123 | num, categories, y, x = h.data.shape
124 | # global average pooling
125 | h = F.reshape(F.average_pooling_2d(h, (y, x)), (num, categories))
126 | h = F.dropout(h, ratio=0.4, train=train)
127 | h = self.linear(h)
128 | return h
129 |
130 | def calc_loss(self, y, t):
131 | loss = F.softmax_cross_entropy(y, t)
132 | return loss
133 |
134 | def accuracy(self, y, t, xp=np):
135 | y.to_cpu()
136 | t.to_cpu()
137 | indices = np.where((t.data == np.argmax(y.data, axis=1)) == True)[0]
138 | accuracy = defaultdict(int)
139 | for i in indices:
140 | accuracy[t.data[i]] += 1
141 | indices = np.where((t.data == np.argmax(y.data, axis=1)) == False)[0]
142 | false_accuracy = defaultdict(int)
143 | false_y = np.argmax(y.data, axis=1)
144 | for i in indices:
145 | false_accuracy[(t.data[i], false_y[i])] += 1
146 | return accuracy, false_accuracy
147 |
--------------------------------------------------------------------------------
/loss.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nutszebra/googlenet/e5e01b25085443a4b1f616c6ddd6d7ad24e2f363/loss.jpg
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('./trainer')
3 | import argparse
4 | import googlenet
5 | import nutszebra_data_augmentation
6 | import nutszebra_cifar10
7 | import nutszebra_optimizer
8 |
9 | if __name__ == '__main__':
10 |
11 | parser = argparse.ArgumentParser(description='cifar10')
12 | parser.add_argument('--load_model', '-m',
13 | default=None,
14 | help='trained model')
15 | parser.add_argument('--load_optimizer', '-o',
16 | default=None,
17 | help='optimizer for trained model')
18 | parser.add_argument('--load_log', '-l',
19 | default=None,
20 | help='optimizer for trained model')
21 | parser.add_argument('--save_path', '-p',
22 | default='./',
23 | help='model and optimizer will be saved every epoch')
24 | parser.add_argument('--epoch', '-e', type=int,
25 | default=200,
26 | help='maximum epoch')
27 | parser.add_argument('--batch', '-b', type=int,
28 | default=32,
29 | help='mini batch number')
30 | parser.add_argument('--gpu', '-g', type=int,
31 | default=-1,
32 | help='-1 means cpu mode, put gpu id here')
33 | parser.add_argument('--start_epoch', '-s', type=int,
34 | default=1,
35 | help='start from this epoch')
36 | parser.add_argument('--train_batch_divide', '-trb', type=int,
37 | default=4,
38 | help='divid batch number by this')
39 | parser.add_argument('--test_batch_divide', '-teb', type=int,
40 | default=4,
41 | help='divid batch number by this')
42 | parser.add_argument('--lr', '-lr', type=float,
43 | default=0.0015,
44 | help='leraning rate')
45 |
46 | args = parser.parse_args().__dict__
47 | lr = args.pop('lr')
48 |
49 | print('generating model')
50 | model = googlenet.Googlenet(10)
51 | print('Done')
52 | optimizer = nutszebra_optimizer.OptimizerGooglenet(model, lr=lr)
53 | args['model'] = model
54 | args['optimizer'] = optimizer
55 | args['da'] = nutszebra_data_augmentation.DataAugmentationCifar10NormalizeBigger
56 | main = nutszebra_cifar10.TrainCifar10(**args)
57 | main.run()
58 |
--------------------------------------------------------------------------------
/nutszebra_optimizer.py:
--------------------------------------------------------------------------------
1 | import chainer
2 | from chainer import optimizers
3 | import nutszebra_basic_print
4 |
5 |
6 | class Optimizer(object):
7 |
8 | def __init__(self, model=None):
9 | self.model = model
10 | self.optimizer = None
11 |
12 | def __call__(self, i):
13 | pass
14 |
15 | def update(self):
16 | self.optimizer.update()
17 |
18 |
19 | class OptimizerResnet(Optimizer):
20 |
21 | def __init__(self, model=None, schedule=(int(32000. / (50000. / 128)), int(48000. / (50000. / 128))), lr=0.1, momentum=0.9, weight_decay=1.0e-4, warm_up_lr=0.01):
22 | super(OptimizerResnet, self).__init__(model)
23 | optimizer = optimizers.MomentumSGD(warm_up_lr, momentum)
24 | weight_decay = chainer.optimizer.WeightDecay(weight_decay)
25 | optimizer.setup(self.model)
26 | optimizer.add_hook(weight_decay)
27 | self.optimizer = optimizer
28 | self.schedule = schedule
29 | self.lr = lr
30 | self.warmup_lr = warm_up_lr
31 | self.momentum = momentum
32 | self.weight_decay = weight_decay
33 |
34 | def __call__(self, i):
35 | if i == 1:
36 | lr = self.lr
37 | print('finishded warming up')
38 | print('lr is changed: {} -> {}'.format(self.optimizer.lr, lr))
39 | self.optimizer.lr = lr
40 | if i in self.schedule:
41 | lr = self.optimizer.lr / 10
42 | print('lr is changed: {} -> {}'.format(self.optimizer.lr, lr))
43 | self.optimizer.lr = lr
44 |
45 |
46 | class OptimizerDense(Optimizer):
47 |
48 | def __init__(self, model=None, schedule=(150, 225), lr=0.1, momentum=0.9, weight_decay=1.0e-4):
49 | super(OptimizerDense, self).__init__(model)
50 | optimizer = optimizers.MomentumSGD(lr, momentum)
51 | weight_decay = chainer.optimizer.WeightDecay(weight_decay)
52 | optimizer.setup(self.model)
53 | optimizer.add_hook(weight_decay)
54 | self.optimizer = optimizer
55 | self.schedule = schedule
56 | self.lr = lr
57 | self.momentum = momentum
58 | self.weight_decay = weight_decay
59 |
60 | def __call__(self, i):
61 | if i in self.schedule:
62 | lr = self.optimizer.lr / 10
63 | print('lr is changed: {} -> {}'.format(self.optimizer.lr, lr))
64 | self.optimizer.lr = lr
65 |
66 |
67 | class OptimizerWideRes(Optimizer):
68 |
69 | def __init__(self, model=None, schedule=(60, 120, 160), lr=0.1, momentum=0.9, weight_decay=5.0e-4):
70 | super(OptimizerWideRes, self).__init__(model)
71 | optimizer = optimizers.MomentumSGD(lr, momentum)
72 | weight_decay = chainer.optimizer.WeightDecay(weight_decay)
73 | optimizer.setup(self.model)
74 | optimizer.add_hook(weight_decay)
75 | self.optimizer = optimizer
76 | self.schedule = schedule
77 | self.lr = lr
78 | self.momentum = momentum
79 | self.weight_decay = weight_decay
80 |
81 | def __call__(self, i):
82 | if i in self.schedule:
83 | lr = self.optimizer.lr * 0.2
84 | print('lr is changed: {} -> {}'.format(self.optimizer.lr, lr))
85 | self.optimizer.lr = lr
86 |
87 |
88 | class OptimizerSwapout(Optimizer):
89 |
90 | def __init__(self, model=None, schedule=(196, 224), lr=0.1, momentum=0.9, weight_decay=1.0e-4):
91 | super(OptimizerSwapout, self).__init__(model)
92 | optimizer = optimizers.MomentumSGD(lr, momentum)
93 | weight_decay = chainer.optimizer.WeightDecay(weight_decay)
94 | optimizer.setup(self.model)
95 | optimizer.add_hook(weight_decay)
96 | self.optimizer = optimizer
97 | self.schedule = schedule
98 | self.lr = lr
99 | self.momentum = momentum
100 | self.weight_decay = weight_decay
101 |
102 | def __call__(self, i):
103 | if i in self.schedule:
104 | lr = self.optimizer.lr / 10
105 | print('lr is changed: {} -> {}'.format(self.optimizer.lr, lr))
106 | self.optimizer.lr = lr
107 |
108 |
109 | class OptimizerXception(Optimizer):
110 |
111 | def __init__(self, model=None, lr=0.045, momentum=0.9, weight_decay=1.0e-5, period=2):
112 | super(OptimizerXception, self).__init__(model)
113 | optimizer = optimizers.MomentumSGD(lr, momentum)
114 | weight_decay = chainer.optimizer.WeightDecay(weight_decay)
115 | optimizer.setup(self.model)
116 | optimizer.add_hook(weight_decay)
117 | self.optimizer = optimizer
118 | self.lr = lr
119 | self.momentum = momentum
120 | self.weight_decay = weight_decay
121 | self.period = int(period)
122 |
123 | def __call__(self, i):
124 | if i % self.period == 0:
125 | lr = self.optimizer.lr * 0.94
126 | print('lr is changed: {} -> {}'.format(self.optimizer.lr, lr))
127 | self.optimizer.lr = lr
128 |
129 |
130 | class OptimizerVGG(Optimizer):
131 |
132 | def __init__(self, model=None, lr=0.01, momentum=0.9, weight_decay=5.0e-4):
133 | super(OptimizerVGG, self).__init__(model)
134 | optimizer = optimizers.MomentumSGD(lr, momentum)
135 | weight_decay = chainer.optimizer.WeightDecay(weight_decay)
136 | optimizer.setup(self.model)
137 | optimizer.add_hook(weight_decay)
138 | self.optimizer = optimizer
139 | self.lr = lr
140 | self.momentum = momentum
141 | self.weight_decay = weight_decay
142 |
143 | def __call__(self, i):
144 | # 150 epoch means (0.94 ** 75) * lr
145 | # if lr is 0.01, then (0.94 ** 75) * 0.01 is 0.0001 at the end
146 | if i % 2 == 0:
147 | lr = self.optimizer.lr * 0.94
148 | print('lr is changed: {} -> {}'.format(self.optimizer.lr, lr))
149 | self.optimizer.lr = lr
150 |
151 |
152 | class OptimizerGooglenet(Optimizer):
153 |
154 | def __init__(self, model=None, lr=0.0015, momentum=0.9, weight_decay=2.0e-4):
155 | super(OptimizerGooglenet, self).__init__(model)
156 | optimizer = optimizers.MomentumSGD(lr, momentum)
157 | weight_decay = chainer.optimizer.WeightDecay(weight_decay)
158 | optimizer.setup(self.model)
159 | optimizer.add_hook(weight_decay)
160 | self.optimizer = optimizer
161 | self.lr = lr
162 | self.momentum = momentum
163 | self.weight_decay = weight_decay
164 |
165 | def __call__(self, i):
166 | if i % 8 == 0:
167 | lr = self.optimizer.lr * 0.96
168 | print('lr is changed: {} -> {}'.format(self.optimizer.lr, lr))
169 | self.optimizer.lr = lr
170 |
171 |
172 | class OptimizerNetworkInNetwork(Optimizer):
173 |
174 | def __init__(self, model=None, lr=0.1, momentum=0.9, weight_decay=1.0e-4, schedule=(int(1.0e5 / (50000. / 128)), )):
175 | super(OptimizerNetworkInNetwork, self).__init__(model)
176 | optimizer = optimizers.MomentumSGD(lr, momentum)
177 | weight_decay = chainer.optimizer.WeightDecay(weight_decay)
178 | optimizer.setup(self.model)
179 | optimizer.add_hook(weight_decay)
180 | self.optimizer = optimizer
181 | self.lr = lr
182 | self.momentum = momentum
183 | self.weight_decay = weight_decay
184 | self.schedule = schedule
185 |
186 | def __call__(self, i):
187 | if i in self.schedule:
188 | lr = self.optimizer.lr / 10
189 | print('lr is changed: {} -> {}'.format(self.optimizer.lr, lr))
190 | self.optimizer.lr = lr
191 |
--------------------------------------------------------------------------------