├── .gitignore
├── LICENSE
├── README.md
├── clr.py
├── images
├── lr.png
├── momentum.png
├── one_cycle_lr.png
├── one_cycle_momentum.png
└── weight_decay.png
├── models
├── mobilenet
│ ├── find_lr_schedule.py
│ ├── find_momentum_schedule.py
│ ├── find_weight_decay_schedule.py
│ ├── mobilenets.py
│ ├── train_cifar_10.py
│ └── weights
│ │ ├── losses.npy
│ │ ├── lrs.npy
│ │ ├── mobilenet_v2 - 9033.h5
│ │ ├── mobilenet_v2.h5
│ │ ├── momentum
│ │ ├── momentum-0.9
│ │ │ ├── losses.npy
│ │ │ └── lrs.npy
│ │ ├── momentum-0.95
│ │ │ ├── losses.npy
│ │ │ └── lrs.npy
│ │ └── momentum-0.99
│ │ │ ├── losses.npy
│ │ │ └── lrs.npy
│ │ └── weight_decay
│ │ ├── weight_decay-1e-05
│ │ ├── losses.npy
│ │ └── lrs.npy
│ │ ├── weight_decay-1e-06
│ │ ├── losses.npy
│ │ └── lrs.npy
│ │ ├── weight_decay-1e-07
│ │ ├── losses.npy
│ │ └── lrs.npy
│ │ ├── weight_decay-3e-05
│ │ ├── losses.npy
│ │ └── lrs.npy
│ │ └── weight_decay-3e-06
│ │ ├── losses.npy
│ │ └── lrs.npy
└── small
│ ├── find_lr_schedule.py
│ ├── find_momentum_schedule.py
│ ├── find_weight_decay_schedule.py
│ ├── model.py
│ ├── train_cifar_10.py
│ └── weights
│ ├── losses.npy
│ ├── lrs.npy
│ ├── mini_vgg.h5
│ ├── momentum
│ ├── momentum-0.9
│ │ ├── losses.npy
│ │ └── lrs.npy
│ ├── momentum-0.95
│ │ ├── losses.npy
│ │ └── lrs.npy
│ └── momentum-0.99
│ │ ├── losses.npy
│ │ └── lrs.npy
│ └── weight_decay
│ ├── weight_decay-0.0001
│ ├── losses.npy
│ └── lrs.npy
│ ├── weight_decay-0.0003
│ ├── losses.npy
│ └── lrs.npy
│ ├── weight_decay-0.001
│ ├── losses.npy
│ └── lrs.npy
│ ├── weight_decay-0.003
│ ├── losses.npy
│ └── lrs.npy
│ ├── weight_decay-1e-05
│ ├── losses.npy
│ └── lrs.npy
│ ├── weight_decay-1e-06
│ ├── losses.npy
│ └── lrs.npy
│ └── weight_decay-1e-07
│ ├── losses.npy
│ └── lrs.npy
└── plot_clr.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
103 | # PyCharm
104 | .idea/*
105 |
106 | # Weights
107 | weights/*.h5
108 | models/**/*.h5
109 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Somshubra Majumdar
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # One Cycle Learning Rate Policy for Keras
2 | Implementation of One-Cycle Learning rate policy from the papers by Leslie N. Smith.
3 |
4 | - [A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay](https://arxiv.org/abs/1803.09820)
5 | - [Super-Convergence: Very Fast Training of Residual Networks Using Large Learning Rates](https://arxiv.org/abs/1708.07120)
6 |
7 | Contains two Keras callbacks, `LRFinder` and `OneCycleLR` which are ported from the PyTorch *Fast.ai* library.
8 |
9 | # What is One Cycle Learning Rate
10 | It is the combination of gradually increasing learning rate, and optionally, gradually decreasing the momentum during the first half of the cycle, then gradually decreasing the learning rate and optionally increasing the momentum during the latter half of the cycle.
11 |
12 | Finally, in a certain percentage of the end of the cycle, the learning rate is sharply reduced every epoch.
13 |
14 | The Learning rate schedule is visualized as :
15 |
16 |
17 |
18 | The Optional Momentum schedule is visualized as :
19 |
20 |
21 |
22 | # Usage
23 |
24 | ## Finding a good learning rate
25 | Use `LRFinder` to obtain a loss plot, and visually inspect it to determine the initial loss plot. Provided below is an example, used for the `MiniMobileNetV2` model.
26 |
27 | An example script has been provided in `find_lr_schedule.py` inside the `models/mobilenet/`.
28 |
29 | Essentially,
30 |
31 | ```python
32 | from clr import LRFinder
33 |
34 | lr_callback = LRFinder(num_samples, batch_size,
35 | minimum_lr, maximum_lr,
36 | # validation_data=(X_val, Y_val),
37 | lr_scale='exp', save_dir='path/to/save/directory')
38 |
39 | # Ensure that number of epochs = 1 when calling fit()
40 | model.fit(X, Y, epochs=1, batch_size=batch_size, callbacks=[lr_callback])
41 | ```
42 | The above callback does a few things.
43 |
44 | - Must supply number of samples in the dataset (here, 50k from CIFAR 10) and the batch size that will be used during training.
45 | - `lr_scale` is set to `exp` - useful when searching over a large range of learning rates. Set to `linear` to search a smaller space.
46 | - `save_dir` - Automatic saving of the results of LRFinder on some directory path specified. This is highly encouraged.
47 | - `validation_data` - provide the validation data as a tuple to use that for the loss plot instead of the training batch loss. Since the validation dataset can be very large, we will randomly sample `k` batches (k * batch_size) from the validation set to provide quick estimate of the validation loss. The default value of `k` can be changed by changing `validation_sample_rate`
48 |
49 | **Note : When using this, be careful about setting the learning rate, momentum and weight decay schedule. The loss plots will be more erratic due to the sampling of the validation set.**
50 |
51 | **NOTE 2 :**
52 |
53 | - It is faster to get the learning rate without using `validation_data`, and then find the weight decay and momentum based on that learning rate while using `validation_data`.
54 | - You can also use `LRFinder` to find the optimal weight decay and momentum values using the examples `find_momentum_schedule.py` and `find_weight_decay_schedule.py` inside `models/mobilenet/` folder.
55 |
56 | To visualize the plot, there are two ways -
57 |
58 | - Use `lr_callback.plot_schedule()` after the fit() call. This uses the current training session results.
59 | - Use class method `LRFinder.plot_schedule_from_dir('path/to/save/directory')` to visualize the plot separately from the training session. This only works if you used the `save_dir` argument to save the results of the search to some location.
60 |
61 | ## Finding the optimal Momentum
62 |
63 | Use the `find_momentum_schedule.py` script inside `models/mobilenet/` for an example.
64 |
65 | Some notes :
66 |
67 | - Use a grid search over a few possible momentum values, such as `[0.8, 0.85, 0.9, 0.95, 0.99]`. Use `linear` as the `lr_scale` argument value.
68 | - Set the momentum value manually to the SGD optimizer before compiling the model.
69 | - Plot the curve at the end and visually see which momentum value yields the least noisy / lowest losses overall on the plot. The absolute value of the loss plot is not very important as much as the curve.
70 |
71 | - It is better to supply the `validation_data` here.
72 | - The plot will be very noisy, so if you wish, can use a larger value of `loss_smoothing_beta` (such as `0.99` or `0.995`)
73 | - The actual curve values doesnt matter as much as what is overall curve movement. Choose the value which is more steady and tries to get the lowest value even at large learning rates.
74 |
75 | ## Finding the optimal Weight Decay
76 |
77 | Use the `find_weight_decay_schedule.py` script inside `models/mobilenet/` for an example
78 |
79 | Some notes :
80 |
81 | - Use a grid search over a few weight decay values, such as `[1e-3, 1e-4, 1e-5, 1e-6, 1e-7]`. Call this "coarse search" and use `linear` for the `lr_scale` argument.
82 | - Use a grid search over a select few weight decay values, such as `[3e-7, 1e-7, 3e-6]`. Call this "fine search" and use `linear` scale for the `lr_scale` argument.
83 | - Set the weight decay value manually to the model when building the model.
84 | - Plot the curve at the end and visually see which weight decay value yields the least noisy / lowest losses overall on the plot. The absolute value of the loss plot is not very important as much as the curve.
85 |
86 | - It is better to supply the `validation_data` here.
87 | - The plot will be very noisy, so if you wish, can use a larger value of `loss_smoothing_beta` (such as `0.99` or `0.995`)
88 | - The actual curve values doesnt matter as much as what is overall curve movement. Choose the value which is more steady and tries to get the lowest value even at large learning rates.
89 |
90 |
91 | ## Interpreting the plot
92 |
93 | ### Learning Rate
94 |
95 |
96 |
97 |
98 |
99 | Consider the above plot from using the `LRFinder` on the MiniMobileNetV2 model. In particular, there are a few regions above that we need to carefully interpret.
100 |
101 | **Note : The values are in log 10 scale (since `exp` was used for `lr_scale`)** ; All values discussed will be based on the x-axis (learning rate) :
102 |
103 | - After the -1.5 point on the graph, the loss becomes erratic
104 | - After the 0.5 point on the graph, the loss is noisy but doesn't decrease any further.
105 | - **-1.7** is the last relatively smooth portion before the **-1.5** region. To be safe, we can choose to move a little more to the left, closer to -1.8, but this will reduce the performance.
106 | - It is usually important to visualize the first 2-3 epochs of `OneCycleLR` training with values close to these edges to determine which is the best.
107 |
108 | ### Momentum
109 |
110 | Using the above learning rate, use this information to next calculate the optimal momentum (`find_momentum_schedule.py`)
111 |
112 |
113 |
114 |
115 |
116 | See the notes in the `Finding the optimal momentum` section on how to interpret the plot.
117 |
118 | ### Weight Decay
119 |
120 | Similarly, it is possible to use the above learning rate and momentum values to calculate the optimal weight decay (`find_weight_decay_schedule.py`).
121 |
122 | **Note : Due to large learning rates acting as a strong regularizer, other regularization techniques like weight decay and dropout should be decreased significantly to properly train the model.**
123 |
124 |
125 |
126 |
127 |
128 | It is best to search a range of regularization strength between 1e-3 to 1e-7 first, and then fine-search the region that provided the best overall plot.
129 |
130 | See the notes in the `Finding the optimal weight decay` section on how to interpret the plot.
131 |
132 | ## Training with `OneCycleLR`
133 | Once we find the maximum learning rate, we can then move onto using the `OneCycleLR` callback with SGD to train our model.
134 |
135 | ```python
136 | from clr import OneCycleLR
137 |
138 | lr_manager = OneCycleLR(num_samples, num_epoch, batch_size, max_lr
139 | end_percentage=0.1, scale_percentage=None,
140 | maximum_momentum=0.95, minimum_momentum=0.85)
141 |
142 | model.fit(X, Y, epochs=EPOCHS, batch_size=batch_size, callbacks=[model_checkpoint, lr_manager],
143 | ...)
144 | ```
145 |
146 | There are many parameters, but a few of the important ones :
147 | - Must provide a lot of training information - `number of samples`, `number of epochs`, `batch size` and `max learning rate`
148 | - `end_percentage` is used to determine what percentage of the training epochs will be used for steep reduction in the learning rate. At its miminum, the lowest learning rate will be calculated as 1/1000th of the `max_lr` provided.
149 | - `scale_percentage` is a confusing parameter. It dictates the scaling factor of the learning rate in the second half of the training cycle. **It is best to test this out visually using the `plot_clr.py` script to ensure there are no mistakes**. Leaving it as None defaults to using the same percentage as the provided `end_percentage`.
150 | - `maximum/minimum_momentum` are preset according to the paper and `Fast.ai`. However, if you don't wish to scale it, set both to the same value, generally `0.9` is preferred as the momentum value for SGD. If you don't want to update the momentum / are not using SGD (not adviseable) - set both to None to ignore the momentum updates.
151 |
152 | # Results
153 |
154 | - **-1.7** is chosen to be the maximum learning rate (in log10 space) for the `OneCycleLR` schedule. Since this is in log10 scale, we use `10 ^ (x)` to get the actual learning maximum learning rate. Here, `10 ^ -1.7 ~ 0.019999`. Therefore, we round up to a **maximum learning rate of 0.02**
155 | - **0.9** is chosen as the maximum momentum from the momentum plot. Using Cyclic Momentum updates, choose a slightly lower value (**0.85**) as the minimum for faster training.
156 | - **3e-6** is chosen as the the weight decay factor.
157 |
158 | For the MiniMobileNetV2 model, 2 passes of the OneCycle LR with SGD (40 epochs - max lr = 0.02, 30 epochs - max lr = 0.005) obtained 90.33%. This may not seem like much, but this is a model with only 650k parameters, and in comparison, the same model trained on Adam with initial learning rate 2e-3 did not converge to the same score in over 100 epochs (89.14%).
159 |
160 | # Requirements
161 | - Keras 2.1.6+
162 | - Tensorflow (tested) / Theano / CNTK for the backend
163 | - matplotlib to visualize the plots.
164 |
--------------------------------------------------------------------------------
/clr.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import warnings
4 |
5 | from keras.callbacks import Callback
6 | from keras import backend as K
7 |
8 |
9 | # Code is ported from https://github.com/fastai/fastai
10 | class OneCycleLR(Callback):
11 | def __init__(self,
12 | num_samples,
13 | batch_size,
14 | max_lr,
15 | end_percentage=0.1,
16 | scale_percentage=None,
17 | maximum_momentum=0.95,
18 | minimum_momentum=0.85,
19 | verbose=True):
20 | """ This callback implements a cyclical learning rate policy (CLR).
21 | This is a special case of Cyclic Learning Rates, where we have only 1 cycle.
22 | After the completion of 1 cycle, the learning rate will decrease rapidly to
23 | 100th its initial lowest value.
24 |
25 | # Arguments:
26 | num_samples: Integer. Number of samples in the dataset.
27 | batch_size: Integer. Batch size during training.
28 | max_lr: Float. Initial learning rate. This also sets the
29 | starting learning rate (which will be 10x smaller than
30 | this), and will increase to this value during the first cycle.
31 | end_percentage: Float. The percentage of all the epochs of training
32 | that will be dedicated to sharply decreasing the learning
33 | rate after the completion of 1 cycle. Must be between 0 and 1.
34 | scale_percentage: Float or None. If float, must be between 0 and 1.
35 | If None, it will compute the scale_percentage automatically
36 | based on the `end_percentage`.
37 | maximum_momentum: Optional. Sets the maximum momentum (initial)
38 | value, which gradually drops to its lowest value in half-cycle,
39 | then gradually increases again to stay constant at this max value.
40 | Can only be used with SGD Optimizer.
41 | minimum_momentum: Optional. Sets the minimum momentum at the end of
42 | the half-cycle. Can only be used with SGD Optimizer.
43 | verbose: Bool. Whether to print the current learning rate after every
44 | epoch.
45 |
46 | # Reference
47 | - [A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, weight_decay, and weight decay](https://arxiv.org/abs/1803.09820)
48 | - [Super-Convergence: Very Fast Training of Residual Networks Using Large Learning Rates](https://arxiv.org/abs/1708.07120)
49 | """
50 | super(OneCycleLR, self).__init__()
51 |
52 | if end_percentage < 0. or end_percentage > 1.:
53 | raise ValueError("`end_percentage` must be between 0 and 1")
54 |
55 | if scale_percentage is not None and (scale_percentage < 0. or scale_percentage > 1.):
56 | raise ValueError("`scale_percentage` must be between 0 and 1")
57 |
58 | self.initial_lr = max_lr
59 | self.end_percentage = end_percentage
60 | self.scale = float(scale_percentage) if scale_percentage is not None else float(end_percentage)
61 | self.max_momentum = maximum_momentum
62 | self.min_momentum = minimum_momentum
63 | self.verbose = verbose
64 |
65 | if self.max_momentum is not None and self.min_momentum is not None:
66 | self._update_momentum = True
67 | else:
68 | self._update_momentum = False
69 |
70 | self.clr_iterations = 0.
71 | self.history = {}
72 |
73 | self.epochs = None
74 | self.batch_size = batch_size
75 | self.samples = num_samples
76 | self.steps = None
77 | self.num_iterations = None
78 | self.mid_cycle_id = None
79 |
80 | def _reset(self):
81 | """
82 | Reset the callback.
83 | """
84 | self.clr_iterations = 0.
85 | self.history = {}
86 |
87 | def compute_lr(self):
88 | """
89 | Compute the learning rate based on which phase of the cycle it is in.
90 |
91 | - If in the first half of training, the learning rate gradually increases.
92 | - If in the second half of training, the learning rate gradually decreases.
93 | - If in the final `end_percentage` portion of training, the learning rate
94 | is quickly reduced to near 100th of the original min learning rate.
95 |
96 | # Returns:
97 | the new learning rate
98 | """
99 | if self.clr_iterations > 2 * self.mid_cycle_id:
100 | current_percentage = (self.clr_iterations - 2 * self.mid_cycle_id)
101 | current_percentage /= float((self.num_iterations - 2 * self.mid_cycle_id))
102 | new_lr = self.initial_lr * (1. + (current_percentage *
103 | (1. - 100.) / 100.)) * self.scale
104 |
105 | elif self.clr_iterations > self.mid_cycle_id:
106 | current_percentage = 1. - (
107 | self.clr_iterations - self.mid_cycle_id) / self.mid_cycle_id
108 | new_lr = self.initial_lr * (1. + current_percentage *
109 | (self.scale * 100 - 1.)) * self.scale
110 |
111 | else:
112 | current_percentage = self.clr_iterations / self.mid_cycle_id
113 | new_lr = self.initial_lr * (1. + current_percentage *
114 | (self.scale * 100 - 1.)) * self.scale
115 |
116 | if self.clr_iterations == self.num_iterations:
117 | self.clr_iterations = 0
118 |
119 | return new_lr
120 |
121 | def compute_momentum(self):
122 | """
123 | Compute the momentum based on which phase of the cycle it is in.
124 |
125 | - If in the first half of training, the momentum gradually decreases.
126 | - If in the second half of training, the momentum gradually increases.
127 | - If in the final `end_percentage` portion of training, the momentum value
128 | is kept constant at the maximum initial value.
129 |
130 | # Returns:
131 | the new momentum value
132 | """
133 | if self.clr_iterations > 2 * self.mid_cycle_id:
134 | new_momentum = self.max_momentum
135 |
136 | elif self.clr_iterations > self.mid_cycle_id:
137 | current_percentage = 1. - ((self.clr_iterations - self.mid_cycle_id) / float(
138 | self.mid_cycle_id))
139 | new_momentum = self.max_momentum - current_percentage * (
140 | self.max_momentum - self.min_momentum)
141 |
142 | else:
143 | current_percentage = self.clr_iterations / float(self.mid_cycle_id)
144 | new_momentum = self.max_momentum - current_percentage * (
145 | self.max_momentum - self.min_momentum)
146 |
147 | return new_momentum
148 |
149 | def on_train_begin(self, logs={}):
150 | logs = logs or {}
151 |
152 | self.epochs = self.params['epochs']
153 | # When fit generator is used
154 | # self.params don't have the elements 'batch_size' and 'samples'
155 | # self.batch_size = self.params['batch_size']
156 | # self.samples = self.params['samples']
157 | self.steps = self.params['steps']
158 |
159 | if self.steps is not None:
160 | self.num_iterations = self.epochs * self.steps
161 | else:
162 | if (self.samples % self.batch_size) == 0:
163 | remainder = 0
164 | else:
165 | remainder = 1
166 | self.num_iterations = (self.epochs + remainder) * self.samples // self.batch_size
167 |
168 | self.mid_cycle_id = int(self.num_iterations * ((1. - self.end_percentage)) / float(2))
169 |
170 | self._reset()
171 | K.set_value(self.model.optimizer.lr, self.compute_lr())
172 |
173 | if self._update_momentum:
174 | if not hasattr(self.model.optimizer, 'momentum'):
175 | raise ValueError("Momentum can be updated only on SGD optimizer !")
176 |
177 | new_momentum = self.compute_momentum()
178 | K.set_value(self.model.optimizer.momentum, new_momentum)
179 |
180 | def on_batch_end(self, epoch, logs=None):
181 | logs = logs or {}
182 |
183 | self.clr_iterations += 1
184 | new_lr = self.compute_lr()
185 |
186 | self.history.setdefault('lr', []).append(
187 | K.get_value(self.model.optimizer.lr))
188 | K.set_value(self.model.optimizer.lr, new_lr)
189 |
190 | if self._update_momentum:
191 | if not hasattr(self.model.optimizer, 'momentum'):
192 | raise ValueError("Momentum can be updated only on SGD optimizer !")
193 |
194 | new_momentum = self.compute_momentum()
195 |
196 | self.history.setdefault('momentum', []).append(
197 | K.get_value(self.model.optimizer.momentum))
198 | K.set_value(self.model.optimizer.momentum, new_momentum)
199 |
200 | for k, v in logs.items():
201 | self.history.setdefault(k, []).append(v)
202 |
203 | def on_epoch_end(self, epoch, logs=None):
204 | if self.verbose:
205 | if self._update_momentum:
206 | print(" - lr: %0.5f - momentum: %0.2f " %
207 | (self.history['lr'][-1], self.history['momentum'][-1]))
208 |
209 | else:
210 | print(" - lr: %0.5f " % (self.history['lr'][-1]))
211 |
212 |
213 | class LRFinder(Callback):
214 | def __init__(self,
215 | num_samples,
216 | batch_size,
217 | minimum_lr=1e-5,
218 | maximum_lr=10.,
219 | lr_scale='exp',
220 | validation_data=None,
221 | validation_sample_rate=5,
222 | stopping_criterion_factor=4.,
223 | loss_smoothing_beta=0.98,
224 | save_dir=None,
225 | verbose=True):
226 | """
227 | This class uses the Cyclic Learning Rate history to find a
228 | set of learning rates that can be good initializations for the
229 | One-Cycle training proposed by Leslie Smith in the paper referenced
230 | below.
231 |
232 | A port of the Fast.ai implementation for Keras.
233 |
234 | # Note
235 | This requires that the model be trained for exactly 1 epoch. If the model
236 | is trained for more epochs, then the metric calculations are only done for
237 | the first epoch.
238 |
239 | # Interpretation
240 | Upon visualizing the loss plot, check where the loss starts to increase
241 | rapidly. Choose a learning rate at somewhat prior to the corresponding
242 | position in the plot for faster convergence. This will be the maximum_lr lr.
243 | Choose the max value as this value when passing the `max_val` argument
244 | to OneCycleLR callback.
245 |
246 | Since the plot is in log-scale, you need to compute 10 ^ (-k) of the x-axis
247 |
248 | # Arguments:
249 | num_samples: Integer. Number of samples in the dataset.
250 | batch_size: Integer. Batch size during training.
251 | minimum_lr: Float. Initial learning rate (and the minimum).
252 | maximum_lr: Float. Final learning rate (and the maximum).
253 | lr_scale: Can be one of ['exp', 'linear']. Chooses the type of
254 | scaling for each update to the learning rate during subsequent
255 | batches. Choose 'exp' for large range and 'linear' for small range.
256 | validation_data: Requires the validation dataset as a tuple of
257 | (X, y) belonging to the validation set. If provided, will use the
258 | validation set to compute the loss metrics. Else uses the training
259 | batch loss. Will warn if not provided to alert the user.
260 | validation_sample_rate: Positive or Negative Integer. Number of batches to sample from the
261 | validation set per iteration of the LRFinder. Larger number of
262 | samples will reduce the variance but will take longer time to execute
263 | per batch.
264 |
265 | If Positive > 0, will sample from the validation dataset
266 | If Megative, will use the entire dataset
267 | stopping_criterion_factor: Integer or None. A factor which is used
268 | to measure large increase in the loss value during training.
269 | Since callbacks cannot stop training of a model, it will simply
270 | stop logging the additional values from the epochs after this
271 | stopping criterion has been met.
272 | If None, this check will not be performed.
273 | loss_smoothing_beta: Float. The smoothing factor for the moving
274 | average of the loss function.
275 | save_dir: Optional, String. If passed a directory path, the callback
276 | will save the running loss and learning rates to two separate numpy
277 | arrays inside this directory. If the directory in this path does not
278 | exist, they will be created.
279 | verbose: Whether to print the learning rate after every batch of training.
280 |
281 | # References:
282 | - [A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, weight_decay, and weight decay](https://arxiv.org/abs/1803.09820)
283 | """
284 | super(LRFinder, self).__init__()
285 |
286 | if lr_scale not in ['exp', 'linear']:
287 | raise ValueError("`lr_scale` must be one of ['exp', 'linear']")
288 |
289 | if validation_data is not None:
290 | self.validation_data = validation_data
291 | self.use_validation_set = True
292 |
293 | if validation_sample_rate > 0 or validation_sample_rate < 0:
294 | self.validation_sample_rate = validation_sample_rate
295 | else:
296 | raise ValueError("`validation_sample_rate` must be a positive or negative integer other than o")
297 | else:
298 | self.use_validation_set = False
299 | self.validation_sample_rate = 0
300 |
301 | self.num_samples = num_samples
302 | self.batch_size = batch_size
303 | self.initial_lr = minimum_lr
304 | self.final_lr = maximum_lr
305 | self.lr_scale = lr_scale
306 | self.stopping_criterion_factor = stopping_criterion_factor
307 | self.loss_smoothing_beta = loss_smoothing_beta
308 | self.save_dir = save_dir
309 | self.verbose = verbose
310 |
311 | self.num_batches_ = num_samples // batch_size
312 | self.current_lr_ = minimum_lr
313 |
314 | if lr_scale == 'exp':
315 | self.lr_multiplier_ = (maximum_lr / float(minimum_lr)) ** (
316 | 1. / float(self.num_batches_))
317 | else:
318 | extra_batch = int((num_samples % batch_size) != 0)
319 | self.lr_multiplier_ = np.linspace(
320 | minimum_lr, maximum_lr, num=self.num_batches_ + extra_batch)
321 |
322 | # If negative, use entire validation set
323 | if self.validation_sample_rate < 0:
324 | self.validation_sample_rate = self.validation_data[0].shape[0] // batch_size
325 |
326 | self.current_batch_ = 0
327 | self.current_epoch_ = 0
328 | self.best_loss_ = 1e6
329 | self.running_loss_ = 0.
330 |
331 | self.history = {}
332 |
333 | def on_train_begin(self, logs=None):
334 |
335 | self.current_epoch_ = 1
336 | K.set_value(self.model.optimizer.lr, self.initial_lr)
337 |
338 | warnings.simplefilter("ignore")
339 |
340 | def on_epoch_begin(self, epoch, logs=None):
341 | self.current_batch_ = 0
342 |
343 | if self.current_epoch_ > 1:
344 | warnings.warn(
345 | "\n\nLearning rate finder should be used only with a single epoch. "
346 | "Hereafter, the callback will not measure the losses.\n\n")
347 |
348 | def on_batch_begin(self, batch, logs=None):
349 | self.current_batch_ += 1
350 |
351 | def on_batch_end(self, batch, logs=None):
352 | if self.current_epoch_ > 1:
353 | return
354 |
355 | if self.use_validation_set:
356 | X, Y = self.validation_data[0], self.validation_data[1]
357 |
358 | # use 5 random batches from test set for fast approximate of loss
359 | num_samples = self.batch_size * self.validation_sample_rate
360 |
361 | if num_samples > X.shape[0]:
362 | num_samples = X.shape[0]
363 |
364 | idx = np.random.choice(X.shape[0], num_samples, replace=False)
365 | x = X[idx]
366 | y = Y[idx]
367 |
368 | values = self.model.evaluate(x, y, batch_size=self.batch_size, verbose=False)
369 | loss = values[0]
370 | else:
371 | loss = logs['loss']
372 |
373 | # smooth the loss value and bias correct
374 | running_loss = self.loss_smoothing_beta * loss + (
375 | 1. - self.loss_smoothing_beta) * loss
376 | running_loss = running_loss / (
377 | 1. - self.loss_smoothing_beta**self.current_batch_)
378 |
379 | # stop logging if loss is too large
380 | if self.current_batch_ > 1 and self.stopping_criterion_factor is not None and (
381 | running_loss >
382 | self.stopping_criterion_factor * self.best_loss_):
383 |
384 | if self.verbose:
385 | print(" - LRFinder: Skipping iteration since loss is %d times as large as best loss (%0.4f)"
386 | % (self.stopping_criterion_factor, self.best_loss_))
387 | return
388 |
389 | if running_loss < self.best_loss_ or self.current_batch_ == 1:
390 | self.best_loss_ = running_loss
391 |
392 | current_lr = K.get_value(self.model.optimizer.lr)
393 |
394 | self.history.setdefault('running_loss_', []).append(running_loss)
395 | if self.lr_scale == 'exp':
396 | self.history.setdefault('log_lrs', []).append(np.log10(current_lr))
397 | else:
398 | self.history.setdefault('log_lrs', []).append(current_lr)
399 |
400 | # compute the lr for the next batch and update the optimizer lr
401 | if self.lr_scale == 'exp':
402 | current_lr *= self.lr_multiplier_
403 | else:
404 | current_lr = self.lr_multiplier_[self.current_batch_ - 1]
405 |
406 | K.set_value(self.model.optimizer.lr, current_lr)
407 |
408 | # save the other metrics as well
409 | for k, v in logs.items():
410 | self.history.setdefault(k, []).append(v)
411 |
412 | if self.verbose:
413 | if self.use_validation_set:
414 | print(" - LRFinder: val_loss: %1.4f - lr = %1.8f " %
415 | (values[0], current_lr))
416 | else:
417 | print(" - LRFinder: lr = %1.8f " % current_lr)
418 |
419 | def on_epoch_end(self, epoch, logs=None):
420 | if self.save_dir is not None and self.current_epoch_ <= 1:
421 | if not os.path.exists(self.save_dir):
422 | os.makedirs(self.save_dir)
423 |
424 | losses_path = os.path.join(self.save_dir, 'losses.npy')
425 | lrs_path = os.path.join(self.save_dir, 'lrs.npy')
426 |
427 | np.save(losses_path, self.losses)
428 | np.save(lrs_path, self.lrs)
429 |
430 | if self.verbose:
431 | print("\tLR Finder : Saved the losses and learning rate values in path : {%s}"
432 | % (self.save_dir))
433 |
434 | self.current_epoch_ += 1
435 |
436 | warnings.simplefilter("default")
437 |
438 | def plot_schedule(self, clip_beginning=None, clip_endding=None):
439 | """
440 | Plots the schedule from the callback itself.
441 |
442 | # Arguments:
443 | clip_beginning: Integer or None. If positive integer, it will
444 | remove the specified portion of the loss graph to remove the large
445 | loss values in the beginning of the graph.
446 | clip_endding: Integer or None. If negative integer, it will
447 | remove the specified portion of the ending of the loss graph to
448 | remove the sharp increase in the loss values at high learning rates.
449 | """
450 | try:
451 | import matplotlib.pyplot as plt
452 | plt.style.use('seaborn-white')
453 | except ImportError:
454 | print(
455 | "Matplotlib not found. Please use `pip install matplotlib` first."
456 | )
457 | return
458 |
459 | if clip_beginning is not None and clip_beginning < 0:
460 | clip_beginning = -clip_beginning
461 |
462 | if clip_endding is not None and clip_endding > 0:
463 | clip_endding = -clip_endding
464 |
465 | losses = self.losses
466 | lrs = self.lrs
467 |
468 | if clip_beginning:
469 | losses = losses[clip_beginning:]
470 | lrs = lrs[clip_beginning:]
471 |
472 | if clip_endding:
473 | losses = losses[:clip_endding]
474 | lrs = lrs[:clip_endding]
475 |
476 | plt.plot(lrs, losses)
477 | plt.title('Learning rate vs Loss')
478 | plt.xlabel('learning rate')
479 | plt.ylabel('loss')
480 | plt.show()
481 |
482 | @classmethod
483 | def restore_schedule_from_dir(cls,
484 | directory,
485 | clip_beginning=None,
486 | clip_endding=None):
487 | """
488 | Loads the training history from the saved numpy files in the given directory.
489 |
490 | # Arguments:
491 | directory: String. Path to the directory where the serialized numpy
492 | arrays of the loss and learning rates are saved.
493 | clip_beginning: Integer or None. If positive integer, it will
494 | remove the specified portion of the loss graph to remove the large
495 | loss values in the beginning of the graph.
496 | clip_endding: Integer or None. If negative integer, it will
497 | remove the specified portion of the ending of the loss graph to
498 | remove the sharp increase in the loss values at high learning rates.
499 |
500 | Returns:
501 | tuple of (losses, learning rates)
502 | """
503 | if clip_beginning is not None and clip_beginning < 0:
504 | clip_beginning = -clip_beginning
505 |
506 | if clip_endding is not None and clip_endding > 0:
507 | clip_endding = -clip_endding
508 |
509 | losses_path = os.path.join(directory, 'losses.npy')
510 | lrs_path = os.path.join(directory, 'lrs.npy')
511 |
512 | if not os.path.exists(losses_path) or not os.path.exists(lrs_path):
513 | print("%s and %s could not be found at directory : {%s}" %
514 | (losses_path, lrs_path, directory))
515 |
516 | losses = None
517 | lrs = None
518 |
519 | else:
520 | losses = np.load(losses_path)
521 | lrs = np.load(lrs_path)
522 |
523 | if clip_beginning:
524 | losses = losses[clip_beginning:]
525 | lrs = lrs[clip_beginning:]
526 |
527 | if clip_endding:
528 | losses = losses[:clip_endding]
529 | lrs = lrs[:clip_endding]
530 |
531 | return losses, lrs
532 |
533 | @classmethod
534 | def plot_schedule_from_file(cls,
535 | directory,
536 | clip_beginning=None,
537 | clip_endding=None):
538 | """
539 | Plots the schedule from the saved numpy arrays of the loss and learning
540 | rate values in the specified directory.
541 |
542 | # Arguments:
543 | directory: String. Path to the directory where the serialized numpy
544 | arrays of the loss and learning rates are saved.
545 | clip_beginning: Integer or None. If positive integer, it will
546 | remove the specified portion of the loss graph to remove the large
547 | loss values in the beginning of the graph.
548 | clip_endding: Integer or None. If negative integer, it will
549 | remove the specified portion of the ending of the loss graph to
550 | remove the sharp increase in the loss values at high learning rates.
551 | """
552 | try:
553 | import matplotlib.pyplot as plt
554 | plt.style.use('seaborn-white')
555 | except ImportError:
556 | print("Matplotlib not found. Please use `pip install matplotlib` first.")
557 | return
558 |
559 | losses, lrs = cls.restore_schedule_from_dir(
560 | directory,
561 | clip_beginning=clip_beginning,
562 | clip_endding=clip_endding)
563 |
564 | if losses is None or lrs is None:
565 | return
566 | else:
567 | plt.plot(lrs, losses)
568 | plt.title('Learning rate vs Loss')
569 | plt.xlabel('learning rate')
570 | plt.ylabel('loss')
571 | plt.show()
572 |
573 | @property
574 | def lrs(self):
575 | return np.array(self.history['log_lrs'])
576 |
577 | @property
578 | def losses(self):
579 | return np.array(self.history['running_loss_'])
580 |
--------------------------------------------------------------------------------
/images/lr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/images/lr.png
--------------------------------------------------------------------------------
/images/momentum.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/images/momentum.png
--------------------------------------------------------------------------------
/images/one_cycle_lr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/images/one_cycle_lr.png
--------------------------------------------------------------------------------
/images/one_cycle_momentum.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/images/one_cycle_momentum.png
--------------------------------------------------------------------------------
/images/weight_decay.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/images/weight_decay.png
--------------------------------------------------------------------------------
/models/mobilenet/find_lr_schedule.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from keras example cifar10_cnn.py
3 | Train NASNet-CIFAR on the CIFAR10 small images dataset.
4 | """
5 | from __future__ import print_function
6 | import os
7 |
8 | from keras.datasets import cifar10
9 | from keras.preprocessing.image import ImageDataGenerator
10 | from keras.utils import np_utils
11 | from keras.callbacks import ModelCheckpoint
12 | from keras.optimizers import SGD
13 | import numpy as np
14 |
15 | from clr import LRFinder
16 | from models.mobilenet.mobilenets import MiniMobileNetV2
17 |
18 | if not os.path.exists('weights/'):
19 | os.makedirs('weights/')
20 |
21 | weights_file = 'weights/mobilenet_v2_schedule.h5'
22 | model_checkpoint = ModelCheckpoint(weights_file, monitor='val_acc', save_best_only=True,
23 | save_weights_only=True, mode='max')
24 |
25 | batch_size = 128
26 | nb_classes = 10
27 | nb_epoch = 1 # Only finding lr
28 | data_augmentation = True
29 |
30 | # input image dimensions
31 | img_rows, img_cols = 32, 32
32 | # The CIFAR10 images are RGB.
33 | img_channels = 3
34 |
35 | # The data, shuffled and split between train and test sets:
36 | (X_train, y_train), (X_test, y_test) = cifar10.load_data()
37 |
38 | # Convert class vectors to binary class matrices.
39 | Y_train = np_utils.to_categorical(y_train, nb_classes)
40 | Y_test = np_utils.to_categorical(y_test, nb_classes)
41 |
42 | X_train = X_train.astype('float32')
43 | X_test = X_test.astype('float32')
44 |
45 | # preprocess input
46 | mean = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32')
47 | std = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32')
48 |
49 | print("Channel Mean : ", mean)
50 | print("Channel Std : ", std)
51 |
52 | X_train = (X_train - mean) / (std)
53 | X_test = (X_test - mean) / (std)
54 |
55 | # Learning rate finder callback setup
56 | num_samples = X_train.shape[0]
57 |
58 | # Exponential lr finder
59 | # USE THIS FOR A LARGE RANGE SEARCH
60 | # Uncomment the validation_data flag to reduce speed but get a better idea of the learning rate
61 | lr_finder = LRFinder(num_samples, batch_size, minimum_lr=1e-3, maximum_lr=10.,
62 | lr_scale='exp',
63 | # validation_data=(X_test, Y_test), # use the validation data for losses
64 | validation_sample_rate=5,
65 | save_dir='weights/', verbose=True)
66 |
67 | # Linear lr finder
68 | # USE THIS FOR A CLOSE SEARCH
69 | # Uncomment the validation_data flag to reduce speed but get a better idea of the learning rate
70 | # lr_finder = LRFinder(num_samples, batch_size, minimum_lr=5e-4, maximum_lr=1e-2,
71 | # lr_scale='linear',
72 | # validation_data=(X_test, y_test), # use the validation data for losses
73 | # validation_sample_rate=5,
74 | # save_dir='weights/', verbose=True)
75 |
76 | # plot the previous values if present
77 | LRFinder.plot_schedule_from_file('weights/', clip_beginning=10, clip_endding=5)
78 |
79 | # For training, the auxilary branch must be used to correctly train NASNet
80 |
81 | model = MiniMobileNetV2((img_rows, img_cols, img_channels), alpha=1.4,
82 | dropout=0, weights=None, classes=nb_classes)
83 | model.summary()
84 |
85 | optimizer = SGD(lr=0.1, momentum=0.9, nesterov=True)
86 | model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
87 |
88 | # model.load_weights(weights_file)
89 |
90 | if not data_augmentation:
91 | print('Not using data augmentation.')
92 | model.fit(X_train, Y_train,
93 | batch_size=batch_size,
94 | epochs=nb_epoch,
95 | validation_data=(X_test, Y_test),
96 | shuffle=True,
97 | verbose=1,
98 | callbacks=[lr_finder, model_checkpoint])
99 | else:
100 | print('Using real-time data augmentation.')
101 | # This will do preprocessing and realtime data augmentation:
102 | datagen = ImageDataGenerator(
103 | featurewise_center=False, # set input mean to 0 over the dataset
104 | samplewise_center=False, # set each sample mean to 0
105 | featurewise_std_normalization=False, # divide inputs by std of the dataset
106 | samplewise_std_normalization=False, # divide each input by its std
107 | zca_whitening=False, # apply ZCA whitening
108 | rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180)
109 | width_shift_range=0, # randomly shift images horizontally (fraction of total width)
110 | height_shift_range=0, # randomly shift images vertically (fraction of total height)
111 | horizontal_flip=True, # randomly flip images
112 | vertical_flip=False) # randomly flip images
113 |
114 | # Compute quantities required for featurewise normalization
115 | # (std, mean, and principal components if ZCA whitening is applied).
116 | datagen.fit(X_train)
117 |
118 | # Fit the model on the batches generated by datagen.flow().
119 | model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size, shuffle=True),
120 | steps_per_epoch=X_train.shape[0] // batch_size,
121 | validation_data=(X_test, Y_test),
122 | epochs=nb_epoch, verbose=1,
123 | callbacks=[lr_finder, model_checkpoint])
124 |
125 | lr_finder.plot_schedule(clip_beginning=10, clip_endding=5)
126 |
127 | scores = model.evaluate(X_test, Y_test, batch_size=batch_size)
128 | for score, metric_name in zip(scores, model.metrics_names):
129 | print("%s : %0.4f" % (metric_name, score))
130 |
--------------------------------------------------------------------------------
/models/mobilenet/find_momentum_schedule.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from keras example cifar10_cnn.py
3 | Train NASNet-CIFAR on the CIFAR10 small images dataset.
4 | """
5 | from __future__ import print_function
6 | import os
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 |
10 | from keras.datasets import cifar10
11 | from keras.preprocessing.image import ImageDataGenerator
12 | from keras.utils import np_utils
13 | from keras.callbacks import ModelCheckpoint
14 | from keras.optimizers import SGD
15 | from keras import backend as K
16 |
17 | from clr import LRFinder
18 | from models.mobilenet.mobilenets import MiniMobileNetV2
19 |
20 | plt.style.use('seaborn-white')
21 |
22 | batch_size = 128
23 | nb_classes = 10
24 | nb_epoch = 1 # Only finding lr
25 | data_augmentation = True
26 |
27 | # input image dimensions
28 | img_rows, img_cols = 32, 32
29 | # The CIFAR10 images are RGB.
30 | img_channels = 3
31 |
32 | # The data, shuffled and split between train and test sets:
33 | (X_train, y_train), (X_test, y_test) = cifar10.load_data()
34 |
35 | # Convert class vectors to binary class matrices.
36 | Y_train = np_utils.to_categorical(y_train, nb_classes)
37 | Y_test = np_utils.to_categorical(y_test, nb_classes)
38 |
39 | X_train = X_train.astype('float32')
40 | X_test = X_test.astype('float32')
41 |
42 | # preprocess input
43 | mean = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32')
44 | std = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32')
45 |
46 | print("Channel Mean : ", mean)
47 | print("Channel Std : ", std)
48 |
49 | X_train = (X_train - mean) / (std)
50 | X_test = (X_test - mean) / (std)
51 |
52 | # Learning rate finder callback setup
53 | num_samples = X_train.shape[0]
54 |
55 | MOMENTUMS = [0.9, 0.95, 0.99]
56 |
57 | # for momentum in MOMENTUMS:
58 | # K.clear_session()
59 | #
60 | # # Learning rate range obtained from `find_lr_schedule.py`
61 | # # NOTE : Minimum is 10x smaller than the max found above !
62 | # # NOTE : It is preferable to use the validation data here to get a correct value
63 | # lr_finder = LRFinder(num_samples, batch_size, minimum_lr=0.002, maximum_lr=0.02,
64 | # validation_data=(X_test, Y_test),
65 | # validation_sample_rate=5,
66 | # lr_scale='linear', save_dir='weights/momentum/momentum-%s/' % str(momentum),
67 | # verbose=True)
68 | #
69 | # model = MiniMobileNetV2((img_rows, img_cols, img_channels), alpha=1.4,
70 | # dropout=0, weights=None, classes=nb_classes)
71 | # model.summary()
72 | #
73 | # # set the weight_decay here !
74 | # # lr doesnt matter as it will be over written by the callback
75 | # optimizer = SGD(lr=0.002, momentum=momentum, nesterov=True)
76 | # model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
77 | #
78 | # # model.load_weights(weights_file)
79 | #
80 | # if not data_augmentation:
81 | # print('Not using data augmentation.')
82 | # model.fit(X_train, Y_train,
83 | # batch_size=batch_size,
84 | # epochs=nb_epoch,
85 | # validation_data=(X_test, Y_test),
86 | # shuffle=True,
87 | # verbose=1,
88 | # callbacks=[lr_finder])
89 | # else:
90 | # print('Using real-time data augmentation.')
91 | # # This will do preprocessing and realtime data augmentation:
92 | # datagen = ImageDataGenerator(
93 | # featurewise_center=False, # set input mean to 0 over the dataset
94 | # samplewise_center=False, # set each sample mean to 0
95 | # featurewise_std_normalization=False, # divide inputs by std of the dataset
96 | # samplewise_std_normalization=False, # divide each input by its std
97 | # zca_whitening=False, # apply ZCA whitening
98 | # rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180)
99 | # width_shift_range=0, # randomly shift images horizontally (fraction of total width)
100 | # height_shift_range=0, # randomly shift images vertically (fraction of total height)
101 | # horizontal_flip=True, # randomly flip images
102 | # vertical_flip=False) # randomly flip images
103 | #
104 | # # Compute quantities required for featurewise normalization
105 | # # (std, mean, and principal components if ZCA whitening is applied).
106 | # datagen.fit(X_train)
107 | #
108 | # # Fit the model on the batches generated by datagen.flow().
109 | # model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size, shuffle=True),
110 | # steps_per_epoch=X_train.shape[0] // batch_size,
111 | # validation_data=(X_test, Y_test),
112 | # epochs=nb_epoch, verbose=1,
113 | # callbacks=[lr_finder])
114 |
115 | # from plot we see, the model isnt impacted by the weight_decay very much at all
116 | # so we can use any of them.
117 |
118 | for momentum in MOMENTUMS:
119 | directory = 'weights/momentum/momentum-%s/' % str(momentum)
120 |
121 | losses, lrs = LRFinder.restore_schedule_from_dir(directory, 10, 5)
122 | plt.plot(lrs, losses, label='momentum=%0.2f' % momentum)
123 |
124 | plt.title("Momentum")
125 | plt.xlabel("Learning rate")
126 | plt.ylabel("Validation Loss")
127 | plt.legend()
128 | plt.show()
129 |
--------------------------------------------------------------------------------
/models/mobilenet/find_weight_decay_schedule.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from keras example cifar10_cnn.py
3 | Train NASNet-CIFAR on the CIFAR10 small images dataset.
4 | """
5 | from __future__ import print_function
6 | import os
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 |
10 | from keras.datasets import cifar10
11 | from keras.preprocessing.image import ImageDataGenerator
12 | from keras.utils import np_utils
13 | from keras.callbacks import ModelCheckpoint
14 | from keras.optimizers import SGD
15 | from keras import backend as K
16 |
17 | from clr import LRFinder
18 | from models.mobilenet.mobilenets import MiniMobileNetV2
19 |
20 | plt.style.use('seaborn-white')
21 |
22 | batch_size = 128
23 | nb_classes = 10
24 | nb_epoch = 1 # Only finding lr
25 | data_augmentation = True
26 |
27 | # input image dimensions
28 | img_rows, img_cols = 32, 32
29 | # The CIFAR10 images are RGB.
30 | img_channels = 3
31 |
32 | # The data, shuffled and split between train and test sets:
33 | (X_train, y_train), (X_test, y_test) = cifar10.load_data()
34 |
35 | # Convert class vectors to binary class matrices.
36 | Y_train = np_utils.to_categorical(y_train, nb_classes)
37 | Y_test = np_utils.to_categorical(y_test, nb_classes)
38 |
39 | X_train = X_train.astype('float32')
40 | X_test = X_test.astype('float32')
41 |
42 | # preprocess input
43 | mean = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32')
44 | std = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32')
45 |
46 | print("Channel Mean : ", mean)
47 | print("Channel Std : ", std)
48 |
49 | X_train = (X_train - mean) / (std)
50 | X_test = (X_test - mean) / (std)
51 |
52 | # Learning rate finder callback setup
53 | num_samples = X_train.shape[0]
54 |
55 | # INITIAL WEIGHT DECAY FACTORS
56 | # WEIGHT_DECAY_FACTORS = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7]
57 |
58 | # FINEGRAINED WEIGHT DECAY FACTORS
59 | WEIGHT_DECAY_FACTORS = [1e-7, 3e-7, 3e-6]
60 |
61 | # for weight_decay in WEIGHT_DECAY_FACTORS:
62 | # K.clear_session()
63 | #
64 | # # Learning rate range obtained from `find_lr_schedule.py`
65 | # # NOTE : Minimum is 10x smaller than the max found above !
66 | # # NOTE : It is preferable to use the validation data here to get a correct value
67 | # lr_finder = LRFinder(num_samples, batch_size, minimum_lr=0.002, maximum_lr=0.02,
68 | # validation_data=(X_test, Y_test),
69 | # validation_sample_rate=5,
70 | # lr_scale='linear', save_dir='weights/weight_decay/weight_decay-%s/' % str(weight_decay),
71 | # verbose=True)
72 | #
73 | # # SETUP THE WEIGHT DECAY IN THE MODEL
74 | # model = MiniMobileNetV2((img_rows, img_cols, img_channels), alpha=1.4,
75 | # weight_decay=weight_decay, dropout=0,
76 | # weights=None, classes=nb_classes)
77 | # model.summary()
78 | #
79 | # # set the weight_decay here !
80 | # # lr doesnt matter as it will be over written by the callback
81 | # optimizer = SGD(lr=0.002, momentum=0.9, nesterov=True)
82 | # model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
83 | #
84 | # # model.load_weights(weights_file)
85 | #
86 | # if not data_augmentation:
87 | # print('Not using data augmentation.')
88 | # model.fit(X_train, Y_train,
89 | # batch_size=batch_size,
90 | # epochs=nb_epoch,
91 | # validation_data=(X_test, Y_test),
92 | # shuffle=True,
93 | # verbose=1,
94 | # callbacks=[lr_finder])
95 | # else:
96 | # print('Using real-time data augmentation.')
97 | # # This will do preprocessing and realtime data augmentation:
98 | # datagen = ImageDataGenerator(
99 | # featurewise_center=False, # set input mean to 0 over the dataset
100 | # samplewise_center=False, # set each sample mean to 0
101 | # featurewise_std_normalization=False, # divide inputs by std of the dataset
102 | # samplewise_std_normalization=False, # divide each input by its std
103 | # zca_whitening=False, # apply ZCA whitening
104 | # rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180)
105 | # width_shift_range=0, # randomly shift images horizontally (fraction of total width)
106 | # height_shift_range=0, # randomly shift images vertically (fraction of total height)
107 | # horizontal_flip=True, # randomly flip images
108 | # vertical_flip=False) # randomly flip images
109 | #
110 | # # Compute quantities required for featurewise normalization
111 | # # (std, mean, and principal components if ZCA whitening is applied).
112 | # datagen.fit(X_train)
113 | #
114 | # # Fit the model on the batches generated by datagen.flow().
115 | # model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size, shuffle=True),
116 | # steps_per_epoch=X_train.shape[0] // batch_size,
117 | # validation_data=(X_test, Y_test),
118 | # epochs=nb_epoch, verbose=1,
119 | # callbacks=[lr_finder])
120 |
121 | # from plot we see, the model isnt impacted by the weight_decay very much at all
122 | # so we can use any of them.
123 |
124 | for weight_decay in WEIGHT_DECAY_FACTORS:
125 | directory = 'weights/weight_decay/weight_decay-%s/' % str(weight_decay)
126 |
127 | losses, lrs = LRFinder.restore_schedule_from_dir(directory, 10, 5)
128 | plt.plot(lrs, losses, label='weight_decay=%0.7f' % weight_decay)
129 |
130 | plt.title("Weight Decay")
131 | plt.xlabel("Learning rate")
132 | plt.ylabel("Validation Loss")
133 | plt.legend()
134 | plt.show()
135 |
--------------------------------------------------------------------------------
/models/mobilenet/mobilenets.py:
--------------------------------------------------------------------------------
1 | """MobileNet v1 models for Keras.
2 | MobileNet is a general architecture and can be used for multiple use cases.
3 | Depending on the use case, it can use different input layer size and
4 | different width factors. This allows different width models to reduce
5 | the number of multiply-adds and thereby
6 | reduce inference cost on mobile devices.
7 | MobileNets support any input size greater than 32 x 32, with larger image sizes
8 | offering better performance.
9 | The number of parameters and number of multiply-adds
10 | can be modified by using the `alpha` parameter,
11 | which increases/decreases the number of filters in each layer.
12 | By altering the image size and `alpha` parameter,
13 | all 16 models from the paper can be built, with ImageNet weights provided.
14 | The paper demonstrates the performance of MobileNets using `alpha` values of
15 | 1.0 (also called 100 % MobileNet), 0.75, 0.5 and 0.25.
16 | For each of these `alpha` values, weights for 4 different input image sizes
17 | are provided (224, 192, 160, 128).
18 | The following table describes the size and accuracy of the 100% MobileNet
19 | on size 224 x 224:
20 | ----------------------------------------------------------------------------
21 | Width Multiplier (alpha) | ImageNet Acc | Multiply-Adds (M) | Params (M)
22 | ----------------------------------------------------------------------------
23 | | 1.0 MobileNet-224 | 70.6 % | 529 | 4.2 |
24 | | 0.75 MobileNet-224 | 68.4 % | 325 | 2.6 |
25 | | 0.50 MobileNet-224 | 63.7 % | 149 | 1.3 |
26 | | 0.25 MobileNet-224 | 50.6 % | 41 | 0.5 |
27 | ----------------------------------------------------------------------------
28 | The following table describes the performance of
29 | the 100 % MobileNet on various input sizes:
30 | ------------------------------------------------------------------------
31 | Resolution | ImageNet Acc | Multiply-Adds (M) | Params (M)
32 | ------------------------------------------------------------------------
33 | | 1.0 MobileNet-224 | 70.6 % | 529 | 4.2 |
34 | | 1.0 MobileNet-192 | 69.1 % | 529 | 4.2 |
35 | | 1.0 MobileNet-160 | 67.2 % | 529 | 4.2 |
36 | | 1.0 MobileNet-128 | 64.4 % | 529 | 4.2 |
37 | ------------------------------------------------------------------------
38 | The weights for all 16 models are obtained and translated
39 | from Tensorflow checkpoints found at
40 | https://github.com/tensorflow/models/blob/master/slim/nets/mobilenet_v1.md
41 | # Reference
42 | - [MobileNets: Efficient Convolutional Neural Networks for
43 | Mobile Vision Applications](https://arxiv.org/pdf/1704.04861.pdf))
44 | """
45 | from __future__ import print_function
46 | from __future__ import absolute_import
47 | from __future__ import division
48 |
49 | import warnings
50 | import math
51 |
52 | from keras.models import Model
53 | from keras.layers import Input
54 | from keras.layers import Activation
55 | from keras.layers import Dropout
56 | from keras.layers import Reshape
57 | from keras.layers import BatchNormalization
58 | from keras.layers import GlobalAveragePooling2D
59 | from keras.layers import GlobalMaxPooling2D
60 | from keras.layers import Conv2D
61 | from keras.layers import DepthwiseConv2D
62 | from keras.layers import add
63 | from keras import initializers
64 | from keras import regularizers
65 | from keras import constraints
66 | from keras.utils import conv_utils
67 | from keras.utils.data_utils import get_file
68 | from keras.engine.topology import get_source_inputs
69 | from keras.engine import InputSpec
70 | # changed the next 3 lines from keras.applications to keras_applications for Keras version 2.2.4
71 | from keras_applications.imagenet_utils import _obtain_input_shape
72 | from keras_applications.inception_v3 import preprocess_input
73 | from keras_applications.imagenet_utils import decode_predictions
74 | from keras import backend as K
75 |
76 | import tensorflow as tf
77 |
78 | BASE_WEIGHT_PATH = ''
79 | BASE_WEIGHT_PATH_V2 = ''
80 |
81 |
82 | def relu6(x):
83 | return K.relu(x, max_value=6)
84 |
85 |
86 | def MiniMobileNetV2(input_shape=None,
87 | alpha=1.0,
88 | expansion_factor=6,
89 | depth_multiplier=1,
90 | dropout=0.,
91 | weight_decay=0.,
92 | include_top=True,
93 | weights=None,
94 | input_tensor=None,
95 | pooling=None,
96 | classes=10):
97 | """Instantiates the MobileNet architecture.
98 | MobileNet V2 is from the paper:
99 | - [Inverted Residuals and Linear Bottlenecks: Mobile Networks for Classification, Detection and Segmentation](https://arxiv.org/abs/1801.04381)
100 |
101 | Note that only TensorFlow is supported for now,
102 | therefore it only works with the data format
103 | `image_data_format='channels_last'` in your Keras config
104 | at `~/.keras/keras.json`.
105 | To load a MobileNet model via `load_model`, import the custom
106 | objects `relu6` and `DepthwiseConv2D` and pass them to the
107 | `custom_objects` parameter.
108 | E.g.
109 | model = load_model('mobilenet.h5', custom_objects={
110 | 'relu6': mobilenet.relu6,
111 | 'DepthwiseConv2D': mobilenet.DepthwiseConv2D})
112 | # Arguments
113 | input_shape: optional shape tuple, only to be specified
114 | if `include_top` is False (otherwise the input shape
115 | has to be `(224, 224, 3)` (with `channels_last` data format)
116 | or (3, 224, 224) (with `channels_first` data format).
117 | It should have exactly 3 inputs channels,
118 | and width and height should be no smaller than 32.
119 | E.g. `(200, 200, 3)` would be one valid value.
120 | alpha: controls the width of the network.
121 | - If `alpha` < 1.0, proportionally decreases the number
122 | of filters in each layer.
123 | - If `alpha` > 1.0, proportionally increases the number
124 | of filters in each layer.
125 | - If `alpha` = 1, default number of filters from the paper
126 | are used at each layer.
127 | expansion_factor: controls the expansion of the internal bottleneck
128 | blocks. Should be a positive integer >= 1
129 | depth_multiplier: depth multiplier for depthwise convolution
130 | (also called the resolution multiplier)
131 | dropout: dropout rate
132 | weight_decay: Weight decay factor.
133 | include_top: whether to include the fully-connected
134 | layer at the top of the network.
135 | weights: `None` (random initialization) or
136 | `imagenet` (ImageNet weights)
137 | input_tensor: optional Keras tensor (i.e. output of
138 | `layers.Input()`)
139 | to use as image input for the model.
140 | pooling: Optional pooling mode for feature extraction
141 | when `include_top` is `False`.
142 | - `None` means that the output of the model
143 | will be the 4D tensor output of the
144 | last convolutional layer.
145 | - `avg` means that global average pooling
146 | will be applied to the output of the
147 | last convolutional layer, and thus
148 | the output of the model will be a
149 | 2D tensor.
150 | - `max` means that global max pooling will
151 | be applied.
152 | classes: optional number of classes to classify images
153 | into, only to be specified if `include_top` is True, and
154 | if no `weights` argument is specified.
155 | # Returns
156 | A Keras model instance.
157 | # Raises
158 | ValueError: in case of invalid argument for `weights`,
159 | or invalid input shape.
160 | RuntimeError: If attempting to run this model with a
161 | backend that does not support separable convolutions.
162 | """
163 |
164 | if K.backend() != 'tensorflow':
165 | raise RuntimeError('Only Tensorflow backend is currently supported, '
166 | 'as other backends do not support '
167 | 'depthwise convolution.')
168 |
169 | if weights not in {'imagenet', None}:
170 | raise ValueError('The `weights` argument should be either '
171 | '`None` (random initialization) or `imagenet` '
172 | '(pre-training on ImageNet).')
173 |
174 | if weights == 'imagenet' and include_top and classes != 1000:
175 | raise ValueError('If using `weights` as ImageNet with `include_top` '
176 | 'as true, `classes` should be 1000')
177 |
178 | # Determine proper input shape and default size.
179 | if input_shape is None:
180 | default_size = 224
181 | else:
182 | if K.image_data_format() == 'channels_first':
183 | rows = input_shape[1]
184 | cols = input_shape[2]
185 | else:
186 | rows = input_shape[0]
187 | cols = input_shape[1]
188 |
189 | if rows == cols and rows in [96, 128, 160, 192, 224]:
190 | default_size = rows
191 | else:
192 | default_size = 224
193 |
194 | input_shape = _obtain_input_shape(input_shape,
195 | default_size=default_size,
196 | min_size=32,
197 | data_format=K.image_data_format(),
198 | require_flatten=include_top or weights)
199 | if K.image_data_format() == 'channels_last':
200 | row_axis, col_axis = (0, 1)
201 | else:
202 | row_axis, col_axis = (1, 2)
203 | rows = input_shape[row_axis]
204 | cols = input_shape[col_axis]
205 |
206 | if weights == 'imagenet':
207 | if depth_multiplier != 1:
208 | raise ValueError('If imagenet weights are being loaded, '
209 | 'depth multiplier must be 1')
210 |
211 | if alpha not in [0.35, 0.50, 0.75, 1.0, 1.3, 1.4]:
212 | raise ValueError('If imagenet weights are being loaded, '
213 | 'alpha can be one of'
214 | '`0.35`, `0.50`, `0.75`, `1.0`, `1.3` and `1.4` only.')
215 |
216 | if rows != cols or rows not in [96, 128, 160, 192, 224]:
217 | raise ValueError('If imagenet weights are being loaded, '
218 | 'input must have a static square shape (one of '
219 | '(06, 96), (128,128), (160,160), (192,192), or '
220 | '(224, 224)).Input shape provided = %s' % (input_shape,))
221 |
222 | if K.image_data_format() != 'channels_last':
223 | warnings.warn('The MobileNet family of models is only available '
224 | 'for the input data format "channels_last" '
225 | '(width, height, channels). '
226 | 'However your settings specify the default '
227 | 'data format "channels_first" (channels, width, height).'
228 | ' You should set `image_data_format="channels_last"` '
229 | 'in your Keras config located at ~/.keras/keras.json. '
230 | 'The model being returned right now will expect inputs '
231 | 'to follow the "channels_last" data format.')
232 | K.set_image_data_format('channels_last')
233 | old_data_format = 'channels_first'
234 | else:
235 | old_data_format = None
236 |
237 | if input_tensor is None:
238 | img_input = Input(shape=input_shape)
239 | else:
240 | if not K.is_keras_tensor(input_tensor):
241 | img_input = Input(tensor=input_tensor, shape=input_shape)
242 | else:
243 | img_input = input_tensor
244 |
245 | x = _conv_block(img_input, 32, alpha, bn_epsilon=1e-3, bn_momentum=0.99, weight_decay=weight_decay)
246 | x = _depthwise_conv_block_v2(x, 16, alpha, 1, depth_multiplier, bn_epsilon=1e-3, bn_momentum=0.99,
247 | weight_decay=weight_decay, block_id=1)
248 |
249 | x = _depthwise_conv_block_v2(x, 24, alpha, expansion_factor, depth_multiplier, block_id=2,
250 | bn_epsilon=1e-3, bn_momentum=0.99, weight_decay=weight_decay, strides=(2, 2))
251 | x = _depthwise_conv_block_v2(x, 24, alpha, expansion_factor, depth_multiplier, bn_epsilon=1e-3, bn_momentum=0.99,
252 | weight_decay=weight_decay, block_id=3)
253 |
254 | x = _depthwise_conv_block_v2(x, 32, alpha, expansion_factor, depth_multiplier, block_id=4,
255 | bn_epsilon=1e-3, bn_momentum=0.99, weight_decay=weight_decay)
256 | x = _depthwise_conv_block_v2(x, 32, alpha, expansion_factor, depth_multiplier, bn_epsilon=1e-3, bn_momentum=0.99,
257 | weight_decay=weight_decay, block_id=5)
258 | x = _depthwise_conv_block_v2(x, 32, alpha, expansion_factor, depth_multiplier, bn_epsilon=1e-3, bn_momentum=0.99,
259 | weight_decay=weight_decay, block_id=6)
260 |
261 | x = _depthwise_conv_block_v2(x, 64, alpha, expansion_factor, depth_multiplier, block_id=7,
262 | bn_epsilon=1e-3, bn_momentum=0.99, weight_decay=weight_decay, strides=(2, 2))
263 | x = _depthwise_conv_block_v2(x, 64, alpha, expansion_factor, depth_multiplier, bn_epsilon=1e-3, bn_momentum=0.99,
264 | weight_decay=weight_decay, block_id=8)
265 | x = _depthwise_conv_block_v2(x, 64, alpha, expansion_factor, depth_multiplier, bn_epsilon=1e-3, bn_momentum=0.99,
266 | weight_decay=weight_decay, block_id=9)
267 | x = _depthwise_conv_block_v2(x, 64, alpha, expansion_factor, depth_multiplier, bn_epsilon=1e-3, bn_momentum=0.99,
268 | weight_decay=weight_decay, block_id=10)
269 |
270 | if alpha <= 1.0:
271 | penultimate_filters = 1280
272 | else:
273 | penultimate_filters = int(1280 * alpha)
274 |
275 | x = _conv_block(x, penultimate_filters, alpha=1.0, kernel=(1, 1), bn_epsilon=1e-3, bn_momentum=0.99,
276 | block_id=18)
277 |
278 | if include_top:
279 | if K.image_data_format() == 'channels_first':
280 | shape = (penultimate_filters, 1, 1)
281 | else:
282 | shape = (1, 1, penultimate_filters)
283 |
284 | x = GlobalAveragePooling2D()(x)
285 | x = Reshape(shape, name='reshape_1')(x)
286 | x = Dropout(dropout, name='dropout')(x)
287 | x = Conv2D(classes, (1, 1),
288 | kernel_initializer=initializers.he_normal(),
289 | kernel_regularizer=regularizers.l2(weight_decay),
290 | padding='same', name='conv_preds')(x)
291 | x = Activation('softmax', name='act_softmax')(x)
292 | x = Reshape((classes,), name='reshape_2')(x)
293 | else:
294 | if pooling == 'avg':
295 | x = GlobalAveragePooling2D()(x)
296 | elif pooling == 'max':
297 | x = GlobalMaxPooling2D()(x)
298 |
299 | # Ensure that the model takes into account
300 | # any potential predecessors of `input_tensor`.
301 | if input_tensor is not None:
302 | inputs = get_source_inputs(input_tensor)
303 | else:
304 | inputs = img_input
305 |
306 | # Create model.
307 | model = Model(inputs, x, name='mobilenetV2_%0.2f_%s' % (alpha, rows))
308 |
309 | # load weights
310 | if weights == 'imagenet':
311 | if K.image_data_format() == 'channels_first':
312 | raise ValueError('Weights for "channels_last" format '
313 | 'are not available.')
314 | if alpha == 1.0:
315 | alpha_text = '1_0'
316 | elif alpha == 1.3:
317 | alpha_text = '1_3'
318 | elif alpha == 1.4:
319 | alpha_text = '1_4'
320 | elif alpha == 0.75:
321 | alpha_text = '7_5'
322 | elif alpha == 0.50:
323 | alpha_text = '5_0'
324 | else:
325 | alpha_text = '3_5'
326 |
327 | if include_top:
328 | model_name = 'mobilenet_v2_%s_%d_tf.h5' % (alpha_text, rows)
329 | weigh_path = BASE_WEIGHT_PATH_V2 + model_name
330 | weights_path = get_file(model_name,
331 | weigh_path,
332 | cache_subdir='models')
333 | else:
334 | model_name = 'mobilenet_v2_%s_%d_tf_no_top.h5' % (alpha_text, rows)
335 | weigh_path = BASE_WEIGHT_PATH_V2 + model_name
336 | weights_path = get_file(model_name,
337 | weigh_path,
338 | cache_subdir='models')
339 | model.load_weights(weights_path)
340 |
341 | if old_data_format:
342 | K.set_image_data_format(old_data_format)
343 | return model
344 |
345 |
346 | # taken from https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/conv_blocks.py
347 | def _make_divisible(v, divisor=8, min_value=8):
348 | if min_value is None:
349 | min_value = divisor
350 |
351 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
352 | # Make sure that round down does not go down by more than 10%.
353 | if new_v < 0.9 * v:
354 | new_v += divisor
355 | return new_v
356 |
357 |
358 | def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1), bn_epsilon=1e-3,
359 | bn_momentum=0.99, weight_decay=0., block_id=1):
360 | """Adds an initial convolution layer (with batch normalization and relu6).
361 | # Arguments
362 | inputs: Input tensor of shape `(rows, cols, 3)`
363 | (with `channels_last` data format) or
364 | (3, rows, cols) (with `channels_first` data format).
365 | It should have exactly 3 inputs channels,
366 | and width and height should be no smaller than 32.
367 | E.g. `(224, 224, 3)` would be one valid value.
368 | filters: Integer, the dimensionality of the output space
369 | (i.e. the number output of filters in the convolution).
370 | alpha: controls the width of the network.
371 | - If `alpha` < 1.0, proportionally decreases the number
372 | of filters in each layer.
373 | - If `alpha` > 1.0, proportionally increases the number
374 | of filters in each layer.
375 | - If `alpha` = 1, default number of filters from the paper
376 | are used at each layer.
377 | kernel: An integer or tuple/list of 2 integers, specifying the
378 | width and height of the 2D convolution window.
379 | Can be a single integer to specify the same value for
380 | all spatial dimensions.
381 | strides: An integer or tuple/list of 2 integers,
382 | specifying the strides of the convolution along the width and height.
383 | Can be a single integer to specify the same value for
384 | all spatial dimensions.
385 | Specifying any stride value != 1 is incompatible with specifying
386 | any `dilation_rate` value != 1.
387 | bn_epsilon: Epsilon value for BatchNormalization
388 | bn_momentum: Momentum value for BatchNormalization
389 | # Input shape
390 | 4D tensor with shape:
391 | `(samples, channels, rows, cols)` if data_format='channels_first'
392 | or 4D tensor with shape:
393 | `(samples, rows, cols, channels)` if data_format='channels_last'.
394 | # Output shape
395 | 4D tensor with shape:
396 | `(samples, filters, new_rows, new_cols)` if data_format='channels_first'
397 | or 4D tensor with shape:
398 | `(samples, new_rows, new_cols, filters)` if data_format='channels_last'.
399 | `rows` and `cols` values might have changed due to stride.
400 | # Returns
401 | Output tensor of block.
402 | """
403 | channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
404 | filters = filters * alpha
405 | filters = _make_divisible(filters)
406 | x = Conv2D(filters, kernel,
407 | padding='same',
408 | use_bias=False,
409 | strides=strides,
410 | kernel_initializer=initializers.he_normal(),
411 | kernel_regularizer=regularizers.l2(weight_decay),
412 | name='conv%d' % block_id)(inputs)
413 | x = BatchNormalization(axis=channel_axis, momentum=bn_momentum, epsilon=bn_epsilon,
414 | name='conv%d_bn' % block_id)(x)
415 | return Activation(relu6, name='conv%d_relu' % block_id)(x)
416 |
417 |
418 | def _depthwise_conv_block_v2(inputs, pointwise_conv_filters, alpha, expansion_factor,
419 | depth_multiplier=1, strides=(1, 1), bn_epsilon=1e-3,
420 | bn_momentum=0.99, weight_decay=0.0, block_id=1):
421 | """Adds a depthwise convolution block V2.
422 | A depthwise convolution V2 block consists of a depthwise conv,
423 | batch normalization, relu6, pointwise convolution,
424 | batch normalization and relu6 activation.
425 | # Arguments
426 | inputs: Input tensor of shape `(rows, cols, channels)`
427 | (with `channels_last` data format) or
428 | (channels, rows, cols) (with `channels_first` data format).
429 | pointwise_conv_filters: Integer, the dimensionality of the output space
430 | (i.e. the number output of filters in the pointwise convolution).
431 | alpha: controls the width of the network.
432 | - If `alpha` < 1.0, proportionally decreases the number
433 | of filters in each layer.
434 | - If `alpha` > 1.0, proportionally increases the number
435 | of filters in each layer.
436 | - If `alpha` = 1, default number of filters from the paper
437 | are used at each layer.
438 | expansion_factor: controls the expansion of the internal bottleneck
439 | blocks. Should be a positive integer >= 1
440 | depth_multiplier: The number of depthwise convolution output channels
441 | for each input channel.
442 | The total number of depthwise convolution output
443 | channels will be equal to `filters_in * depth_multiplier`.
444 | strides: An integer or tuple/list of 2 integers,
445 | specifying the strides of the convolution along the width and height.
446 | Can be a single integer to specify the same value for
447 | all spatial dimensions.
448 | Specifying any stride value != 1 is incompatible with specifying
449 | any `dilation_rate` value != 1.
450 | bn_epsilon: Epsilon value for BatchNormalization
451 | bn_momentum: Momentum value for BatchNormalization
452 | block_id: Integer, a unique identification designating the block number.
453 | # Input shape
454 | 4D tensor with shape:
455 | `(batch, channels, rows, cols)` if data_format='channels_first'
456 | or 4D tensor with shape:
457 | `(batch, rows, cols, channels)` if data_format='channels_last'.
458 | # Output shape
459 | 4D tensor with shape:
460 | `(batch, filters, new_rows, new_cols)` if data_format='channels_first'
461 | or 4D tensor with shape:
462 | `(batch, new_rows, new_cols, filters)` if data_format='channels_last'.
463 | `rows` and `cols` values might have changed due to stride.
464 | # Returns
465 | Output tensor of block.
466 | """
467 | channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
468 | input_shape = K.int_shape(inputs)
469 | depthwise_conv_filters = _make_divisible(input_shape[channel_axis] * expansion_factor)
470 | pointwise_conv_filters = _make_divisible(pointwise_conv_filters * alpha)
471 |
472 | if depthwise_conv_filters > input_shape[channel_axis]:
473 | x = Conv2D(depthwise_conv_filters, (1, 1),
474 | padding='same',
475 | use_bias=False,
476 | strides=(1, 1),
477 | kernel_initializer=initializers.he_normal(),
478 | kernel_regularizer=regularizers.l2(weight_decay),
479 | name='conv_expand_%d' % block_id)(inputs)
480 | x = BatchNormalization(axis=channel_axis, momentum=bn_momentum, epsilon=bn_epsilon,
481 | name='conv_expand_%d_bn' % block_id)(x)
482 | x = Activation(relu6, name='conv_expand_%d_relu' % block_id)(x)
483 | else:
484 | x = inputs
485 |
486 | x = DepthwiseConv2D((3, 3),
487 | padding='same',
488 | depth_multiplier=depth_multiplier,
489 | strides=strides,
490 | use_bias=False,
491 | depthwise_initializer=initializers.he_normal(),
492 | depthwise_regularizer=regularizers.l2(weight_decay),
493 | name='conv_dw_%d' % block_id)(x)
494 | x = BatchNormalization(axis=channel_axis, momentum=bn_momentum, epsilon=bn_epsilon,
495 | name='conv_dw_%d_bn' % block_id)(x)
496 | x = Activation(relu6, name='conv_dw_%d_relu' % block_id)(x)
497 |
498 | x = Conv2D(pointwise_conv_filters, (1, 1),
499 | padding='same',
500 | use_bias=False,
501 | strides=(1, 1),
502 | kernel_initializer=initializers.he_normal(),
503 | kernel_regularizer=regularizers.l2(weight_decay),
504 | name='conv_pw_%d' % block_id)(x)
505 | x = BatchNormalization(axis=channel_axis, momentum=bn_momentum, epsilon=bn_epsilon,
506 | name='conv_pw_%d_bn' % block_id)(x)
507 |
508 | if strides == (2, 2):
509 | return x
510 | else:
511 | if input_shape[channel_axis] == pointwise_conv_filters:
512 |
513 | x = add([inputs, x])
514 |
515 | return x
516 |
517 |
518 | if __name__ == '__main__':
519 | import tensorflow as tf
520 | from keras import backend as K
521 |
522 | run_metadata = tf.RunMetadata()
523 |
524 | with tf.Session(graph=tf.Graph()) as sess:
525 | K.set_session(sess)
526 |
527 | model = MiniMobileNetV2(input_tensor=tf.placeholder('float32', shape=(1, 224, 224, 3)), alpha=1.0)
528 | opt = tf.profiler.ProfileOptionBuilder.float_operation()
529 | flops = tf.profiler.profile(sess.graph, run_meta=run_metadata, cmd='op', options=opt)
530 |
531 | opt = tf.profiler.ProfileOptionBuilder.trainable_variables_parameter()
532 | param_count = tf.profiler.profile(sess.graph, run_meta=run_metadata, cmd='op', options=opt)
533 |
534 | print('flops:', flops.total_float_ops)
535 | print('param count:', param_count.total_parameters)
536 |
537 | model.summary()
538 |
--------------------------------------------------------------------------------
/models/mobilenet/train_cifar_10.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from keras example cifar10_cnn.py
3 | Train NASNet-CIFAR on the CIFAR10 small images dataset.
4 | """
5 | from __future__ import print_function
6 | import os
7 |
8 | from keras.datasets import cifar10
9 | from keras.preprocessing.image import ImageDataGenerator
10 | from keras.utils import np_utils
11 | from keras.callbacks import ModelCheckpoint
12 | from keras.optimizers import SGD
13 | import numpy as np
14 |
15 | from clr import OneCycleLR
16 | from models.mobilenet.mobilenets import MiniMobileNetV2
17 |
18 | if not os.path.exists('weights/'):
19 | os.makedirs('weights/')
20 |
21 | weights_file = 'weights/mobilenet_v2.h5'
22 | model_checkpoint = ModelCheckpoint(
23 | weights_file,
24 | monitor='val_acc',
25 | save_best_only=True,
26 | save_weights_only=True,
27 | mode='max')
28 | batch_size = 128
29 | nb_classes = 10
30 | nb_epoch = 100 # Only finding lr
31 | data_augmentation = True
32 |
33 | # input image dimensions
34 | img_rows, img_cols = 32, 32
35 | # The CIFAR10 images are RGB.
36 | img_channels = 3
37 |
38 | # The data, shuffled and split between train and test sets:
39 | (X_train, y_train), (X_test, y_test) = cifar10.load_data()
40 |
41 | # Convert class vectors to binary class matrices.
42 | Y_train = np_utils.to_categorical(y_train, nb_classes)
43 | Y_test = np_utils.to_categorical(y_test, nb_classes)
44 |
45 | X_train = X_train.astype('float32')
46 | X_test = X_test.astype('float32')
47 |
48 | # preprocess input
49 | mean = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32')
50 | std = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32')
51 |
52 | print("Channel Mean : ", mean)
53 | print("Channel Std : ", std)
54 |
55 | X_train = (X_train - mean) / (std)
56 | X_test = (X_test - mean) / (std)
57 |
58 | # Learning rate finder callback setup
59 | num_samples = X_train.shape[0]
60 |
61 | lr_manager = OneCycleLR(max_lr=0.02, maximum_momentum=0.9, verbose=True)
62 |
63 | # For training, the auxilary branch must be used to correctly train NASNet
64 | model = MiniMobileNetV2((img_rows, img_cols, img_channels),
65 | alpha=1.4,
66 | weight_decay=1e-6,
67 | weights=None,
68 | classes=nb_classes)
69 | model.summary()
70 |
71 | # These values will be overridden by the above callback
72 | optimizer = SGD(lr=0.002, momentum=0.9, nesterov=True)
73 | model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
74 |
75 | model.load_weights(weights_file)
76 |
77 | if not data_augmentation:
78 | print('Not using data augmentation.')
79 | model.fit(
80 | X_train,
81 | Y_train,
82 | batch_size=batch_size,
83 | epochs=nb_epoch,
84 | validation_data=(X_test, Y_test),
85 | shuffle=True,
86 | verbose=1,
87 | callbacks=[lr_manager, model_checkpoint])
88 | else:
89 | print('Using real-time data augmentation.')
90 | # This will do preprocessing and realtime data augmentation:
91 | datagen = ImageDataGenerator(
92 | featurewise_center=False, # set input mean to 0 over the dataset
93 | samplewise_center=False, # set each sample mean to 0
94 | featurewise_std_normalization=False, # divide inputs by std of the dataset
95 | samplewise_std_normalization=False, # divide each input by its std
96 | zca_whitening=False, # apply ZCA whitening
97 | # randomly rotate images in the range (degrees, 0 to 180)
98 | rotation_range=0,
99 | # randomly shift images horizontally (fraction of total width)
100 | width_shift_range=0,
101 | # randomly shift images vertically (fraction of total height)
102 | height_shift_range=0,
103 | horizontal_flip=True, # randomly flip images
104 | vertical_flip=False) # randomly flip images
105 |
106 | # Compute quantities required for featurewise normalization
107 | # (std, mean, and principal components if ZCA whitening is applied).
108 | datagen.fit(X_train)
109 |
110 | # Fit the model on the batches generated by datagen.flow().
111 | model.fit_generator(
112 | datagen.flow(X_train, Y_train, batch_size=batch_size, shuffle=True),
113 | steps_per_epoch=X_train.shape[0] // batch_size,
114 | validation_data=(X_test, Y_test),
115 | epochs=nb_epoch,
116 | verbose=1,
117 | callbacks=[lr_manager, model_checkpoint])
118 |
119 | scores = model.evaluate(X_test, Y_test, batch_size=batch_size)
120 | for score, metric_name in zip(scores, model.metrics_names):
121 | print("%s : %0.4f" % (metric_name, score))
122 |
--------------------------------------------------------------------------------
/models/mobilenet/weights/losses.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/losses.npy
--------------------------------------------------------------------------------
/models/mobilenet/weights/lrs.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/lrs.npy
--------------------------------------------------------------------------------
/models/mobilenet/weights/mobilenet_v2 - 9033.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/mobilenet_v2 - 9033.h5
--------------------------------------------------------------------------------
/models/mobilenet/weights/mobilenet_v2.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/mobilenet_v2.h5
--------------------------------------------------------------------------------
/models/mobilenet/weights/momentum/momentum-0.9/losses.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/momentum/momentum-0.9/losses.npy
--------------------------------------------------------------------------------
/models/mobilenet/weights/momentum/momentum-0.9/lrs.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/momentum/momentum-0.9/lrs.npy
--------------------------------------------------------------------------------
/models/mobilenet/weights/momentum/momentum-0.95/losses.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/momentum/momentum-0.95/losses.npy
--------------------------------------------------------------------------------
/models/mobilenet/weights/momentum/momentum-0.95/lrs.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/momentum/momentum-0.95/lrs.npy
--------------------------------------------------------------------------------
/models/mobilenet/weights/momentum/momentum-0.99/losses.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/momentum/momentum-0.99/losses.npy
--------------------------------------------------------------------------------
/models/mobilenet/weights/momentum/momentum-0.99/lrs.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/momentum/momentum-0.99/lrs.npy
--------------------------------------------------------------------------------
/models/mobilenet/weights/weight_decay/weight_decay-1e-05/losses.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/weight_decay/weight_decay-1e-05/losses.npy
--------------------------------------------------------------------------------
/models/mobilenet/weights/weight_decay/weight_decay-1e-05/lrs.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/weight_decay/weight_decay-1e-05/lrs.npy
--------------------------------------------------------------------------------
/models/mobilenet/weights/weight_decay/weight_decay-1e-06/losses.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/weight_decay/weight_decay-1e-06/losses.npy
--------------------------------------------------------------------------------
/models/mobilenet/weights/weight_decay/weight_decay-1e-06/lrs.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/weight_decay/weight_decay-1e-06/lrs.npy
--------------------------------------------------------------------------------
/models/mobilenet/weights/weight_decay/weight_decay-1e-07/losses.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/weight_decay/weight_decay-1e-07/losses.npy
--------------------------------------------------------------------------------
/models/mobilenet/weights/weight_decay/weight_decay-1e-07/lrs.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/weight_decay/weight_decay-1e-07/lrs.npy
--------------------------------------------------------------------------------
/models/mobilenet/weights/weight_decay/weight_decay-3e-05/losses.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/weight_decay/weight_decay-3e-05/losses.npy
--------------------------------------------------------------------------------
/models/mobilenet/weights/weight_decay/weight_decay-3e-05/lrs.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/weight_decay/weight_decay-3e-05/lrs.npy
--------------------------------------------------------------------------------
/models/mobilenet/weights/weight_decay/weight_decay-3e-06/losses.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/weight_decay/weight_decay-3e-06/losses.npy
--------------------------------------------------------------------------------
/models/mobilenet/weights/weight_decay/weight_decay-3e-06/lrs.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/weight_decay/weight_decay-3e-06/lrs.npy
--------------------------------------------------------------------------------
/models/small/find_lr_schedule.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from keras example cifar10_cnn.py
3 | Train NASNet-CIFAR on the CIFAR10 small images dataset.
4 | """
5 | from __future__ import print_function
6 | import os
7 |
8 | from keras.datasets import cifar10
9 | from keras.preprocessing.image import ImageDataGenerator
10 | from keras.utils import np_utils
11 | from keras.callbacks import ModelCheckpoint
12 | from keras.optimizers import SGD
13 | import numpy as np
14 |
15 | from clr import LRFinder
16 | from models.small.model import MiniVGG
17 |
18 | if not os.path.exists('weights/'):
19 | os.makedirs('weights/')
20 |
21 | weights_file = 'weights/small_vgg_v2_schedule.h5'
22 | model_checkpoint = ModelCheckpoint(weights_file, monitor='val_acc', save_best_only=True,
23 | save_weights_only=True, mode='max')
24 |
25 | batch_size = 128
26 | nb_classes = 10
27 | nb_epoch = 1 # Only finding lr
28 | data_augmentation = True
29 |
30 | # input image dimensions
31 | img_rows, img_cols = 32, 32
32 | # The CIFAR10 images are RGB.
33 | img_channels = 3
34 |
35 | # The data, shuffled and split between train and test sets:
36 | (X_train, y_train), (X_test, y_test) = cifar10.load_data()
37 |
38 | # Convert class vectors to binary class matrices.
39 | Y_train = np_utils.to_categorical(y_train, nb_classes)
40 | Y_test = np_utils.to_categorical(y_test, nb_classes)
41 |
42 | X_train = X_train.astype('float32')
43 | X_test = X_test.astype('float32')
44 |
45 | # preprocess input
46 | mean = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32')
47 | std = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32')
48 |
49 | print("Channel Mean : ", mean)
50 | print("Channel Std : ", std)
51 |
52 | X_train = (X_train - mean) / (std)
53 | X_test = (X_test - mean) / (std)
54 |
55 | # Learning rate finder callback setup
56 | num_samples = X_train.shape[0]
57 |
58 | # Exponential lr finder
59 | # USE THIS FOR A LARGE RANGE SEARCH
60 | # Uncomment the validation_data flag to reduce speed but get a better idea of the learning rate
61 | lr_finder = LRFinder(num_samples, batch_size, minimum_lr=1e-5, maximum_lr=10.,
62 | lr_scale='exp',
63 | validation_data=(X_test, Y_test), # use the validation data for losses
64 | validation_sample_rate=5,
65 | save_dir='weights/', verbose=True)
66 |
67 | # Linear lr finder
68 | # USE THIS FOR A CLOSE SEARCH
69 | # Uncomment the validation_data flag to reduce speed but get a better idea of the learning rate
70 | # lr_finder = LRFinder(num_samples, batch_size, minimum_lr=5e-4, maximum_lr=1e-2,
71 | # lr_scale='linear',
72 | # validation_data=(X_test, y_test), # use the validation data for losses
73 | # validation_sample_rate=5,
74 | # save_dir='weights/', verbose=True)
75 |
76 | # plot the previous values if present
77 | LRFinder.plot_schedule_from_file('weights/', clip_beginning=10, clip_endding=5)
78 |
79 | # For training, the auxilary branch must be used to correctly train NASNet
80 |
81 | model = MiniVGG((img_rows, img_cols, img_channels),
82 | dropout=0, weights=None, classes=nb_classes)
83 | model.summary()
84 |
85 | optimizer = SGD(lr=0.1, momentum=0.9, nesterov=True)
86 | model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
87 |
88 | # model.load_weights(weights_file)
89 |
90 | if not data_augmentation:
91 | print('Not using data augmentation.')
92 | model.fit(X_train, Y_train,
93 | batch_size=batch_size,
94 | epochs=nb_epoch,
95 | validation_data=(X_test, Y_test),
96 | shuffle=True,
97 | verbose=1,
98 | callbacks=[lr_finder, model_checkpoint])
99 | else:
100 | print('Using real-time data augmentation.')
101 | # This will do preprocessing and realtime data augmentation:
102 | datagen = ImageDataGenerator(
103 | featurewise_center=False, # set input mean to 0 over the dataset
104 | samplewise_center=False, # set each sample mean to 0
105 | featurewise_std_normalization=False, # divide inputs by std of the dataset
106 | samplewise_std_normalization=False, # divide each input by its std
107 | zca_whitening=False, # apply ZCA whitening
108 | rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180)
109 | width_shift_range=0, # randomly shift images horizontally (fraction of total width)
110 | height_shift_range=0, # randomly shift images vertically (fraction of total height)
111 | horizontal_flip=True, # randomly flip images
112 | vertical_flip=False) # randomly flip images
113 |
114 | # Compute quantities required for featurewise normalization
115 | # (std, mean, and principal components if ZCA whitening is applied).
116 | datagen.fit(X_train)
117 |
118 | # Fit the model on the batches generated by datagen.flow().
119 | model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size, shuffle=True),
120 | steps_per_epoch=X_train.shape[0] // batch_size,
121 | validation_data=(X_test, Y_test),
122 | epochs=nb_epoch, verbose=1,
123 | callbacks=[lr_finder, model_checkpoint])
124 |
125 | lr_finder.plot_schedule(clip_beginning=10, clip_endding=5)
126 |
127 | scores = model.evaluate(X_test, Y_test, batch_size=batch_size)
128 | for score, metric_name in zip(scores, model.metrics_names):
129 | print("%s : %0.4f" % (metric_name, score))
130 |
--------------------------------------------------------------------------------
/models/small/find_momentum_schedule.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from keras example cifar10_cnn.py
3 | Train NASNet-CIFAR on the CIFAR10 small images dataset.
4 | """
5 | from __future__ import print_function
6 | import os
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 |
10 | from keras.datasets import cifar10
11 | from keras.preprocessing.image import ImageDataGenerator
12 | from keras.utils import np_utils
13 | from keras.callbacks import ModelCheckpoint
14 | from keras.optimizers import SGD
15 | from keras import backend as K
16 |
17 | from clr import LRFinder
18 | from models.small.model import MiniVGG
19 |
20 | plt.style.use('seaborn-white')
21 |
22 | batch_size = 128
23 | nb_classes = 10
24 | nb_epoch = 1 # Only finding lr
25 | data_augmentation = True
26 |
27 | # input image dimensions
28 | img_rows, img_cols = 32, 32
29 | # The CIFAR10 images are RGB.
30 | img_channels = 3
31 |
32 | # The data, shuffled and split between train and test sets:
33 | (X_train, y_train), (X_test, y_test) = cifar10.load_data()
34 |
35 | # Convert class vectors to binary class matrices.
36 | Y_train = np_utils.to_categorical(y_train, nb_classes)
37 | Y_test = np_utils.to_categorical(y_test, nb_classes)
38 |
39 | X_train = X_train.astype('float32')
40 | X_test = X_test.astype('float32')
41 |
42 | # preprocess input
43 | mean = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32')
44 | std = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32')
45 |
46 | print("Channel Mean : ", mean)
47 | print("Channel Std : ", std)
48 |
49 | X_train = (X_train - mean) / (std)
50 | X_test = (X_test - mean) / (std)
51 |
52 | # Learning rate finder callback setup
53 | num_samples = X_train.shape[0]
54 |
55 | MOMENTUMS = [0.9, 0.95, 0.99]
56 |
57 | # for momentum in MOMENTUMS:
58 | # K.clear_session()
59 | #
60 | # # Learning rate range obtained from `find_lr_schedule.py`
61 | # # NOTE : Minimum is 10x smaller than the max found above !
62 | # # NOTE : It is preferable to use the validation data here to get a correct value
63 | # lr_finder = LRFinder(num_samples, batch_size, minimum_lr=0.00125, maximum_lr=0.0125,
64 | # validation_data=(X_test, Y_test),
65 | # validation_sample_rate=5,
66 | # lr_scale='linear', save_dir='weights/momentum/momentum-%s/' % str(momentum),
67 | # verbose=True)
68 | #
69 | # model = MiniVGG((img_rows, img_cols, img_channels),
70 | # dropout=0, weights=None, classes=nb_classes)
71 | # model.summary()
72 | #
73 | # # set the weight_decay here !
74 | # # lr doesnt matter as it will be over written by the callback
75 | # optimizer = SGD(lr=0.00125, momentum=momentum, nesterov=True)
76 | # model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
77 | #
78 | # # model.load_weights(weights_file)
79 | #
80 | # if not data_augmentation:
81 | # print('Not using data augmentation.')
82 | # model.fit(X_train, Y_train,
83 | # batch_size=batch_size,
84 | # epochs=nb_epoch,
85 | # validation_data=(X_test, Y_test),
86 | # shuffle=True,
87 | # verbose=1,
88 | # callbacks=[lr_finder])
89 | # else:
90 | # print('Using real-time data augmentation.')
91 | # # This will do preprocessing and realtime data augmentation:
92 | # datagen = ImageDataGenerator(
93 | # featurewise_center=False, # set input mean to 0 over the dataset
94 | # samplewise_center=False, # set each sample mean to 0
95 | # featurewise_std_normalization=False, # divide inputs by std of the dataset
96 | # samplewise_std_normalization=False, # divide each input by its std
97 | # zca_whitening=False, # apply ZCA whitening
98 | # rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180)
99 | # width_shift_range=0, # randomly shift images horizontally (fraction of total width)
100 | # height_shift_range=0, # randomly shift images vertically (fraction of total height)
101 | # horizontal_flip=True, # randomly flip images
102 | # vertical_flip=False) # randomly flip images
103 | #
104 | # # Compute quantities required for featurewise normalization
105 | # # (std, mean, and principal components if ZCA whitening is applied).
106 | # datagen.fit(X_train)
107 | #
108 | # # Fit the model on the batches generated by datagen.flow().
109 | # model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size, shuffle=True),
110 | # steps_per_epoch=X_train.shape[0] // batch_size,
111 | # validation_data=(X_test, Y_test),
112 | # epochs=nb_epoch, verbose=1,
113 | # callbacks=[lr_finder])
114 |
115 | # from plot we see, the model isnt impacted by the weight_decay very much at all
116 | # so we can use any of them.
117 |
118 | for momentum in MOMENTUMS:
119 | directory = 'weights/momentum/momentum-%s/' % str(momentum)
120 |
121 | losses, lrs = LRFinder.restore_schedule_from_dir(directory, 10, 5)
122 | plt.plot(lrs, losses, label='momentum=%0.2f' % momentum)
123 |
124 | plt.title("Momentum")
125 | plt.xlabel("Learning rate")
126 | plt.ylabel("Validation Loss")
127 | plt.legend()
128 | plt.show()
129 |
--------------------------------------------------------------------------------
/models/small/find_weight_decay_schedule.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from keras example cifar10_cnn.py
3 | Train NASNet-CIFAR on the CIFAR10 small images dataset.
4 | """
5 | from __future__ import print_function
6 | import os
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 |
10 | from keras.datasets import cifar10
11 | from keras.preprocessing.image import ImageDataGenerator
12 | from keras.utils import np_utils
13 | from keras.callbacks import ModelCheckpoint
14 | from keras.optimizers import SGD
15 | from keras import backend as K
16 |
17 | from clr import LRFinder
18 | from models.small.model import MiniVGG
19 |
20 | plt.style.use('seaborn-white')
21 |
22 | batch_size = 128
23 | nb_classes = 10
24 | nb_epoch = 1 # Only finding lr
25 | data_augmentation = True
26 |
27 | # input image dimensions
28 | img_rows, img_cols = 32, 32
29 | # The CIFAR10 images are RGB.
30 | img_channels = 3
31 |
32 | # The data, shuffled and split between train and test sets:
33 | (X_train, y_train), (X_test, y_test) = cifar10.load_data()
34 |
35 | # Convert class vectors to binary class matrices.
36 | Y_train = np_utils.to_categorical(y_train, nb_classes)
37 | Y_test = np_utils.to_categorical(y_test, nb_classes)
38 |
39 | X_train = X_train.astype('float32')
40 | X_test = X_test.astype('float32')
41 |
42 | # preprocess input
43 | mean = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32')
44 | std = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32')
45 |
46 | print("Channel Mean : ", mean)
47 | print("Channel Std : ", std)
48 |
49 | X_train = (X_train - mean) / (std)
50 | X_test = (X_test - mean) / (std)
51 |
52 | # Learning rate finder callback setup
53 | num_samples = X_train.shape[0]
54 |
55 | # INITIAL WEIGHT DECAY FACTORS
56 | WEIGHT_DECAY_FACTORS = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7]
57 |
58 | # FINEGRAINED WEIGHT DECAY FACTORS
59 | # WEIGHT_DECAY_FACTORS = [3e-3, 3e-4]
60 |
61 | # for weight_decay in WEIGHT_DECAY_FACTORS:
62 | # K.clear_session()
63 | #
64 | # # Learning rate range obtained from `find_lr_schedule.py`
65 | # # NOTE : Minimum is 10x smaller than the max found above !
66 | # # NOTE : It is preferable to use the validation data here to get a correct value
67 | # lr_finder = LRFinder(num_samples, batch_size, minimum_lr=0.00125, maximum_lr=0.0125,
68 | # validation_data=(X_test, Y_test),
69 | # validation_sample_rate=5,
70 | # lr_scale='linear', save_dir='weights/weight_decay/weight_decay-%s/' % str(weight_decay),
71 | # verbose=True)
72 | #
73 | # # SETUP THE WEIGHT DECAY IN THE MODEL
74 | # model = MiniVGG((img_rows, img_cols, img_channels),
75 | # weight_decay=weight_decay, dropout=0,
76 | # weights=None, classes=nb_classes)
77 | # model.summary()
78 | #
79 | # # set the weight_decay here !
80 | # # lr doesnt matter as it will be over written by the callback
81 | # optimizer = SGD(lr=0.002, momentum=0.95, nesterov=True)
82 | # model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
83 | #
84 | # # model.load_weights(weights_file)
85 | #
86 | # if not data_augmentation:
87 | # print('Not using data augmentation.')
88 | # model.fit(X_train, Y_train,
89 | # batch_size=batch_size,
90 | # epochs=nb_epoch,
91 | # validation_data=(X_test, Y_test),
92 | # shuffle=True,
93 | # verbose=1,
94 | # callbacks=[lr_finder])
95 | # else:
96 | # print('Using real-time data augmentation.')
97 | # # This will do preprocessing and realtime data augmentation:
98 | # datagen = ImageDataGenerator(
99 | # featurewise_center=False, # set input mean to 0 over the dataset
100 | # samplewise_center=False, # set each sample mean to 0
101 | # featurewise_std_normalization=False, # divide inputs by std of the dataset
102 | # samplewise_std_normalization=False, # divide each input by its std
103 | # zca_whitening=False, # apply ZCA whitening
104 | # rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180)
105 | # width_shift_range=0, # randomly shift images horizontally (fraction of total width)
106 | # height_shift_range=0, # randomly shift images vertically (fraction of total height)
107 | # horizontal_flip=True, # randomly flip images
108 | # vertical_flip=False) # randomly flip images
109 | #
110 | # # Compute quantities required for featurewise normalization
111 | # # (std, mean, and principal components if ZCA whitening is applied).
112 | # datagen.fit(X_train)
113 | #
114 | # # Fit the model on the batches generated by datagen.flow().
115 | # model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size, shuffle=True),
116 | # steps_per_epoch=X_train.shape[0] // batch_size,
117 | # validation_data=(X_test, Y_test),
118 | # epochs=nb_epoch, verbose=1,
119 | # callbacks=[lr_finder])
120 |
121 | # from plot we see, the model isnt impacted by the weight_decay very much at all
122 | # so we can use any of them.
123 |
124 | WEIGHT_DECAY_FACTORS = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7] + [3e-3, 3e-4]
125 |
126 | for weight_decay in WEIGHT_DECAY_FACTORS:
127 | directory = 'weights/weight_decay/weight_decay-%s/' % str(weight_decay)
128 |
129 | losses, lrs = LRFinder.restore_schedule_from_dir(directory, 10, 5)
130 | plt.plot(lrs, losses, label='weight_decay=%0.7f' % weight_decay)
131 |
132 | plt.title("Weight Decay")
133 | plt.xlabel("Learning rate")
134 | plt.ylabel("Validation Loss")
135 | plt.legend()
136 | plt.show()
137 |
--------------------------------------------------------------------------------
/models/small/model.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from __future__ import absolute_import
3 | from __future__ import division
4 |
5 | import warnings
6 | import math
7 |
8 | from keras.models import Model
9 | from keras.layers import Input
10 | from keras.layers import Activation
11 | from keras.layers import Dropout
12 | from keras.layers import Reshape
13 | from keras.layers import BatchNormalization
14 | from keras.layers import GlobalAveragePooling2D
15 | from keras.layers import GlobalMaxPooling2D
16 | from keras.layers import Conv2D
17 | from keras.layers import Dense
18 | from keras.layers import add, concatenate
19 | from keras import initializers
20 | from keras import regularizers
21 | from keras import constraints
22 | from keras.utils import conv_utils
23 | from keras.utils.data_utils import get_file
24 | from keras.engine.topology import get_source_inputs
25 | from keras.engine import InputSpec
26 | from keras.applications.imagenet_utils import _obtain_input_shape
27 | from keras.applications.inception_v3 import preprocess_input
28 | from keras.applications.imagenet_utils import decode_predictions
29 | from keras import backend as K
30 |
31 | import tensorflow as tf
32 |
33 | BASE_WEIGHT_PATH = ''
34 | BASE_WEIGHT_PATH_V2 = ''
35 |
36 |
37 | def relu6(x):
38 | return K.relu(x, max_value=6)
39 |
40 |
41 | def MiniVGG(input_shape=None,
42 | dropout=0.,
43 | weight_decay=0.,
44 | include_top=True,
45 | weights=None,
46 | input_tensor=None,
47 | pooling=None,
48 | classes=10):
49 | """Mini VGG is a 3 layer small CNN network
50 |
51 | # Arguments
52 | input_shape: optional shape tuple, only to be specified
53 | if `include_top` is False (otherwise the input shape
54 | has to be `(224, 224, 3)` (with `channels_last` data format)
55 | or (3, 224, 224) (with `channels_first` data format).
56 | It should have exactly 3 inputs channels,
57 | and width and height should be no smaller than 32.
58 | E.g. `(200, 200, 3)` would be one valid value.
59 | dropout: dropout rate
60 | weight_decay: Weight decay factor.
61 | include_top: whether to include the fully-connected
62 | layer at the top of the network.
63 | weights: `None` (random initialization) or
64 | `imagenet` (ImageNet weights)
65 | input_tensor: optional Keras tensor (i.e. output of
66 | `layers.Input()`)
67 | to use as image input for the model.
68 | pooling: Optional pooling mode for feature extraction
69 | when `include_top` is `False`.
70 | - `None` means that the output of the model
71 | will be the 4D tensor output of the
72 | last convolutional layer.
73 | - `avg` means that global average pooling
74 | will be applied to the output of the
75 | last convolutional layer, and thus
76 | the output of the model will be a
77 | 2D tensor.
78 | - `max` means that global max pooling will
79 | be applied.
80 | classes: optional number of classes to classify images
81 | into, only to be specified if `include_top` is True, and
82 | if no `weights` argument is specified.
83 | # Returns
84 | A Keras model instance.
85 | # Raises
86 | ValueError: in case of invalid argument for `weights`,
87 | or invalid input shape.
88 | RuntimeError: If attempting to run this model with a
89 | backend that does not support separable convolutions.
90 | """
91 |
92 | if K.backend() != 'tensorflow':
93 | raise RuntimeError('Only Tensorflow backend is currently supported, '
94 | 'as other backends do not support '
95 | 'depthwise convolution.')
96 |
97 | if weights not in {'imagenet', None}:
98 | raise ValueError('The `weights` argument should be either '
99 | '`None` (random initialization) or `imagenet` '
100 | '(pre-training on ImageNet).')
101 |
102 | if weights == 'imagenet' and include_top and classes != 1000:
103 | raise ValueError('If using `weights` as ImageNet with `include_top` '
104 | 'as true, `classes` should be 1000')
105 |
106 | # Determine proper input shape and default size.
107 | if input_shape is None:
108 | default_size = 224
109 | else:
110 | if K.image_data_format() == 'channels_first':
111 | rows = input_shape[1]
112 | cols = input_shape[2]
113 | else:
114 | rows = input_shape[0]
115 | cols = input_shape[1]
116 |
117 | if rows == cols and rows in [96, 128, 160, 192, 224]:
118 | default_size = rows
119 | else:
120 | default_size = 224
121 |
122 | input_shape = _obtain_input_shape(input_shape,
123 | default_size=default_size,
124 | min_size=32,
125 | data_format=K.image_data_format(),
126 | require_flatten=include_top or weights)
127 | if K.image_data_format() == 'channels_last':
128 | row_axis, col_axis = (0, 1)
129 | channel_axis = -1
130 | else:
131 | row_axis, col_axis = (1, 2)
132 | channel_axis = 1
133 |
134 | rows = input_shape[row_axis]
135 | cols = input_shape[col_axis]
136 |
137 | if input_tensor is None:
138 | img_input = Input(shape=input_shape)
139 | else:
140 | if not K.is_keras_tensor(input_tensor):
141 | img_input = Input(tensor=input_tensor, shape=input_shape)
142 | else:
143 | img_input = input_tensor
144 |
145 | x = _conv_block(img_input, 32, bn_epsilon=1e-3, bn_momentum=0.99, weight_decay=weight_decay, block_id=1)
146 | x = _conv_block(x, 64, bn_epsilon=1e-3, bn_momentum=0.99, weight_decay=weight_decay, block_id=2)
147 | x = _conv_block(x, 96, bn_epsilon=1e-3, bn_momentum=0.99, weight_decay=weight_decay, block_id=3)
148 |
149 | if include_top:
150 |
151 | # Fast.ai's Concat Pooling
152 | a = GlobalAveragePooling2D()(x)
153 | b = GlobalMaxPooling2D()(x)
154 |
155 | x = concatenate([a, b], axis=channel_axis)
156 |
157 | x = Dropout(dropout, name='dropout')(x)
158 | x = Dense(classes, activation='softmax', name='conv_preds')(x)
159 | else:
160 | if pooling == 'avg':
161 | x = GlobalAveragePooling2D()(x)
162 | elif pooling == 'max':
163 | x = GlobalMaxPooling2D()(x)
164 |
165 | # Ensure that the model takes into account
166 | # any potential predecessors of `input_tensor`.
167 | if input_tensor is not None:
168 | inputs = get_source_inputs(input_tensor)
169 | else:
170 | inputs = img_input
171 |
172 | # Create model.
173 | model = Model(inputs, x, name='mini_vgg_%0.2f_%s')
174 | return model
175 |
176 |
177 | # taken from https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/conv_blocks.py
178 | def _make_divisible(v, divisor=8, min_value=8):
179 | if min_value is None:
180 | min_value = divisor
181 |
182 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
183 | # Make sure that round down does not go down by more than 10%.
184 | if new_v < 0.9 * v:
185 | new_v += divisor
186 | return new_v
187 |
188 |
189 | def _conv_block(inputs, filters, kernel=(3, 3), strides=(1, 1), bn_epsilon=1e-3,
190 | bn_momentum=0.99, weight_decay=0., block_id=1):
191 | """Adds an initial convolution layer (with batch normalization and relu6).
192 | # Arguments
193 | inputs: Input tensor of shape `(rows, cols, 3)`
194 | (with `channels_last` data format) or
195 | (3, rows, cols) (with `channels_first` data format).
196 | It should have exactly 3 inputs channels,
197 | and width and height should be no smaller than 32.
198 | E.g. `(224, 224, 3)` would be one valid value.
199 | filters: Integer, the dimensionality of the output space
200 | (i.e. the number output of filters in the convolution).
201 | alpha: controls the width of the network.
202 | - If `alpha` < 1.0, proportionally decreases the number
203 | of filters in each layer.
204 | - If `alpha` > 1.0, proportionally increases the number
205 | of filters in each layer.
206 | - If `alpha` = 1, default number of filters from the paper
207 | are used at each layer.
208 | kernel: An integer or tuple/list of 2 integers, specifying the
209 | width and height of the 2D convolution window.
210 | Can be a single integer to specify the same value for
211 | all spatial dimensions.
212 | strides: An integer or tuple/list of 2 integers,
213 | specifying the strides of the convolution along the width and height.
214 | Can be a single integer to specify the same value for
215 | all spatial dimensions.
216 | Specifying any stride value != 1 is incompatible with specifying
217 | any `dilation_rate` value != 1.
218 | bn_epsilon: Epsilon value for BatchNormalization
219 | bn_momentum: Momentum value for BatchNormalization
220 | # Input shape
221 | 4D tensor with shape:
222 | `(samples, channels, rows, cols)` if data_format='channels_first'
223 | or 4D tensor with shape:
224 | `(samples, rows, cols, channels)` if data_format='channels_last'.
225 | # Output shape
226 | 4D tensor with shape:
227 | `(samples, filters, new_rows, new_cols)` if data_format='channels_first'
228 | or 4D tensor with shape:
229 | `(samples, new_rows, new_cols, filters)` if data_format='channels_last'.
230 | `rows` and `cols` values might have changed due to stride.
231 | # Returns
232 | Output tensor of block.
233 | """
234 | channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
235 | filters = _make_divisible(filters)
236 | x = Conv2D(filters, kernel,
237 | padding='same',
238 | use_bias=False,
239 | strides=strides,
240 | kernel_initializer=initializers.he_normal(),
241 | kernel_regularizer=regularizers.l2(weight_decay),
242 | name='conv%d' % block_id)(inputs)
243 | x = BatchNormalization(axis=channel_axis, momentum=bn_momentum, epsilon=bn_epsilon,
244 | name='conv%d_bn' % block_id)(x)
245 | return Activation(relu6, name='conv%d_relu' % block_id)(x)
246 |
247 |
248 | if __name__ == '__main__':
249 | import tensorflow as tf
250 | from keras import backend as K
251 |
252 | run_metadata = tf.RunMetadata()
253 |
254 | with tf.Session(graph=tf.Graph()) as sess:
255 | K.set_session(sess)
256 |
257 | model = MiniVGG(input_tensor=tf.placeholder('float32', shape=(1, 32, 32, 3)))
258 | opt = tf.profiler.ProfileOptionBuilder.float_operation()
259 | flops = tf.profiler.profile(sess.graph, run_meta=run_metadata, cmd='op', options=opt)
260 |
261 | opt = tf.profiler.ProfileOptionBuilder.trainable_variables_parameter()
262 | param_count = tf.profiler.profile(sess.graph, run_meta=run_metadata, cmd='op', options=opt)
263 |
264 | print('flops:', flops.total_float_ops)
265 | print('param count:', param_count.total_parameters)
266 |
267 | model.summary()
268 |
--------------------------------------------------------------------------------
/models/small/train_cifar_10.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from keras example cifar10_cnn.py
3 | Train NASNet-CIFAR on the CIFAR10 small images dataset.
4 | """
5 | from __future__ import print_function
6 | import os
7 |
8 | from keras.datasets import cifar10
9 | from keras.preprocessing.image import ImageDataGenerator
10 | from keras.utils import np_utils
11 | from keras.callbacks import ModelCheckpoint
12 | from keras.optimizers import SGD
13 | import numpy as np
14 |
15 | from clr import OneCycleLR
16 | from models.small.model import MiniVGG
17 |
18 | if not os.path.exists('weights/'):
19 | os.makedirs('weights/')
20 |
21 | weights_file = 'weights/mini_vgg.h5'
22 | model_checkpoint = ModelCheckpoint(
23 | weights_file,
24 | monitor='val_acc',
25 | save_best_only=True,
26 | save_weights_only=True,
27 | mode='max')
28 | batch_size = 128
29 | nb_classes = 10
30 | nb_epoch = 50 # Only finding lr
31 | data_augmentation = True
32 |
33 | # input image dimensions
34 | img_rows, img_cols = 32, 32
35 | # The CIFAR10 images are RGB.
36 | img_channels = 3
37 |
38 | # The data, shuffled and split between train and test sets:
39 | (X_train, y_train), (X_test, y_test) = cifar10.load_data()
40 |
41 | # Convert class vectors to binary class matrices.
42 | Y_train = np_utils.to_categorical(y_train, nb_classes)
43 | Y_test = np_utils.to_categorical(y_test, nb_classes)
44 |
45 | X_train = X_train.astype('float32')
46 | X_test = X_test.astype('float32')
47 |
48 | # preprocess input
49 | mean = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32')
50 | std = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32')
51 |
52 | print("Channel Mean : ", mean)
53 | print("Channel Std : ", std)
54 |
55 | X_train = (X_train - mean) / (std)
56 | X_test = (X_test - mean) / (std)
57 |
58 | # Learning rate finder callback setup
59 | num_samples = X_train.shape[0]
60 |
61 | # When using the validation set for LRFinder, try out values starting from 2x
62 | # the lr found there and move lower until its good for the first few epochs
63 | lr_manager = OneCycleLR(
64 | max_lr=0.025,
65 | end_percentage=0.2,
66 | scale_percentage=0.1,
67 | maximum_momentum=0.95,
68 | verbose=True)
69 |
70 | # For training, the auxilary branch must be used to correctly train NASNet
71 | model = MiniVGG((img_rows, img_cols, img_channels),
72 | weight_decay=1e-5,
73 | weights=None,
74 | classes=nb_classes)
75 | model.summary()
76 |
77 | # These values will be overridden by the above callback
78 | optimizer = SGD(lr=0.0025, momentum=0.95, nesterov=True)
79 | model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
80 |
81 | model.load_weights(weights_file)
82 |
83 | if not data_augmentation:
84 | print('Not using data augmentation.')
85 | model.fit(
86 | X_train,
87 | Y_train,
88 | batch_size=batch_size,
89 | epochs=nb_epoch,
90 | validation_data=(X_test, Y_test),
91 | shuffle=True,
92 | verbose=1,
93 | callbacks=[lr_manager, model_checkpoint])
94 | else:
95 | print('Using real-time data augmentation.')
96 | # This will do preprocessing and realtime data augmentation:
97 | datagen = ImageDataGenerator(
98 | featurewise_center=False, # set input mean to 0 over the dataset
99 | samplewise_center=False, # set each sample mean to 0
100 | featurewise_std_normalization=False, # divide inputs by std of the dataset
101 | samplewise_std_normalization=False, # divide each input by its std
102 | zca_whitening=False, # apply ZCA whitening
103 | # randomly rotate images in the range (degrees, 0 to 180)
104 | rotation_range=0,
105 | # randomly shift images horizontally (fraction of total width)
106 | width_shift_range=0,
107 | # randomly shift images vertically (fraction of total height)
108 | height_shift_range=0,
109 | horizontal_flip=True, # randomly flip images
110 | vertical_flip=False) # randomly flip images
111 |
112 | # Compute quantities required for featurewise normalization
113 | # (std, mean, and principal components if ZCA whitening is applied).
114 | datagen.fit(X_train)
115 |
116 | # Fit the model on the batches generated by datagen.flow().
117 | # model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size, shuffle=True),
118 | # steps_per_epoch=X_train.shape[0] // batch_size,
119 | # validation_data=(X_test, Y_test),
120 | # epochs=nb_epoch, verbose=1,
121 | # callbacks=[lr_manager, model_checkpoint])
122 |
123 | scores = model.evaluate(X_test, Y_test, batch_size=batch_size)
124 | for score, metric_name in zip(scores, model.metrics_names):
125 | print("%s : %0.4f" % (metric_name, score))
126 |
--------------------------------------------------------------------------------
/models/small/weights/losses.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/losses.npy
--------------------------------------------------------------------------------
/models/small/weights/lrs.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/lrs.npy
--------------------------------------------------------------------------------
/models/small/weights/mini_vgg.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/mini_vgg.h5
--------------------------------------------------------------------------------
/models/small/weights/momentum/momentum-0.9/losses.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/momentum/momentum-0.9/losses.npy
--------------------------------------------------------------------------------
/models/small/weights/momentum/momentum-0.9/lrs.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/momentum/momentum-0.9/lrs.npy
--------------------------------------------------------------------------------
/models/small/weights/momentum/momentum-0.95/losses.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/momentum/momentum-0.95/losses.npy
--------------------------------------------------------------------------------
/models/small/weights/momentum/momentum-0.95/lrs.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/momentum/momentum-0.95/lrs.npy
--------------------------------------------------------------------------------
/models/small/weights/momentum/momentum-0.99/losses.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/momentum/momentum-0.99/losses.npy
--------------------------------------------------------------------------------
/models/small/weights/momentum/momentum-0.99/lrs.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/momentum/momentum-0.99/lrs.npy
--------------------------------------------------------------------------------
/models/small/weights/weight_decay/weight_decay-0.0001/losses.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-0.0001/losses.npy
--------------------------------------------------------------------------------
/models/small/weights/weight_decay/weight_decay-0.0001/lrs.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-0.0001/lrs.npy
--------------------------------------------------------------------------------
/models/small/weights/weight_decay/weight_decay-0.0003/losses.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-0.0003/losses.npy
--------------------------------------------------------------------------------
/models/small/weights/weight_decay/weight_decay-0.0003/lrs.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-0.0003/lrs.npy
--------------------------------------------------------------------------------
/models/small/weights/weight_decay/weight_decay-0.001/losses.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-0.001/losses.npy
--------------------------------------------------------------------------------
/models/small/weights/weight_decay/weight_decay-0.001/lrs.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-0.001/lrs.npy
--------------------------------------------------------------------------------
/models/small/weights/weight_decay/weight_decay-0.003/losses.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-0.003/losses.npy
--------------------------------------------------------------------------------
/models/small/weights/weight_decay/weight_decay-0.003/lrs.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-0.003/lrs.npy
--------------------------------------------------------------------------------
/models/small/weights/weight_decay/weight_decay-1e-05/losses.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-1e-05/losses.npy
--------------------------------------------------------------------------------
/models/small/weights/weight_decay/weight_decay-1e-05/lrs.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-1e-05/lrs.npy
--------------------------------------------------------------------------------
/models/small/weights/weight_decay/weight_decay-1e-06/losses.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-1e-06/losses.npy
--------------------------------------------------------------------------------
/models/small/weights/weight_decay/weight_decay-1e-06/lrs.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-1e-06/lrs.npy
--------------------------------------------------------------------------------
/models/small/weights/weight_decay/weight_decay-1e-07/losses.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-1e-07/losses.npy
--------------------------------------------------------------------------------
/models/small/weights/weight_decay/weight_decay-1e-07/lrs.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-1e-07/lrs.npy
--------------------------------------------------------------------------------
/plot_clr.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 | from keras.models import Model
5 | from keras.layers import Dense, Activation, Input
6 | from keras.optimizers import SGD, Adam
7 |
8 | from clr import OneCycleLR
9 |
10 | plt.style.use('seaborn-white')
11 |
12 | # Constants
13 | NUM_SAMPLES = 2000
14 | NUM_EPOCHS = 100
15 | BATCH_SIZE = 500
16 | MAX_LR = 0.1
17 |
18 | # Data
19 | X = np.random.rand(NUM_SAMPLES, 10)
20 | Y = np.random.randint(0, 2, size=NUM_SAMPLES)
21 |
22 | # Model
23 | inp = Input(shape=(10,))
24 | x = Dense(5, activation='relu')(inp)
25 | x = Dense(1, activation='sigmoid')(x)
26 | model = Model(inp, x)
27 |
28 | clr_triangular = OneCycleLR(NUM_SAMPLES, NUM_EPOCHS, BATCH_SIZE, MAX_LR,
29 | end_percentage=0.2, scale_percentage=0.2)
30 |
31 | model.compile(optimizer=SGD(0.1), loss='binary_crossentropy', metrics=['accuracy'])
32 |
33 | model.fit(X, Y, batch_size=BATCH_SIZE, epochs=NUM_EPOCHS, callbacks=[clr_triangular], verbose=0)
34 |
35 |
36 | print("LR Range : ", min(clr_triangular.history['lr']), max(clr_triangular.history['lr']))
37 | print("Momentum Range : ", min(clr_triangular.history['momentum']), max(clr_triangular.history['momentum']))
38 |
39 |
40 | plt.xlabel('Training Iterations')
41 | plt.ylabel('Learning Rate')
42 | plt.title("CLR")
43 | plt.plot(clr_triangular.history['lr'])
44 | plt.show()
45 |
46 | plt.xlabel('Training Iterations')
47 | plt.ylabel('Momentum')
48 | plt.title("CLR")
49 | plt.plot(clr_triangular.history['momentum'])
50 | plt.show()
51 |
--------------------------------------------------------------------------------