├── .gitignore ├── LICENSE ├── README.md ├── clr.py ├── images ├── lr.png ├── momentum.png ├── one_cycle_lr.png ├── one_cycle_momentum.png └── weight_decay.png ├── models ├── mobilenet │ ├── find_lr_schedule.py │ ├── find_momentum_schedule.py │ ├── find_weight_decay_schedule.py │ ├── mobilenets.py │ ├── train_cifar_10.py │ └── weights │ │ ├── losses.npy │ │ ├── lrs.npy │ │ ├── mobilenet_v2 - 9033.h5 │ │ ├── mobilenet_v2.h5 │ │ ├── momentum │ │ ├── momentum-0.9 │ │ │ ├── losses.npy │ │ │ └── lrs.npy │ │ ├── momentum-0.95 │ │ │ ├── losses.npy │ │ │ └── lrs.npy │ │ └── momentum-0.99 │ │ │ ├── losses.npy │ │ │ └── lrs.npy │ │ └── weight_decay │ │ ├── weight_decay-1e-05 │ │ ├── losses.npy │ │ └── lrs.npy │ │ ├── weight_decay-1e-06 │ │ ├── losses.npy │ │ └── lrs.npy │ │ ├── weight_decay-1e-07 │ │ ├── losses.npy │ │ └── lrs.npy │ │ ├── weight_decay-3e-05 │ │ ├── losses.npy │ │ └── lrs.npy │ │ └── weight_decay-3e-06 │ │ ├── losses.npy │ │ └── lrs.npy └── small │ ├── find_lr_schedule.py │ ├── find_momentum_schedule.py │ ├── find_weight_decay_schedule.py │ ├── model.py │ ├── train_cifar_10.py │ └── weights │ ├── losses.npy │ ├── lrs.npy │ ├── mini_vgg.h5 │ ├── momentum │ ├── momentum-0.9 │ │ ├── losses.npy │ │ └── lrs.npy │ ├── momentum-0.95 │ │ ├── losses.npy │ │ └── lrs.npy │ └── momentum-0.99 │ │ ├── losses.npy │ │ └── lrs.npy │ └── weight_decay │ ├── weight_decay-0.0001 │ ├── losses.npy │ └── lrs.npy │ ├── weight_decay-0.0003 │ ├── losses.npy │ └── lrs.npy │ ├── weight_decay-0.001 │ ├── losses.npy │ └── lrs.npy │ ├── weight_decay-0.003 │ ├── losses.npy │ └── lrs.npy │ ├── weight_decay-1e-05 │ ├── losses.npy │ └── lrs.npy │ ├── weight_decay-1e-06 │ ├── losses.npy │ └── lrs.npy │ └── weight_decay-1e-07 │ ├── losses.npy │ └── lrs.npy └── plot_clr.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # PyCharm 104 | .idea/* 105 | 106 | # Weights 107 | weights/*.h5 108 | models/**/*.h5 109 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Somshubra Majumdar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # One Cycle Learning Rate Policy for Keras 2 | Implementation of One-Cycle Learning rate policy from the papers by Leslie N. Smith. 3 | 4 | - [A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay](https://arxiv.org/abs/1803.09820) 5 | - [Super-Convergence: Very Fast Training of Residual Networks Using Large Learning Rates](https://arxiv.org/abs/1708.07120) 6 | 7 | Contains two Keras callbacks, `LRFinder` and `OneCycleLR` which are ported from the PyTorch *Fast.ai* library. 8 | 9 | # What is One Cycle Learning Rate 10 | It is the combination of gradually increasing learning rate, and optionally, gradually decreasing the momentum during the first half of the cycle, then gradually decreasing the learning rate and optionally increasing the momentum during the latter half of the cycle. 11 | 12 | Finally, in a certain percentage of the end of the cycle, the learning rate is sharply reduced every epoch. 13 | 14 | The Learning rate schedule is visualized as : 15 | 16 | 17 | 18 | The Optional Momentum schedule is visualized as : 19 | 20 | 21 | 22 | # Usage 23 | 24 | ## Finding a good learning rate 25 | Use `LRFinder` to obtain a loss plot, and visually inspect it to determine the initial loss plot. Provided below is an example, used for the `MiniMobileNetV2` model. 26 | 27 | An example script has been provided in `find_lr_schedule.py` inside the `models/mobilenet/`. 28 | 29 | Essentially, 30 | 31 | ```python 32 | from clr import LRFinder 33 | 34 | lr_callback = LRFinder(num_samples, batch_size, 35 | minimum_lr, maximum_lr, 36 | # validation_data=(X_val, Y_val), 37 | lr_scale='exp', save_dir='path/to/save/directory') 38 | 39 | # Ensure that number of epochs = 1 when calling fit() 40 | model.fit(X, Y, epochs=1, batch_size=batch_size, callbacks=[lr_callback]) 41 | ``` 42 | The above callback does a few things. 43 | 44 | - Must supply number of samples in the dataset (here, 50k from CIFAR 10) and the batch size that will be used during training. 45 | - `lr_scale` is set to `exp` - useful when searching over a large range of learning rates. Set to `linear` to search a smaller space. 46 | - `save_dir` - Automatic saving of the results of LRFinder on some directory path specified. This is highly encouraged. 47 | - `validation_data` - provide the validation data as a tuple to use that for the loss plot instead of the training batch loss. Since the validation dataset can be very large, we will randomly sample `k` batches (k * batch_size) from the validation set to provide quick estimate of the validation loss. The default value of `k` can be changed by changing `validation_sample_rate` 48 | 49 | **Note : When using this, be careful about setting the learning rate, momentum and weight decay schedule. The loss plots will be more erratic due to the sampling of the validation set.** 50 | 51 | **NOTE 2 :** 52 | 53 | - It is faster to get the learning rate without using `validation_data`, and then find the weight decay and momentum based on that learning rate while using `validation_data`. 54 | - You can also use `LRFinder` to find the optimal weight decay and momentum values using the examples `find_momentum_schedule.py` and `find_weight_decay_schedule.py` inside `models/mobilenet/` folder. 55 | 56 | To visualize the plot, there are two ways - 57 | 58 | - Use `lr_callback.plot_schedule()` after the fit() call. This uses the current training session results. 59 | - Use class method `LRFinder.plot_schedule_from_dir('path/to/save/directory')` to visualize the plot separately from the training session. This only works if you used the `save_dir` argument to save the results of the search to some location. 60 | 61 | ## Finding the optimal Momentum 62 | 63 | Use the `find_momentum_schedule.py` script inside `models/mobilenet/` for an example. 64 | 65 | Some notes : 66 | 67 | - Use a grid search over a few possible momentum values, such as `[0.8, 0.85, 0.9, 0.95, 0.99]`. Use `linear` as the `lr_scale` argument value. 68 | - Set the momentum value manually to the SGD optimizer before compiling the model. 69 | - Plot the curve at the end and visually see which momentum value yields the least noisy / lowest losses overall on the plot. The absolute value of the loss plot is not very important as much as the curve. 70 | 71 | - It is better to supply the `validation_data` here. 72 | - The plot will be very noisy, so if you wish, can use a larger value of `loss_smoothing_beta` (such as `0.99` or `0.995`) 73 | - The actual curve values doesnt matter as much as what is overall curve movement. Choose the value which is more steady and tries to get the lowest value even at large learning rates. 74 | 75 | ## Finding the optimal Weight Decay 76 | 77 | Use the `find_weight_decay_schedule.py` script inside `models/mobilenet/` for an example 78 | 79 | Some notes : 80 | 81 | - Use a grid search over a few weight decay values, such as `[1e-3, 1e-4, 1e-5, 1e-6, 1e-7]`. Call this "coarse search" and use `linear` for the `lr_scale` argument. 82 | - Use a grid search over a select few weight decay values, such as `[3e-7, 1e-7, 3e-6]`. Call this "fine search" and use `linear` scale for the `lr_scale` argument. 83 | - Set the weight decay value manually to the model when building the model. 84 | - Plot the curve at the end and visually see which weight decay value yields the least noisy / lowest losses overall on the plot. The absolute value of the loss plot is not very important as much as the curve. 85 | 86 | - It is better to supply the `validation_data` here. 87 | - The plot will be very noisy, so if you wish, can use a larger value of `loss_smoothing_beta` (such as `0.99` or `0.995`) 88 | - The actual curve values doesnt matter as much as what is overall curve movement. Choose the value which is more steady and tries to get the lowest value even at large learning rates. 89 | 90 | 91 | ## Interpreting the plot 92 | 93 | ### Learning Rate 94 | 95 | 96 | 97 | 98 | 99 | Consider the above plot from using the `LRFinder` on the MiniMobileNetV2 model. In particular, there are a few regions above that we need to carefully interpret. 100 | 101 | **Note : The values are in log 10 scale (since `exp` was used for `lr_scale`)** ; All values discussed will be based on the x-axis (learning rate) : 102 | 103 | - After the -1.5 point on the graph, the loss becomes erratic 104 | - After the 0.5 point on the graph, the loss is noisy but doesn't decrease any further. 105 | - **-1.7** is the last relatively smooth portion before the **-1.5** region. To be safe, we can choose to move a little more to the left, closer to -1.8, but this will reduce the performance. 106 | - It is usually important to visualize the first 2-3 epochs of `OneCycleLR` training with values close to these edges to determine which is the best. 107 | 108 | ### Momentum 109 | 110 | Using the above learning rate, use this information to next calculate the optimal momentum (`find_momentum_schedule.py`) 111 | 112 | 113 | 114 | 115 | 116 | See the notes in the `Finding the optimal momentum` section on how to interpret the plot. 117 | 118 | ### Weight Decay 119 | 120 | Similarly, it is possible to use the above learning rate and momentum values to calculate the optimal weight decay (`find_weight_decay_schedule.py`). 121 | 122 | **Note : Due to large learning rates acting as a strong regularizer, other regularization techniques like weight decay and dropout should be decreased significantly to properly train the model.** 123 | 124 | 125 | 126 | 127 | 128 | It is best to search a range of regularization strength between 1e-3 to 1e-7 first, and then fine-search the region that provided the best overall plot. 129 | 130 | See the notes in the `Finding the optimal weight decay` section on how to interpret the plot. 131 | 132 | ## Training with `OneCycleLR` 133 | Once we find the maximum learning rate, we can then move onto using the `OneCycleLR` callback with SGD to train our model. 134 | 135 | ```python 136 | from clr import OneCycleLR 137 | 138 | lr_manager = OneCycleLR(num_samples, num_epoch, batch_size, max_lr 139 | end_percentage=0.1, scale_percentage=None, 140 | maximum_momentum=0.95, minimum_momentum=0.85) 141 | 142 | model.fit(X, Y, epochs=EPOCHS, batch_size=batch_size, callbacks=[model_checkpoint, lr_manager], 143 | ...) 144 | ``` 145 | 146 | There are many parameters, but a few of the important ones : 147 | - Must provide a lot of training information - `number of samples`, `number of epochs`, `batch size` and `max learning rate` 148 | - `end_percentage` is used to determine what percentage of the training epochs will be used for steep reduction in the learning rate. At its miminum, the lowest learning rate will be calculated as 1/1000th of the `max_lr` provided. 149 | - `scale_percentage` is a confusing parameter. It dictates the scaling factor of the learning rate in the second half of the training cycle. **It is best to test this out visually using the `plot_clr.py` script to ensure there are no mistakes**. Leaving it as None defaults to using the same percentage as the provided `end_percentage`. 150 | - `maximum/minimum_momentum` are preset according to the paper and `Fast.ai`. However, if you don't wish to scale it, set both to the same value, generally `0.9` is preferred as the momentum value for SGD. If you don't want to update the momentum / are not using SGD (not adviseable) - set both to None to ignore the momentum updates. 151 | 152 | # Results 153 | 154 | - **-1.7** is chosen to be the maximum learning rate (in log10 space) for the `OneCycleLR` schedule. Since this is in log10 scale, we use `10 ^ (x)` to get the actual learning maximum learning rate. Here, `10 ^ -1.7 ~ 0.019999`. Therefore, we round up to a **maximum learning rate of 0.02** 155 | - **0.9** is chosen as the maximum momentum from the momentum plot. Using Cyclic Momentum updates, choose a slightly lower value (**0.85**) as the minimum for faster training. 156 | - **3e-6** is chosen as the the weight decay factor. 157 | 158 | For the MiniMobileNetV2 model, 2 passes of the OneCycle LR with SGD (40 epochs - max lr = 0.02, 30 epochs - max lr = 0.005) obtained 90.33%. This may not seem like much, but this is a model with only 650k parameters, and in comparison, the same model trained on Adam with initial learning rate 2e-3 did not converge to the same score in over 100 epochs (89.14%). 159 | 160 | # Requirements 161 | - Keras 2.1.6+ 162 | - Tensorflow (tested) / Theano / CNTK for the backend 163 | - matplotlib to visualize the plots. 164 | -------------------------------------------------------------------------------- /clr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import warnings 4 | 5 | from keras.callbacks import Callback 6 | from keras import backend as K 7 | 8 | 9 | # Code is ported from https://github.com/fastai/fastai 10 | class OneCycleLR(Callback): 11 | def __init__(self, 12 | num_samples, 13 | batch_size, 14 | max_lr, 15 | end_percentage=0.1, 16 | scale_percentage=None, 17 | maximum_momentum=0.95, 18 | minimum_momentum=0.85, 19 | verbose=True): 20 | """ This callback implements a cyclical learning rate policy (CLR). 21 | This is a special case of Cyclic Learning Rates, where we have only 1 cycle. 22 | After the completion of 1 cycle, the learning rate will decrease rapidly to 23 | 100th its initial lowest value. 24 | 25 | # Arguments: 26 | num_samples: Integer. Number of samples in the dataset. 27 | batch_size: Integer. Batch size during training. 28 | max_lr: Float. Initial learning rate. This also sets the 29 | starting learning rate (which will be 10x smaller than 30 | this), and will increase to this value during the first cycle. 31 | end_percentage: Float. The percentage of all the epochs of training 32 | that will be dedicated to sharply decreasing the learning 33 | rate after the completion of 1 cycle. Must be between 0 and 1. 34 | scale_percentage: Float or None. If float, must be between 0 and 1. 35 | If None, it will compute the scale_percentage automatically 36 | based on the `end_percentage`. 37 | maximum_momentum: Optional. Sets the maximum momentum (initial) 38 | value, which gradually drops to its lowest value in half-cycle, 39 | then gradually increases again to stay constant at this max value. 40 | Can only be used with SGD Optimizer. 41 | minimum_momentum: Optional. Sets the minimum momentum at the end of 42 | the half-cycle. Can only be used with SGD Optimizer. 43 | verbose: Bool. Whether to print the current learning rate after every 44 | epoch. 45 | 46 | # Reference 47 | - [A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, weight_decay, and weight decay](https://arxiv.org/abs/1803.09820) 48 | - [Super-Convergence: Very Fast Training of Residual Networks Using Large Learning Rates](https://arxiv.org/abs/1708.07120) 49 | """ 50 | super(OneCycleLR, self).__init__() 51 | 52 | if end_percentage < 0. or end_percentage > 1.: 53 | raise ValueError("`end_percentage` must be between 0 and 1") 54 | 55 | if scale_percentage is not None and (scale_percentage < 0. or scale_percentage > 1.): 56 | raise ValueError("`scale_percentage` must be between 0 and 1") 57 | 58 | self.initial_lr = max_lr 59 | self.end_percentage = end_percentage 60 | self.scale = float(scale_percentage) if scale_percentage is not None else float(end_percentage) 61 | self.max_momentum = maximum_momentum 62 | self.min_momentum = minimum_momentum 63 | self.verbose = verbose 64 | 65 | if self.max_momentum is not None and self.min_momentum is not None: 66 | self._update_momentum = True 67 | else: 68 | self._update_momentum = False 69 | 70 | self.clr_iterations = 0. 71 | self.history = {} 72 | 73 | self.epochs = None 74 | self.batch_size = batch_size 75 | self.samples = num_samples 76 | self.steps = None 77 | self.num_iterations = None 78 | self.mid_cycle_id = None 79 | 80 | def _reset(self): 81 | """ 82 | Reset the callback. 83 | """ 84 | self.clr_iterations = 0. 85 | self.history = {} 86 | 87 | def compute_lr(self): 88 | """ 89 | Compute the learning rate based on which phase of the cycle it is in. 90 | 91 | - If in the first half of training, the learning rate gradually increases. 92 | - If in the second half of training, the learning rate gradually decreases. 93 | - If in the final `end_percentage` portion of training, the learning rate 94 | is quickly reduced to near 100th of the original min learning rate. 95 | 96 | # Returns: 97 | the new learning rate 98 | """ 99 | if self.clr_iterations > 2 * self.mid_cycle_id: 100 | current_percentage = (self.clr_iterations - 2 * self.mid_cycle_id) 101 | current_percentage /= float((self.num_iterations - 2 * self.mid_cycle_id)) 102 | new_lr = self.initial_lr * (1. + (current_percentage * 103 | (1. - 100.) / 100.)) * self.scale 104 | 105 | elif self.clr_iterations > self.mid_cycle_id: 106 | current_percentage = 1. - ( 107 | self.clr_iterations - self.mid_cycle_id) / self.mid_cycle_id 108 | new_lr = self.initial_lr * (1. + current_percentage * 109 | (self.scale * 100 - 1.)) * self.scale 110 | 111 | else: 112 | current_percentage = self.clr_iterations / self.mid_cycle_id 113 | new_lr = self.initial_lr * (1. + current_percentage * 114 | (self.scale * 100 - 1.)) * self.scale 115 | 116 | if self.clr_iterations == self.num_iterations: 117 | self.clr_iterations = 0 118 | 119 | return new_lr 120 | 121 | def compute_momentum(self): 122 | """ 123 | Compute the momentum based on which phase of the cycle it is in. 124 | 125 | - If in the first half of training, the momentum gradually decreases. 126 | - If in the second half of training, the momentum gradually increases. 127 | - If in the final `end_percentage` portion of training, the momentum value 128 | is kept constant at the maximum initial value. 129 | 130 | # Returns: 131 | the new momentum value 132 | """ 133 | if self.clr_iterations > 2 * self.mid_cycle_id: 134 | new_momentum = self.max_momentum 135 | 136 | elif self.clr_iterations > self.mid_cycle_id: 137 | current_percentage = 1. - ((self.clr_iterations - self.mid_cycle_id) / float( 138 | self.mid_cycle_id)) 139 | new_momentum = self.max_momentum - current_percentage * ( 140 | self.max_momentum - self.min_momentum) 141 | 142 | else: 143 | current_percentage = self.clr_iterations / float(self.mid_cycle_id) 144 | new_momentum = self.max_momentum - current_percentage * ( 145 | self.max_momentum - self.min_momentum) 146 | 147 | return new_momentum 148 | 149 | def on_train_begin(self, logs={}): 150 | logs = logs or {} 151 | 152 | self.epochs = self.params['epochs'] 153 | # When fit generator is used 154 | # self.params don't have the elements 'batch_size' and 'samples' 155 | # self.batch_size = self.params['batch_size'] 156 | # self.samples = self.params['samples'] 157 | self.steps = self.params['steps'] 158 | 159 | if self.steps is not None: 160 | self.num_iterations = self.epochs * self.steps 161 | else: 162 | if (self.samples % self.batch_size) == 0: 163 | remainder = 0 164 | else: 165 | remainder = 1 166 | self.num_iterations = (self.epochs + remainder) * self.samples // self.batch_size 167 | 168 | self.mid_cycle_id = int(self.num_iterations * ((1. - self.end_percentage)) / float(2)) 169 | 170 | self._reset() 171 | K.set_value(self.model.optimizer.lr, self.compute_lr()) 172 | 173 | if self._update_momentum: 174 | if not hasattr(self.model.optimizer, 'momentum'): 175 | raise ValueError("Momentum can be updated only on SGD optimizer !") 176 | 177 | new_momentum = self.compute_momentum() 178 | K.set_value(self.model.optimizer.momentum, new_momentum) 179 | 180 | def on_batch_end(self, epoch, logs=None): 181 | logs = logs or {} 182 | 183 | self.clr_iterations += 1 184 | new_lr = self.compute_lr() 185 | 186 | self.history.setdefault('lr', []).append( 187 | K.get_value(self.model.optimizer.lr)) 188 | K.set_value(self.model.optimizer.lr, new_lr) 189 | 190 | if self._update_momentum: 191 | if not hasattr(self.model.optimizer, 'momentum'): 192 | raise ValueError("Momentum can be updated only on SGD optimizer !") 193 | 194 | new_momentum = self.compute_momentum() 195 | 196 | self.history.setdefault('momentum', []).append( 197 | K.get_value(self.model.optimizer.momentum)) 198 | K.set_value(self.model.optimizer.momentum, new_momentum) 199 | 200 | for k, v in logs.items(): 201 | self.history.setdefault(k, []).append(v) 202 | 203 | def on_epoch_end(self, epoch, logs=None): 204 | if self.verbose: 205 | if self._update_momentum: 206 | print(" - lr: %0.5f - momentum: %0.2f " % 207 | (self.history['lr'][-1], self.history['momentum'][-1])) 208 | 209 | else: 210 | print(" - lr: %0.5f " % (self.history['lr'][-1])) 211 | 212 | 213 | class LRFinder(Callback): 214 | def __init__(self, 215 | num_samples, 216 | batch_size, 217 | minimum_lr=1e-5, 218 | maximum_lr=10., 219 | lr_scale='exp', 220 | validation_data=None, 221 | validation_sample_rate=5, 222 | stopping_criterion_factor=4., 223 | loss_smoothing_beta=0.98, 224 | save_dir=None, 225 | verbose=True): 226 | """ 227 | This class uses the Cyclic Learning Rate history to find a 228 | set of learning rates that can be good initializations for the 229 | One-Cycle training proposed by Leslie Smith in the paper referenced 230 | below. 231 | 232 | A port of the Fast.ai implementation for Keras. 233 | 234 | # Note 235 | This requires that the model be trained for exactly 1 epoch. If the model 236 | is trained for more epochs, then the metric calculations are only done for 237 | the first epoch. 238 | 239 | # Interpretation 240 | Upon visualizing the loss plot, check where the loss starts to increase 241 | rapidly. Choose a learning rate at somewhat prior to the corresponding 242 | position in the plot for faster convergence. This will be the maximum_lr lr. 243 | Choose the max value as this value when passing the `max_val` argument 244 | to OneCycleLR callback. 245 | 246 | Since the plot is in log-scale, you need to compute 10 ^ (-k) of the x-axis 247 | 248 | # Arguments: 249 | num_samples: Integer. Number of samples in the dataset. 250 | batch_size: Integer. Batch size during training. 251 | minimum_lr: Float. Initial learning rate (and the minimum). 252 | maximum_lr: Float. Final learning rate (and the maximum). 253 | lr_scale: Can be one of ['exp', 'linear']. Chooses the type of 254 | scaling for each update to the learning rate during subsequent 255 | batches. Choose 'exp' for large range and 'linear' for small range. 256 | validation_data: Requires the validation dataset as a tuple of 257 | (X, y) belonging to the validation set. If provided, will use the 258 | validation set to compute the loss metrics. Else uses the training 259 | batch loss. Will warn if not provided to alert the user. 260 | validation_sample_rate: Positive or Negative Integer. Number of batches to sample from the 261 | validation set per iteration of the LRFinder. Larger number of 262 | samples will reduce the variance but will take longer time to execute 263 | per batch. 264 | 265 | If Positive > 0, will sample from the validation dataset 266 | If Megative, will use the entire dataset 267 | stopping_criterion_factor: Integer or None. A factor which is used 268 | to measure large increase in the loss value during training. 269 | Since callbacks cannot stop training of a model, it will simply 270 | stop logging the additional values from the epochs after this 271 | stopping criterion has been met. 272 | If None, this check will not be performed. 273 | loss_smoothing_beta: Float. The smoothing factor for the moving 274 | average of the loss function. 275 | save_dir: Optional, String. If passed a directory path, the callback 276 | will save the running loss and learning rates to two separate numpy 277 | arrays inside this directory. If the directory in this path does not 278 | exist, they will be created. 279 | verbose: Whether to print the learning rate after every batch of training. 280 | 281 | # References: 282 | - [A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, weight_decay, and weight decay](https://arxiv.org/abs/1803.09820) 283 | """ 284 | super(LRFinder, self).__init__() 285 | 286 | if lr_scale not in ['exp', 'linear']: 287 | raise ValueError("`lr_scale` must be one of ['exp', 'linear']") 288 | 289 | if validation_data is not None: 290 | self.validation_data = validation_data 291 | self.use_validation_set = True 292 | 293 | if validation_sample_rate > 0 or validation_sample_rate < 0: 294 | self.validation_sample_rate = validation_sample_rate 295 | else: 296 | raise ValueError("`validation_sample_rate` must be a positive or negative integer other than o") 297 | else: 298 | self.use_validation_set = False 299 | self.validation_sample_rate = 0 300 | 301 | self.num_samples = num_samples 302 | self.batch_size = batch_size 303 | self.initial_lr = minimum_lr 304 | self.final_lr = maximum_lr 305 | self.lr_scale = lr_scale 306 | self.stopping_criterion_factor = stopping_criterion_factor 307 | self.loss_smoothing_beta = loss_smoothing_beta 308 | self.save_dir = save_dir 309 | self.verbose = verbose 310 | 311 | self.num_batches_ = num_samples // batch_size 312 | self.current_lr_ = minimum_lr 313 | 314 | if lr_scale == 'exp': 315 | self.lr_multiplier_ = (maximum_lr / float(minimum_lr)) ** ( 316 | 1. / float(self.num_batches_)) 317 | else: 318 | extra_batch = int((num_samples % batch_size) != 0) 319 | self.lr_multiplier_ = np.linspace( 320 | minimum_lr, maximum_lr, num=self.num_batches_ + extra_batch) 321 | 322 | # If negative, use entire validation set 323 | if self.validation_sample_rate < 0: 324 | self.validation_sample_rate = self.validation_data[0].shape[0] // batch_size 325 | 326 | self.current_batch_ = 0 327 | self.current_epoch_ = 0 328 | self.best_loss_ = 1e6 329 | self.running_loss_ = 0. 330 | 331 | self.history = {} 332 | 333 | def on_train_begin(self, logs=None): 334 | 335 | self.current_epoch_ = 1 336 | K.set_value(self.model.optimizer.lr, self.initial_lr) 337 | 338 | warnings.simplefilter("ignore") 339 | 340 | def on_epoch_begin(self, epoch, logs=None): 341 | self.current_batch_ = 0 342 | 343 | if self.current_epoch_ > 1: 344 | warnings.warn( 345 | "\n\nLearning rate finder should be used only with a single epoch. " 346 | "Hereafter, the callback will not measure the losses.\n\n") 347 | 348 | def on_batch_begin(self, batch, logs=None): 349 | self.current_batch_ += 1 350 | 351 | def on_batch_end(self, batch, logs=None): 352 | if self.current_epoch_ > 1: 353 | return 354 | 355 | if self.use_validation_set: 356 | X, Y = self.validation_data[0], self.validation_data[1] 357 | 358 | # use 5 random batches from test set for fast approximate of loss 359 | num_samples = self.batch_size * self.validation_sample_rate 360 | 361 | if num_samples > X.shape[0]: 362 | num_samples = X.shape[0] 363 | 364 | idx = np.random.choice(X.shape[0], num_samples, replace=False) 365 | x = X[idx] 366 | y = Y[idx] 367 | 368 | values = self.model.evaluate(x, y, batch_size=self.batch_size, verbose=False) 369 | loss = values[0] 370 | else: 371 | loss = logs['loss'] 372 | 373 | # smooth the loss value and bias correct 374 | running_loss = self.loss_smoothing_beta * loss + ( 375 | 1. - self.loss_smoothing_beta) * loss 376 | running_loss = running_loss / ( 377 | 1. - self.loss_smoothing_beta**self.current_batch_) 378 | 379 | # stop logging if loss is too large 380 | if self.current_batch_ > 1 and self.stopping_criterion_factor is not None and ( 381 | running_loss > 382 | self.stopping_criterion_factor * self.best_loss_): 383 | 384 | if self.verbose: 385 | print(" - LRFinder: Skipping iteration since loss is %d times as large as best loss (%0.4f)" 386 | % (self.stopping_criterion_factor, self.best_loss_)) 387 | return 388 | 389 | if running_loss < self.best_loss_ or self.current_batch_ == 1: 390 | self.best_loss_ = running_loss 391 | 392 | current_lr = K.get_value(self.model.optimizer.lr) 393 | 394 | self.history.setdefault('running_loss_', []).append(running_loss) 395 | if self.lr_scale == 'exp': 396 | self.history.setdefault('log_lrs', []).append(np.log10(current_lr)) 397 | else: 398 | self.history.setdefault('log_lrs', []).append(current_lr) 399 | 400 | # compute the lr for the next batch and update the optimizer lr 401 | if self.lr_scale == 'exp': 402 | current_lr *= self.lr_multiplier_ 403 | else: 404 | current_lr = self.lr_multiplier_[self.current_batch_ - 1] 405 | 406 | K.set_value(self.model.optimizer.lr, current_lr) 407 | 408 | # save the other metrics as well 409 | for k, v in logs.items(): 410 | self.history.setdefault(k, []).append(v) 411 | 412 | if self.verbose: 413 | if self.use_validation_set: 414 | print(" - LRFinder: val_loss: %1.4f - lr = %1.8f " % 415 | (values[0], current_lr)) 416 | else: 417 | print(" - LRFinder: lr = %1.8f " % current_lr) 418 | 419 | def on_epoch_end(self, epoch, logs=None): 420 | if self.save_dir is not None and self.current_epoch_ <= 1: 421 | if not os.path.exists(self.save_dir): 422 | os.makedirs(self.save_dir) 423 | 424 | losses_path = os.path.join(self.save_dir, 'losses.npy') 425 | lrs_path = os.path.join(self.save_dir, 'lrs.npy') 426 | 427 | np.save(losses_path, self.losses) 428 | np.save(lrs_path, self.lrs) 429 | 430 | if self.verbose: 431 | print("\tLR Finder : Saved the losses and learning rate values in path : {%s}" 432 | % (self.save_dir)) 433 | 434 | self.current_epoch_ += 1 435 | 436 | warnings.simplefilter("default") 437 | 438 | def plot_schedule(self, clip_beginning=None, clip_endding=None): 439 | """ 440 | Plots the schedule from the callback itself. 441 | 442 | # Arguments: 443 | clip_beginning: Integer or None. If positive integer, it will 444 | remove the specified portion of the loss graph to remove the large 445 | loss values in the beginning of the graph. 446 | clip_endding: Integer or None. If negative integer, it will 447 | remove the specified portion of the ending of the loss graph to 448 | remove the sharp increase in the loss values at high learning rates. 449 | """ 450 | try: 451 | import matplotlib.pyplot as plt 452 | plt.style.use('seaborn-white') 453 | except ImportError: 454 | print( 455 | "Matplotlib not found. Please use `pip install matplotlib` first." 456 | ) 457 | return 458 | 459 | if clip_beginning is not None and clip_beginning < 0: 460 | clip_beginning = -clip_beginning 461 | 462 | if clip_endding is not None and clip_endding > 0: 463 | clip_endding = -clip_endding 464 | 465 | losses = self.losses 466 | lrs = self.lrs 467 | 468 | if clip_beginning: 469 | losses = losses[clip_beginning:] 470 | lrs = lrs[clip_beginning:] 471 | 472 | if clip_endding: 473 | losses = losses[:clip_endding] 474 | lrs = lrs[:clip_endding] 475 | 476 | plt.plot(lrs, losses) 477 | plt.title('Learning rate vs Loss') 478 | plt.xlabel('learning rate') 479 | plt.ylabel('loss') 480 | plt.show() 481 | 482 | @classmethod 483 | def restore_schedule_from_dir(cls, 484 | directory, 485 | clip_beginning=None, 486 | clip_endding=None): 487 | """ 488 | Loads the training history from the saved numpy files in the given directory. 489 | 490 | # Arguments: 491 | directory: String. Path to the directory where the serialized numpy 492 | arrays of the loss and learning rates are saved. 493 | clip_beginning: Integer or None. If positive integer, it will 494 | remove the specified portion of the loss graph to remove the large 495 | loss values in the beginning of the graph. 496 | clip_endding: Integer or None. If negative integer, it will 497 | remove the specified portion of the ending of the loss graph to 498 | remove the sharp increase in the loss values at high learning rates. 499 | 500 | Returns: 501 | tuple of (losses, learning rates) 502 | """ 503 | if clip_beginning is not None and clip_beginning < 0: 504 | clip_beginning = -clip_beginning 505 | 506 | if clip_endding is not None and clip_endding > 0: 507 | clip_endding = -clip_endding 508 | 509 | losses_path = os.path.join(directory, 'losses.npy') 510 | lrs_path = os.path.join(directory, 'lrs.npy') 511 | 512 | if not os.path.exists(losses_path) or not os.path.exists(lrs_path): 513 | print("%s and %s could not be found at directory : {%s}" % 514 | (losses_path, lrs_path, directory)) 515 | 516 | losses = None 517 | lrs = None 518 | 519 | else: 520 | losses = np.load(losses_path) 521 | lrs = np.load(lrs_path) 522 | 523 | if clip_beginning: 524 | losses = losses[clip_beginning:] 525 | lrs = lrs[clip_beginning:] 526 | 527 | if clip_endding: 528 | losses = losses[:clip_endding] 529 | lrs = lrs[:clip_endding] 530 | 531 | return losses, lrs 532 | 533 | @classmethod 534 | def plot_schedule_from_file(cls, 535 | directory, 536 | clip_beginning=None, 537 | clip_endding=None): 538 | """ 539 | Plots the schedule from the saved numpy arrays of the loss and learning 540 | rate values in the specified directory. 541 | 542 | # Arguments: 543 | directory: String. Path to the directory where the serialized numpy 544 | arrays of the loss and learning rates are saved. 545 | clip_beginning: Integer or None. If positive integer, it will 546 | remove the specified portion of the loss graph to remove the large 547 | loss values in the beginning of the graph. 548 | clip_endding: Integer or None. If negative integer, it will 549 | remove the specified portion of the ending of the loss graph to 550 | remove the sharp increase in the loss values at high learning rates. 551 | """ 552 | try: 553 | import matplotlib.pyplot as plt 554 | plt.style.use('seaborn-white') 555 | except ImportError: 556 | print("Matplotlib not found. Please use `pip install matplotlib` first.") 557 | return 558 | 559 | losses, lrs = cls.restore_schedule_from_dir( 560 | directory, 561 | clip_beginning=clip_beginning, 562 | clip_endding=clip_endding) 563 | 564 | if losses is None or lrs is None: 565 | return 566 | else: 567 | plt.plot(lrs, losses) 568 | plt.title('Learning rate vs Loss') 569 | plt.xlabel('learning rate') 570 | plt.ylabel('loss') 571 | plt.show() 572 | 573 | @property 574 | def lrs(self): 575 | return np.array(self.history['log_lrs']) 576 | 577 | @property 578 | def losses(self): 579 | return np.array(self.history['running_loss_']) 580 | -------------------------------------------------------------------------------- /images/lr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/images/lr.png -------------------------------------------------------------------------------- /images/momentum.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/images/momentum.png -------------------------------------------------------------------------------- /images/one_cycle_lr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/images/one_cycle_lr.png -------------------------------------------------------------------------------- /images/one_cycle_momentum.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/images/one_cycle_momentum.png -------------------------------------------------------------------------------- /images/weight_decay.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/images/weight_decay.png -------------------------------------------------------------------------------- /models/mobilenet/find_lr_schedule.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from keras example cifar10_cnn.py 3 | Train NASNet-CIFAR on the CIFAR10 small images dataset. 4 | """ 5 | from __future__ import print_function 6 | import os 7 | 8 | from keras.datasets import cifar10 9 | from keras.preprocessing.image import ImageDataGenerator 10 | from keras.utils import np_utils 11 | from keras.callbacks import ModelCheckpoint 12 | from keras.optimizers import SGD 13 | import numpy as np 14 | 15 | from clr import LRFinder 16 | from models.mobilenet.mobilenets import MiniMobileNetV2 17 | 18 | if not os.path.exists('weights/'): 19 | os.makedirs('weights/') 20 | 21 | weights_file = 'weights/mobilenet_v2_schedule.h5' 22 | model_checkpoint = ModelCheckpoint(weights_file, monitor='val_acc', save_best_only=True, 23 | save_weights_only=True, mode='max') 24 | 25 | batch_size = 128 26 | nb_classes = 10 27 | nb_epoch = 1 # Only finding lr 28 | data_augmentation = True 29 | 30 | # input image dimensions 31 | img_rows, img_cols = 32, 32 32 | # The CIFAR10 images are RGB. 33 | img_channels = 3 34 | 35 | # The data, shuffled and split between train and test sets: 36 | (X_train, y_train), (X_test, y_test) = cifar10.load_data() 37 | 38 | # Convert class vectors to binary class matrices. 39 | Y_train = np_utils.to_categorical(y_train, nb_classes) 40 | Y_test = np_utils.to_categorical(y_test, nb_classes) 41 | 42 | X_train = X_train.astype('float32') 43 | X_test = X_test.astype('float32') 44 | 45 | # preprocess input 46 | mean = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32') 47 | std = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32') 48 | 49 | print("Channel Mean : ", mean) 50 | print("Channel Std : ", std) 51 | 52 | X_train = (X_train - mean) / (std) 53 | X_test = (X_test - mean) / (std) 54 | 55 | # Learning rate finder callback setup 56 | num_samples = X_train.shape[0] 57 | 58 | # Exponential lr finder 59 | # USE THIS FOR A LARGE RANGE SEARCH 60 | # Uncomment the validation_data flag to reduce speed but get a better idea of the learning rate 61 | lr_finder = LRFinder(num_samples, batch_size, minimum_lr=1e-3, maximum_lr=10., 62 | lr_scale='exp', 63 | # validation_data=(X_test, Y_test), # use the validation data for losses 64 | validation_sample_rate=5, 65 | save_dir='weights/', verbose=True) 66 | 67 | # Linear lr finder 68 | # USE THIS FOR A CLOSE SEARCH 69 | # Uncomment the validation_data flag to reduce speed but get a better idea of the learning rate 70 | # lr_finder = LRFinder(num_samples, batch_size, minimum_lr=5e-4, maximum_lr=1e-2, 71 | # lr_scale='linear', 72 | # validation_data=(X_test, y_test), # use the validation data for losses 73 | # validation_sample_rate=5, 74 | # save_dir='weights/', verbose=True) 75 | 76 | # plot the previous values if present 77 | LRFinder.plot_schedule_from_file('weights/', clip_beginning=10, clip_endding=5) 78 | 79 | # For training, the auxilary branch must be used to correctly train NASNet 80 | 81 | model = MiniMobileNetV2((img_rows, img_cols, img_channels), alpha=1.4, 82 | dropout=0, weights=None, classes=nb_classes) 83 | model.summary() 84 | 85 | optimizer = SGD(lr=0.1, momentum=0.9, nesterov=True) 86 | model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy']) 87 | 88 | # model.load_weights(weights_file) 89 | 90 | if not data_augmentation: 91 | print('Not using data augmentation.') 92 | model.fit(X_train, Y_train, 93 | batch_size=batch_size, 94 | epochs=nb_epoch, 95 | validation_data=(X_test, Y_test), 96 | shuffle=True, 97 | verbose=1, 98 | callbacks=[lr_finder, model_checkpoint]) 99 | else: 100 | print('Using real-time data augmentation.') 101 | # This will do preprocessing and realtime data augmentation: 102 | datagen = ImageDataGenerator( 103 | featurewise_center=False, # set input mean to 0 over the dataset 104 | samplewise_center=False, # set each sample mean to 0 105 | featurewise_std_normalization=False, # divide inputs by std of the dataset 106 | samplewise_std_normalization=False, # divide each input by its std 107 | zca_whitening=False, # apply ZCA whitening 108 | rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180) 109 | width_shift_range=0, # randomly shift images horizontally (fraction of total width) 110 | height_shift_range=0, # randomly shift images vertically (fraction of total height) 111 | horizontal_flip=True, # randomly flip images 112 | vertical_flip=False) # randomly flip images 113 | 114 | # Compute quantities required for featurewise normalization 115 | # (std, mean, and principal components if ZCA whitening is applied). 116 | datagen.fit(X_train) 117 | 118 | # Fit the model on the batches generated by datagen.flow(). 119 | model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size, shuffle=True), 120 | steps_per_epoch=X_train.shape[0] // batch_size, 121 | validation_data=(X_test, Y_test), 122 | epochs=nb_epoch, verbose=1, 123 | callbacks=[lr_finder, model_checkpoint]) 124 | 125 | lr_finder.plot_schedule(clip_beginning=10, clip_endding=5) 126 | 127 | scores = model.evaluate(X_test, Y_test, batch_size=batch_size) 128 | for score, metric_name in zip(scores, model.metrics_names): 129 | print("%s : %0.4f" % (metric_name, score)) 130 | -------------------------------------------------------------------------------- /models/mobilenet/find_momentum_schedule.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from keras example cifar10_cnn.py 3 | Train NASNet-CIFAR on the CIFAR10 small images dataset. 4 | """ 5 | from __future__ import print_function 6 | import os 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | from keras.datasets import cifar10 11 | from keras.preprocessing.image import ImageDataGenerator 12 | from keras.utils import np_utils 13 | from keras.callbacks import ModelCheckpoint 14 | from keras.optimizers import SGD 15 | from keras import backend as K 16 | 17 | from clr import LRFinder 18 | from models.mobilenet.mobilenets import MiniMobileNetV2 19 | 20 | plt.style.use('seaborn-white') 21 | 22 | batch_size = 128 23 | nb_classes = 10 24 | nb_epoch = 1 # Only finding lr 25 | data_augmentation = True 26 | 27 | # input image dimensions 28 | img_rows, img_cols = 32, 32 29 | # The CIFAR10 images are RGB. 30 | img_channels = 3 31 | 32 | # The data, shuffled and split between train and test sets: 33 | (X_train, y_train), (X_test, y_test) = cifar10.load_data() 34 | 35 | # Convert class vectors to binary class matrices. 36 | Y_train = np_utils.to_categorical(y_train, nb_classes) 37 | Y_test = np_utils.to_categorical(y_test, nb_classes) 38 | 39 | X_train = X_train.astype('float32') 40 | X_test = X_test.astype('float32') 41 | 42 | # preprocess input 43 | mean = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32') 44 | std = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32') 45 | 46 | print("Channel Mean : ", mean) 47 | print("Channel Std : ", std) 48 | 49 | X_train = (X_train - mean) / (std) 50 | X_test = (X_test - mean) / (std) 51 | 52 | # Learning rate finder callback setup 53 | num_samples = X_train.shape[0] 54 | 55 | MOMENTUMS = [0.9, 0.95, 0.99] 56 | 57 | # for momentum in MOMENTUMS: 58 | # K.clear_session() 59 | # 60 | # # Learning rate range obtained from `find_lr_schedule.py` 61 | # # NOTE : Minimum is 10x smaller than the max found above ! 62 | # # NOTE : It is preferable to use the validation data here to get a correct value 63 | # lr_finder = LRFinder(num_samples, batch_size, minimum_lr=0.002, maximum_lr=0.02, 64 | # validation_data=(X_test, Y_test), 65 | # validation_sample_rate=5, 66 | # lr_scale='linear', save_dir='weights/momentum/momentum-%s/' % str(momentum), 67 | # verbose=True) 68 | # 69 | # model = MiniMobileNetV2((img_rows, img_cols, img_channels), alpha=1.4, 70 | # dropout=0, weights=None, classes=nb_classes) 71 | # model.summary() 72 | # 73 | # # set the weight_decay here ! 74 | # # lr doesnt matter as it will be over written by the callback 75 | # optimizer = SGD(lr=0.002, momentum=momentum, nesterov=True) 76 | # model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy']) 77 | # 78 | # # model.load_weights(weights_file) 79 | # 80 | # if not data_augmentation: 81 | # print('Not using data augmentation.') 82 | # model.fit(X_train, Y_train, 83 | # batch_size=batch_size, 84 | # epochs=nb_epoch, 85 | # validation_data=(X_test, Y_test), 86 | # shuffle=True, 87 | # verbose=1, 88 | # callbacks=[lr_finder]) 89 | # else: 90 | # print('Using real-time data augmentation.') 91 | # # This will do preprocessing and realtime data augmentation: 92 | # datagen = ImageDataGenerator( 93 | # featurewise_center=False, # set input mean to 0 over the dataset 94 | # samplewise_center=False, # set each sample mean to 0 95 | # featurewise_std_normalization=False, # divide inputs by std of the dataset 96 | # samplewise_std_normalization=False, # divide each input by its std 97 | # zca_whitening=False, # apply ZCA whitening 98 | # rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180) 99 | # width_shift_range=0, # randomly shift images horizontally (fraction of total width) 100 | # height_shift_range=0, # randomly shift images vertically (fraction of total height) 101 | # horizontal_flip=True, # randomly flip images 102 | # vertical_flip=False) # randomly flip images 103 | # 104 | # # Compute quantities required for featurewise normalization 105 | # # (std, mean, and principal components if ZCA whitening is applied). 106 | # datagen.fit(X_train) 107 | # 108 | # # Fit the model on the batches generated by datagen.flow(). 109 | # model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size, shuffle=True), 110 | # steps_per_epoch=X_train.shape[0] // batch_size, 111 | # validation_data=(X_test, Y_test), 112 | # epochs=nb_epoch, verbose=1, 113 | # callbacks=[lr_finder]) 114 | 115 | # from plot we see, the model isnt impacted by the weight_decay very much at all 116 | # so we can use any of them. 117 | 118 | for momentum in MOMENTUMS: 119 | directory = 'weights/momentum/momentum-%s/' % str(momentum) 120 | 121 | losses, lrs = LRFinder.restore_schedule_from_dir(directory, 10, 5) 122 | plt.plot(lrs, losses, label='momentum=%0.2f' % momentum) 123 | 124 | plt.title("Momentum") 125 | plt.xlabel("Learning rate") 126 | plt.ylabel("Validation Loss") 127 | plt.legend() 128 | plt.show() 129 | -------------------------------------------------------------------------------- /models/mobilenet/find_weight_decay_schedule.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from keras example cifar10_cnn.py 3 | Train NASNet-CIFAR on the CIFAR10 small images dataset. 4 | """ 5 | from __future__ import print_function 6 | import os 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | from keras.datasets import cifar10 11 | from keras.preprocessing.image import ImageDataGenerator 12 | from keras.utils import np_utils 13 | from keras.callbacks import ModelCheckpoint 14 | from keras.optimizers import SGD 15 | from keras import backend as K 16 | 17 | from clr import LRFinder 18 | from models.mobilenet.mobilenets import MiniMobileNetV2 19 | 20 | plt.style.use('seaborn-white') 21 | 22 | batch_size = 128 23 | nb_classes = 10 24 | nb_epoch = 1 # Only finding lr 25 | data_augmentation = True 26 | 27 | # input image dimensions 28 | img_rows, img_cols = 32, 32 29 | # The CIFAR10 images are RGB. 30 | img_channels = 3 31 | 32 | # The data, shuffled and split between train and test sets: 33 | (X_train, y_train), (X_test, y_test) = cifar10.load_data() 34 | 35 | # Convert class vectors to binary class matrices. 36 | Y_train = np_utils.to_categorical(y_train, nb_classes) 37 | Y_test = np_utils.to_categorical(y_test, nb_classes) 38 | 39 | X_train = X_train.astype('float32') 40 | X_test = X_test.astype('float32') 41 | 42 | # preprocess input 43 | mean = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32') 44 | std = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32') 45 | 46 | print("Channel Mean : ", mean) 47 | print("Channel Std : ", std) 48 | 49 | X_train = (X_train - mean) / (std) 50 | X_test = (X_test - mean) / (std) 51 | 52 | # Learning rate finder callback setup 53 | num_samples = X_train.shape[0] 54 | 55 | # INITIAL WEIGHT DECAY FACTORS 56 | # WEIGHT_DECAY_FACTORS = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7] 57 | 58 | # FINEGRAINED WEIGHT DECAY FACTORS 59 | WEIGHT_DECAY_FACTORS = [1e-7, 3e-7, 3e-6] 60 | 61 | # for weight_decay in WEIGHT_DECAY_FACTORS: 62 | # K.clear_session() 63 | # 64 | # # Learning rate range obtained from `find_lr_schedule.py` 65 | # # NOTE : Minimum is 10x smaller than the max found above ! 66 | # # NOTE : It is preferable to use the validation data here to get a correct value 67 | # lr_finder = LRFinder(num_samples, batch_size, minimum_lr=0.002, maximum_lr=0.02, 68 | # validation_data=(X_test, Y_test), 69 | # validation_sample_rate=5, 70 | # lr_scale='linear', save_dir='weights/weight_decay/weight_decay-%s/' % str(weight_decay), 71 | # verbose=True) 72 | # 73 | # # SETUP THE WEIGHT DECAY IN THE MODEL 74 | # model = MiniMobileNetV2((img_rows, img_cols, img_channels), alpha=1.4, 75 | # weight_decay=weight_decay, dropout=0, 76 | # weights=None, classes=nb_classes) 77 | # model.summary() 78 | # 79 | # # set the weight_decay here ! 80 | # # lr doesnt matter as it will be over written by the callback 81 | # optimizer = SGD(lr=0.002, momentum=0.9, nesterov=True) 82 | # model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy']) 83 | # 84 | # # model.load_weights(weights_file) 85 | # 86 | # if not data_augmentation: 87 | # print('Not using data augmentation.') 88 | # model.fit(X_train, Y_train, 89 | # batch_size=batch_size, 90 | # epochs=nb_epoch, 91 | # validation_data=(X_test, Y_test), 92 | # shuffle=True, 93 | # verbose=1, 94 | # callbacks=[lr_finder]) 95 | # else: 96 | # print('Using real-time data augmentation.') 97 | # # This will do preprocessing and realtime data augmentation: 98 | # datagen = ImageDataGenerator( 99 | # featurewise_center=False, # set input mean to 0 over the dataset 100 | # samplewise_center=False, # set each sample mean to 0 101 | # featurewise_std_normalization=False, # divide inputs by std of the dataset 102 | # samplewise_std_normalization=False, # divide each input by its std 103 | # zca_whitening=False, # apply ZCA whitening 104 | # rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180) 105 | # width_shift_range=0, # randomly shift images horizontally (fraction of total width) 106 | # height_shift_range=0, # randomly shift images vertically (fraction of total height) 107 | # horizontal_flip=True, # randomly flip images 108 | # vertical_flip=False) # randomly flip images 109 | # 110 | # # Compute quantities required for featurewise normalization 111 | # # (std, mean, and principal components if ZCA whitening is applied). 112 | # datagen.fit(X_train) 113 | # 114 | # # Fit the model on the batches generated by datagen.flow(). 115 | # model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size, shuffle=True), 116 | # steps_per_epoch=X_train.shape[0] // batch_size, 117 | # validation_data=(X_test, Y_test), 118 | # epochs=nb_epoch, verbose=1, 119 | # callbacks=[lr_finder]) 120 | 121 | # from plot we see, the model isnt impacted by the weight_decay very much at all 122 | # so we can use any of them. 123 | 124 | for weight_decay in WEIGHT_DECAY_FACTORS: 125 | directory = 'weights/weight_decay/weight_decay-%s/' % str(weight_decay) 126 | 127 | losses, lrs = LRFinder.restore_schedule_from_dir(directory, 10, 5) 128 | plt.plot(lrs, losses, label='weight_decay=%0.7f' % weight_decay) 129 | 130 | plt.title("Weight Decay") 131 | plt.xlabel("Learning rate") 132 | plt.ylabel("Validation Loss") 133 | plt.legend() 134 | plt.show() 135 | -------------------------------------------------------------------------------- /models/mobilenet/mobilenets.py: -------------------------------------------------------------------------------- 1 | """MobileNet v1 models for Keras. 2 | MobileNet is a general architecture and can be used for multiple use cases. 3 | Depending on the use case, it can use different input layer size and 4 | different width factors. This allows different width models to reduce 5 | the number of multiply-adds and thereby 6 | reduce inference cost on mobile devices. 7 | MobileNets support any input size greater than 32 x 32, with larger image sizes 8 | offering better performance. 9 | The number of parameters and number of multiply-adds 10 | can be modified by using the `alpha` parameter, 11 | which increases/decreases the number of filters in each layer. 12 | By altering the image size and `alpha` parameter, 13 | all 16 models from the paper can be built, with ImageNet weights provided. 14 | The paper demonstrates the performance of MobileNets using `alpha` values of 15 | 1.0 (also called 100 % MobileNet), 0.75, 0.5 and 0.25. 16 | For each of these `alpha` values, weights for 4 different input image sizes 17 | are provided (224, 192, 160, 128). 18 | The following table describes the size and accuracy of the 100% MobileNet 19 | on size 224 x 224: 20 | ---------------------------------------------------------------------------- 21 | Width Multiplier (alpha) | ImageNet Acc | Multiply-Adds (M) | Params (M) 22 | ---------------------------------------------------------------------------- 23 | | 1.0 MobileNet-224 | 70.6 % | 529 | 4.2 | 24 | | 0.75 MobileNet-224 | 68.4 % | 325 | 2.6 | 25 | | 0.50 MobileNet-224 | 63.7 % | 149 | 1.3 | 26 | | 0.25 MobileNet-224 | 50.6 % | 41 | 0.5 | 27 | ---------------------------------------------------------------------------- 28 | The following table describes the performance of 29 | the 100 % MobileNet on various input sizes: 30 | ------------------------------------------------------------------------ 31 | Resolution | ImageNet Acc | Multiply-Adds (M) | Params (M) 32 | ------------------------------------------------------------------------ 33 | | 1.0 MobileNet-224 | 70.6 % | 529 | 4.2 | 34 | | 1.0 MobileNet-192 | 69.1 % | 529 | 4.2 | 35 | | 1.0 MobileNet-160 | 67.2 % | 529 | 4.2 | 36 | | 1.0 MobileNet-128 | 64.4 % | 529 | 4.2 | 37 | ------------------------------------------------------------------------ 38 | The weights for all 16 models are obtained and translated 39 | from Tensorflow checkpoints found at 40 | https://github.com/tensorflow/models/blob/master/slim/nets/mobilenet_v1.md 41 | # Reference 42 | - [MobileNets: Efficient Convolutional Neural Networks for 43 | Mobile Vision Applications](https://arxiv.org/pdf/1704.04861.pdf)) 44 | """ 45 | from __future__ import print_function 46 | from __future__ import absolute_import 47 | from __future__ import division 48 | 49 | import warnings 50 | import math 51 | 52 | from keras.models import Model 53 | from keras.layers import Input 54 | from keras.layers import Activation 55 | from keras.layers import Dropout 56 | from keras.layers import Reshape 57 | from keras.layers import BatchNormalization 58 | from keras.layers import GlobalAveragePooling2D 59 | from keras.layers import GlobalMaxPooling2D 60 | from keras.layers import Conv2D 61 | from keras.layers import DepthwiseConv2D 62 | from keras.layers import add 63 | from keras import initializers 64 | from keras import regularizers 65 | from keras import constraints 66 | from keras.utils import conv_utils 67 | from keras.utils.data_utils import get_file 68 | from keras.engine.topology import get_source_inputs 69 | from keras.engine import InputSpec 70 | # changed the next 3 lines from keras.applications to keras_applications for Keras version 2.2.4 71 | from keras_applications.imagenet_utils import _obtain_input_shape 72 | from keras_applications.inception_v3 import preprocess_input 73 | from keras_applications.imagenet_utils import decode_predictions 74 | from keras import backend as K 75 | 76 | import tensorflow as tf 77 | 78 | BASE_WEIGHT_PATH = '' 79 | BASE_WEIGHT_PATH_V2 = '' 80 | 81 | 82 | def relu6(x): 83 | return K.relu(x, max_value=6) 84 | 85 | 86 | def MiniMobileNetV2(input_shape=None, 87 | alpha=1.0, 88 | expansion_factor=6, 89 | depth_multiplier=1, 90 | dropout=0., 91 | weight_decay=0., 92 | include_top=True, 93 | weights=None, 94 | input_tensor=None, 95 | pooling=None, 96 | classes=10): 97 | """Instantiates the MobileNet architecture. 98 | MobileNet V2 is from the paper: 99 | - [Inverted Residuals and Linear Bottlenecks: Mobile Networks for Classification, Detection and Segmentation](https://arxiv.org/abs/1801.04381) 100 | 101 | Note that only TensorFlow is supported for now, 102 | therefore it only works with the data format 103 | `image_data_format='channels_last'` in your Keras config 104 | at `~/.keras/keras.json`. 105 | To load a MobileNet model via `load_model`, import the custom 106 | objects `relu6` and `DepthwiseConv2D` and pass them to the 107 | `custom_objects` parameter. 108 | E.g. 109 | model = load_model('mobilenet.h5', custom_objects={ 110 | 'relu6': mobilenet.relu6, 111 | 'DepthwiseConv2D': mobilenet.DepthwiseConv2D}) 112 | # Arguments 113 | input_shape: optional shape tuple, only to be specified 114 | if `include_top` is False (otherwise the input shape 115 | has to be `(224, 224, 3)` (with `channels_last` data format) 116 | or (3, 224, 224) (with `channels_first` data format). 117 | It should have exactly 3 inputs channels, 118 | and width and height should be no smaller than 32. 119 | E.g. `(200, 200, 3)` would be one valid value. 120 | alpha: controls the width of the network. 121 | - If `alpha` < 1.0, proportionally decreases the number 122 | of filters in each layer. 123 | - If `alpha` > 1.0, proportionally increases the number 124 | of filters in each layer. 125 | - If `alpha` = 1, default number of filters from the paper 126 | are used at each layer. 127 | expansion_factor: controls the expansion of the internal bottleneck 128 | blocks. Should be a positive integer >= 1 129 | depth_multiplier: depth multiplier for depthwise convolution 130 | (also called the resolution multiplier) 131 | dropout: dropout rate 132 | weight_decay: Weight decay factor. 133 | include_top: whether to include the fully-connected 134 | layer at the top of the network. 135 | weights: `None` (random initialization) or 136 | `imagenet` (ImageNet weights) 137 | input_tensor: optional Keras tensor (i.e. output of 138 | `layers.Input()`) 139 | to use as image input for the model. 140 | pooling: Optional pooling mode for feature extraction 141 | when `include_top` is `False`. 142 | - `None` means that the output of the model 143 | will be the 4D tensor output of the 144 | last convolutional layer. 145 | - `avg` means that global average pooling 146 | will be applied to the output of the 147 | last convolutional layer, and thus 148 | the output of the model will be a 149 | 2D tensor. 150 | - `max` means that global max pooling will 151 | be applied. 152 | classes: optional number of classes to classify images 153 | into, only to be specified if `include_top` is True, and 154 | if no `weights` argument is specified. 155 | # Returns 156 | A Keras model instance. 157 | # Raises 158 | ValueError: in case of invalid argument for `weights`, 159 | or invalid input shape. 160 | RuntimeError: If attempting to run this model with a 161 | backend that does not support separable convolutions. 162 | """ 163 | 164 | if K.backend() != 'tensorflow': 165 | raise RuntimeError('Only Tensorflow backend is currently supported, ' 166 | 'as other backends do not support ' 167 | 'depthwise convolution.') 168 | 169 | if weights not in {'imagenet', None}: 170 | raise ValueError('The `weights` argument should be either ' 171 | '`None` (random initialization) or `imagenet` ' 172 | '(pre-training on ImageNet).') 173 | 174 | if weights == 'imagenet' and include_top and classes != 1000: 175 | raise ValueError('If using `weights` as ImageNet with `include_top` ' 176 | 'as true, `classes` should be 1000') 177 | 178 | # Determine proper input shape and default size. 179 | if input_shape is None: 180 | default_size = 224 181 | else: 182 | if K.image_data_format() == 'channels_first': 183 | rows = input_shape[1] 184 | cols = input_shape[2] 185 | else: 186 | rows = input_shape[0] 187 | cols = input_shape[1] 188 | 189 | if rows == cols and rows in [96, 128, 160, 192, 224]: 190 | default_size = rows 191 | else: 192 | default_size = 224 193 | 194 | input_shape = _obtain_input_shape(input_shape, 195 | default_size=default_size, 196 | min_size=32, 197 | data_format=K.image_data_format(), 198 | require_flatten=include_top or weights) 199 | if K.image_data_format() == 'channels_last': 200 | row_axis, col_axis = (0, 1) 201 | else: 202 | row_axis, col_axis = (1, 2) 203 | rows = input_shape[row_axis] 204 | cols = input_shape[col_axis] 205 | 206 | if weights == 'imagenet': 207 | if depth_multiplier != 1: 208 | raise ValueError('If imagenet weights are being loaded, ' 209 | 'depth multiplier must be 1') 210 | 211 | if alpha not in [0.35, 0.50, 0.75, 1.0, 1.3, 1.4]: 212 | raise ValueError('If imagenet weights are being loaded, ' 213 | 'alpha can be one of' 214 | '`0.35`, `0.50`, `0.75`, `1.0`, `1.3` and `1.4` only.') 215 | 216 | if rows != cols or rows not in [96, 128, 160, 192, 224]: 217 | raise ValueError('If imagenet weights are being loaded, ' 218 | 'input must have a static square shape (one of ' 219 | '(06, 96), (128,128), (160,160), (192,192), or ' 220 | '(224, 224)).Input shape provided = %s' % (input_shape,)) 221 | 222 | if K.image_data_format() != 'channels_last': 223 | warnings.warn('The MobileNet family of models is only available ' 224 | 'for the input data format "channels_last" ' 225 | '(width, height, channels). ' 226 | 'However your settings specify the default ' 227 | 'data format "channels_first" (channels, width, height).' 228 | ' You should set `image_data_format="channels_last"` ' 229 | 'in your Keras config located at ~/.keras/keras.json. ' 230 | 'The model being returned right now will expect inputs ' 231 | 'to follow the "channels_last" data format.') 232 | K.set_image_data_format('channels_last') 233 | old_data_format = 'channels_first' 234 | else: 235 | old_data_format = None 236 | 237 | if input_tensor is None: 238 | img_input = Input(shape=input_shape) 239 | else: 240 | if not K.is_keras_tensor(input_tensor): 241 | img_input = Input(tensor=input_tensor, shape=input_shape) 242 | else: 243 | img_input = input_tensor 244 | 245 | x = _conv_block(img_input, 32, alpha, bn_epsilon=1e-3, bn_momentum=0.99, weight_decay=weight_decay) 246 | x = _depthwise_conv_block_v2(x, 16, alpha, 1, depth_multiplier, bn_epsilon=1e-3, bn_momentum=0.99, 247 | weight_decay=weight_decay, block_id=1) 248 | 249 | x = _depthwise_conv_block_v2(x, 24, alpha, expansion_factor, depth_multiplier, block_id=2, 250 | bn_epsilon=1e-3, bn_momentum=0.99, weight_decay=weight_decay, strides=(2, 2)) 251 | x = _depthwise_conv_block_v2(x, 24, alpha, expansion_factor, depth_multiplier, bn_epsilon=1e-3, bn_momentum=0.99, 252 | weight_decay=weight_decay, block_id=3) 253 | 254 | x = _depthwise_conv_block_v2(x, 32, alpha, expansion_factor, depth_multiplier, block_id=4, 255 | bn_epsilon=1e-3, bn_momentum=0.99, weight_decay=weight_decay) 256 | x = _depthwise_conv_block_v2(x, 32, alpha, expansion_factor, depth_multiplier, bn_epsilon=1e-3, bn_momentum=0.99, 257 | weight_decay=weight_decay, block_id=5) 258 | x = _depthwise_conv_block_v2(x, 32, alpha, expansion_factor, depth_multiplier, bn_epsilon=1e-3, bn_momentum=0.99, 259 | weight_decay=weight_decay, block_id=6) 260 | 261 | x = _depthwise_conv_block_v2(x, 64, alpha, expansion_factor, depth_multiplier, block_id=7, 262 | bn_epsilon=1e-3, bn_momentum=0.99, weight_decay=weight_decay, strides=(2, 2)) 263 | x = _depthwise_conv_block_v2(x, 64, alpha, expansion_factor, depth_multiplier, bn_epsilon=1e-3, bn_momentum=0.99, 264 | weight_decay=weight_decay, block_id=8) 265 | x = _depthwise_conv_block_v2(x, 64, alpha, expansion_factor, depth_multiplier, bn_epsilon=1e-3, bn_momentum=0.99, 266 | weight_decay=weight_decay, block_id=9) 267 | x = _depthwise_conv_block_v2(x, 64, alpha, expansion_factor, depth_multiplier, bn_epsilon=1e-3, bn_momentum=0.99, 268 | weight_decay=weight_decay, block_id=10) 269 | 270 | if alpha <= 1.0: 271 | penultimate_filters = 1280 272 | else: 273 | penultimate_filters = int(1280 * alpha) 274 | 275 | x = _conv_block(x, penultimate_filters, alpha=1.0, kernel=(1, 1), bn_epsilon=1e-3, bn_momentum=0.99, 276 | block_id=18) 277 | 278 | if include_top: 279 | if K.image_data_format() == 'channels_first': 280 | shape = (penultimate_filters, 1, 1) 281 | else: 282 | shape = (1, 1, penultimate_filters) 283 | 284 | x = GlobalAveragePooling2D()(x) 285 | x = Reshape(shape, name='reshape_1')(x) 286 | x = Dropout(dropout, name='dropout')(x) 287 | x = Conv2D(classes, (1, 1), 288 | kernel_initializer=initializers.he_normal(), 289 | kernel_regularizer=regularizers.l2(weight_decay), 290 | padding='same', name='conv_preds')(x) 291 | x = Activation('softmax', name='act_softmax')(x) 292 | x = Reshape((classes,), name='reshape_2')(x) 293 | else: 294 | if pooling == 'avg': 295 | x = GlobalAveragePooling2D()(x) 296 | elif pooling == 'max': 297 | x = GlobalMaxPooling2D()(x) 298 | 299 | # Ensure that the model takes into account 300 | # any potential predecessors of `input_tensor`. 301 | if input_tensor is not None: 302 | inputs = get_source_inputs(input_tensor) 303 | else: 304 | inputs = img_input 305 | 306 | # Create model. 307 | model = Model(inputs, x, name='mobilenetV2_%0.2f_%s' % (alpha, rows)) 308 | 309 | # load weights 310 | if weights == 'imagenet': 311 | if K.image_data_format() == 'channels_first': 312 | raise ValueError('Weights for "channels_last" format ' 313 | 'are not available.') 314 | if alpha == 1.0: 315 | alpha_text = '1_0' 316 | elif alpha == 1.3: 317 | alpha_text = '1_3' 318 | elif alpha == 1.4: 319 | alpha_text = '1_4' 320 | elif alpha == 0.75: 321 | alpha_text = '7_5' 322 | elif alpha == 0.50: 323 | alpha_text = '5_0' 324 | else: 325 | alpha_text = '3_5' 326 | 327 | if include_top: 328 | model_name = 'mobilenet_v2_%s_%d_tf.h5' % (alpha_text, rows) 329 | weigh_path = BASE_WEIGHT_PATH_V2 + model_name 330 | weights_path = get_file(model_name, 331 | weigh_path, 332 | cache_subdir='models') 333 | else: 334 | model_name = 'mobilenet_v2_%s_%d_tf_no_top.h5' % (alpha_text, rows) 335 | weigh_path = BASE_WEIGHT_PATH_V2 + model_name 336 | weights_path = get_file(model_name, 337 | weigh_path, 338 | cache_subdir='models') 339 | model.load_weights(weights_path) 340 | 341 | if old_data_format: 342 | K.set_image_data_format(old_data_format) 343 | return model 344 | 345 | 346 | # taken from https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/conv_blocks.py 347 | def _make_divisible(v, divisor=8, min_value=8): 348 | if min_value is None: 349 | min_value = divisor 350 | 351 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 352 | # Make sure that round down does not go down by more than 10%. 353 | if new_v < 0.9 * v: 354 | new_v += divisor 355 | return new_v 356 | 357 | 358 | def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1), bn_epsilon=1e-3, 359 | bn_momentum=0.99, weight_decay=0., block_id=1): 360 | """Adds an initial convolution layer (with batch normalization and relu6). 361 | # Arguments 362 | inputs: Input tensor of shape `(rows, cols, 3)` 363 | (with `channels_last` data format) or 364 | (3, rows, cols) (with `channels_first` data format). 365 | It should have exactly 3 inputs channels, 366 | and width and height should be no smaller than 32. 367 | E.g. `(224, 224, 3)` would be one valid value. 368 | filters: Integer, the dimensionality of the output space 369 | (i.e. the number output of filters in the convolution). 370 | alpha: controls the width of the network. 371 | - If `alpha` < 1.0, proportionally decreases the number 372 | of filters in each layer. 373 | - If `alpha` > 1.0, proportionally increases the number 374 | of filters in each layer. 375 | - If `alpha` = 1, default number of filters from the paper 376 | are used at each layer. 377 | kernel: An integer or tuple/list of 2 integers, specifying the 378 | width and height of the 2D convolution window. 379 | Can be a single integer to specify the same value for 380 | all spatial dimensions. 381 | strides: An integer or tuple/list of 2 integers, 382 | specifying the strides of the convolution along the width and height. 383 | Can be a single integer to specify the same value for 384 | all spatial dimensions. 385 | Specifying any stride value != 1 is incompatible with specifying 386 | any `dilation_rate` value != 1. 387 | bn_epsilon: Epsilon value for BatchNormalization 388 | bn_momentum: Momentum value for BatchNormalization 389 | # Input shape 390 | 4D tensor with shape: 391 | `(samples, channels, rows, cols)` if data_format='channels_first' 392 | or 4D tensor with shape: 393 | `(samples, rows, cols, channels)` if data_format='channels_last'. 394 | # Output shape 395 | 4D tensor with shape: 396 | `(samples, filters, new_rows, new_cols)` if data_format='channels_first' 397 | or 4D tensor with shape: 398 | `(samples, new_rows, new_cols, filters)` if data_format='channels_last'. 399 | `rows` and `cols` values might have changed due to stride. 400 | # Returns 401 | Output tensor of block. 402 | """ 403 | channel_axis = 1 if K.image_data_format() == 'channels_first' else -1 404 | filters = filters * alpha 405 | filters = _make_divisible(filters) 406 | x = Conv2D(filters, kernel, 407 | padding='same', 408 | use_bias=False, 409 | strides=strides, 410 | kernel_initializer=initializers.he_normal(), 411 | kernel_regularizer=regularizers.l2(weight_decay), 412 | name='conv%d' % block_id)(inputs) 413 | x = BatchNormalization(axis=channel_axis, momentum=bn_momentum, epsilon=bn_epsilon, 414 | name='conv%d_bn' % block_id)(x) 415 | return Activation(relu6, name='conv%d_relu' % block_id)(x) 416 | 417 | 418 | def _depthwise_conv_block_v2(inputs, pointwise_conv_filters, alpha, expansion_factor, 419 | depth_multiplier=1, strides=(1, 1), bn_epsilon=1e-3, 420 | bn_momentum=0.99, weight_decay=0.0, block_id=1): 421 | """Adds a depthwise convolution block V2. 422 | A depthwise convolution V2 block consists of a depthwise conv, 423 | batch normalization, relu6, pointwise convolution, 424 | batch normalization and relu6 activation. 425 | # Arguments 426 | inputs: Input tensor of shape `(rows, cols, channels)` 427 | (with `channels_last` data format) or 428 | (channels, rows, cols) (with `channels_first` data format). 429 | pointwise_conv_filters: Integer, the dimensionality of the output space 430 | (i.e. the number output of filters in the pointwise convolution). 431 | alpha: controls the width of the network. 432 | - If `alpha` < 1.0, proportionally decreases the number 433 | of filters in each layer. 434 | - If `alpha` > 1.0, proportionally increases the number 435 | of filters in each layer. 436 | - If `alpha` = 1, default number of filters from the paper 437 | are used at each layer. 438 | expansion_factor: controls the expansion of the internal bottleneck 439 | blocks. Should be a positive integer >= 1 440 | depth_multiplier: The number of depthwise convolution output channels 441 | for each input channel. 442 | The total number of depthwise convolution output 443 | channels will be equal to `filters_in * depth_multiplier`. 444 | strides: An integer or tuple/list of 2 integers, 445 | specifying the strides of the convolution along the width and height. 446 | Can be a single integer to specify the same value for 447 | all spatial dimensions. 448 | Specifying any stride value != 1 is incompatible with specifying 449 | any `dilation_rate` value != 1. 450 | bn_epsilon: Epsilon value for BatchNormalization 451 | bn_momentum: Momentum value for BatchNormalization 452 | block_id: Integer, a unique identification designating the block number. 453 | # Input shape 454 | 4D tensor with shape: 455 | `(batch, channels, rows, cols)` if data_format='channels_first' 456 | or 4D tensor with shape: 457 | `(batch, rows, cols, channels)` if data_format='channels_last'. 458 | # Output shape 459 | 4D tensor with shape: 460 | `(batch, filters, new_rows, new_cols)` if data_format='channels_first' 461 | or 4D tensor with shape: 462 | `(batch, new_rows, new_cols, filters)` if data_format='channels_last'. 463 | `rows` and `cols` values might have changed due to stride. 464 | # Returns 465 | Output tensor of block. 466 | """ 467 | channel_axis = 1 if K.image_data_format() == 'channels_first' else -1 468 | input_shape = K.int_shape(inputs) 469 | depthwise_conv_filters = _make_divisible(input_shape[channel_axis] * expansion_factor) 470 | pointwise_conv_filters = _make_divisible(pointwise_conv_filters * alpha) 471 | 472 | if depthwise_conv_filters > input_shape[channel_axis]: 473 | x = Conv2D(depthwise_conv_filters, (1, 1), 474 | padding='same', 475 | use_bias=False, 476 | strides=(1, 1), 477 | kernel_initializer=initializers.he_normal(), 478 | kernel_regularizer=regularizers.l2(weight_decay), 479 | name='conv_expand_%d' % block_id)(inputs) 480 | x = BatchNormalization(axis=channel_axis, momentum=bn_momentum, epsilon=bn_epsilon, 481 | name='conv_expand_%d_bn' % block_id)(x) 482 | x = Activation(relu6, name='conv_expand_%d_relu' % block_id)(x) 483 | else: 484 | x = inputs 485 | 486 | x = DepthwiseConv2D((3, 3), 487 | padding='same', 488 | depth_multiplier=depth_multiplier, 489 | strides=strides, 490 | use_bias=False, 491 | depthwise_initializer=initializers.he_normal(), 492 | depthwise_regularizer=regularizers.l2(weight_decay), 493 | name='conv_dw_%d' % block_id)(x) 494 | x = BatchNormalization(axis=channel_axis, momentum=bn_momentum, epsilon=bn_epsilon, 495 | name='conv_dw_%d_bn' % block_id)(x) 496 | x = Activation(relu6, name='conv_dw_%d_relu' % block_id)(x) 497 | 498 | x = Conv2D(pointwise_conv_filters, (1, 1), 499 | padding='same', 500 | use_bias=False, 501 | strides=(1, 1), 502 | kernel_initializer=initializers.he_normal(), 503 | kernel_regularizer=regularizers.l2(weight_decay), 504 | name='conv_pw_%d' % block_id)(x) 505 | x = BatchNormalization(axis=channel_axis, momentum=bn_momentum, epsilon=bn_epsilon, 506 | name='conv_pw_%d_bn' % block_id)(x) 507 | 508 | if strides == (2, 2): 509 | return x 510 | else: 511 | if input_shape[channel_axis] == pointwise_conv_filters: 512 | 513 | x = add([inputs, x]) 514 | 515 | return x 516 | 517 | 518 | if __name__ == '__main__': 519 | import tensorflow as tf 520 | from keras import backend as K 521 | 522 | run_metadata = tf.RunMetadata() 523 | 524 | with tf.Session(graph=tf.Graph()) as sess: 525 | K.set_session(sess) 526 | 527 | model = MiniMobileNetV2(input_tensor=tf.placeholder('float32', shape=(1, 224, 224, 3)), alpha=1.0) 528 | opt = tf.profiler.ProfileOptionBuilder.float_operation() 529 | flops = tf.profiler.profile(sess.graph, run_meta=run_metadata, cmd='op', options=opt) 530 | 531 | opt = tf.profiler.ProfileOptionBuilder.trainable_variables_parameter() 532 | param_count = tf.profiler.profile(sess.graph, run_meta=run_metadata, cmd='op', options=opt) 533 | 534 | print('flops:', flops.total_float_ops) 535 | print('param count:', param_count.total_parameters) 536 | 537 | model.summary() 538 | -------------------------------------------------------------------------------- /models/mobilenet/train_cifar_10.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from keras example cifar10_cnn.py 3 | Train NASNet-CIFAR on the CIFAR10 small images dataset. 4 | """ 5 | from __future__ import print_function 6 | import os 7 | 8 | from keras.datasets import cifar10 9 | from keras.preprocessing.image import ImageDataGenerator 10 | from keras.utils import np_utils 11 | from keras.callbacks import ModelCheckpoint 12 | from keras.optimizers import SGD 13 | import numpy as np 14 | 15 | from clr import OneCycleLR 16 | from models.mobilenet.mobilenets import MiniMobileNetV2 17 | 18 | if not os.path.exists('weights/'): 19 | os.makedirs('weights/') 20 | 21 | weights_file = 'weights/mobilenet_v2.h5' 22 | model_checkpoint = ModelCheckpoint( 23 | weights_file, 24 | monitor='val_acc', 25 | save_best_only=True, 26 | save_weights_only=True, 27 | mode='max') 28 | batch_size = 128 29 | nb_classes = 10 30 | nb_epoch = 100 # Only finding lr 31 | data_augmentation = True 32 | 33 | # input image dimensions 34 | img_rows, img_cols = 32, 32 35 | # The CIFAR10 images are RGB. 36 | img_channels = 3 37 | 38 | # The data, shuffled and split between train and test sets: 39 | (X_train, y_train), (X_test, y_test) = cifar10.load_data() 40 | 41 | # Convert class vectors to binary class matrices. 42 | Y_train = np_utils.to_categorical(y_train, nb_classes) 43 | Y_test = np_utils.to_categorical(y_test, nb_classes) 44 | 45 | X_train = X_train.astype('float32') 46 | X_test = X_test.astype('float32') 47 | 48 | # preprocess input 49 | mean = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32') 50 | std = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32') 51 | 52 | print("Channel Mean : ", mean) 53 | print("Channel Std : ", std) 54 | 55 | X_train = (X_train - mean) / (std) 56 | X_test = (X_test - mean) / (std) 57 | 58 | # Learning rate finder callback setup 59 | num_samples = X_train.shape[0] 60 | 61 | lr_manager = OneCycleLR(max_lr=0.02, maximum_momentum=0.9, verbose=True) 62 | 63 | # For training, the auxilary branch must be used to correctly train NASNet 64 | model = MiniMobileNetV2((img_rows, img_cols, img_channels), 65 | alpha=1.4, 66 | weight_decay=1e-6, 67 | weights=None, 68 | classes=nb_classes) 69 | model.summary() 70 | 71 | # These values will be overridden by the above callback 72 | optimizer = SGD(lr=0.002, momentum=0.9, nesterov=True) 73 | model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy']) 74 | 75 | model.load_weights(weights_file) 76 | 77 | if not data_augmentation: 78 | print('Not using data augmentation.') 79 | model.fit( 80 | X_train, 81 | Y_train, 82 | batch_size=batch_size, 83 | epochs=nb_epoch, 84 | validation_data=(X_test, Y_test), 85 | shuffle=True, 86 | verbose=1, 87 | callbacks=[lr_manager, model_checkpoint]) 88 | else: 89 | print('Using real-time data augmentation.') 90 | # This will do preprocessing and realtime data augmentation: 91 | datagen = ImageDataGenerator( 92 | featurewise_center=False, # set input mean to 0 over the dataset 93 | samplewise_center=False, # set each sample mean to 0 94 | featurewise_std_normalization=False, # divide inputs by std of the dataset 95 | samplewise_std_normalization=False, # divide each input by its std 96 | zca_whitening=False, # apply ZCA whitening 97 | # randomly rotate images in the range (degrees, 0 to 180) 98 | rotation_range=0, 99 | # randomly shift images horizontally (fraction of total width) 100 | width_shift_range=0, 101 | # randomly shift images vertically (fraction of total height) 102 | height_shift_range=0, 103 | horizontal_flip=True, # randomly flip images 104 | vertical_flip=False) # randomly flip images 105 | 106 | # Compute quantities required for featurewise normalization 107 | # (std, mean, and principal components if ZCA whitening is applied). 108 | datagen.fit(X_train) 109 | 110 | # Fit the model on the batches generated by datagen.flow(). 111 | model.fit_generator( 112 | datagen.flow(X_train, Y_train, batch_size=batch_size, shuffle=True), 113 | steps_per_epoch=X_train.shape[0] // batch_size, 114 | validation_data=(X_test, Y_test), 115 | epochs=nb_epoch, 116 | verbose=1, 117 | callbacks=[lr_manager, model_checkpoint]) 118 | 119 | scores = model.evaluate(X_test, Y_test, batch_size=batch_size) 120 | for score, metric_name in zip(scores, model.metrics_names): 121 | print("%s : %0.4f" % (metric_name, score)) 122 | -------------------------------------------------------------------------------- /models/mobilenet/weights/losses.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/losses.npy -------------------------------------------------------------------------------- /models/mobilenet/weights/lrs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/lrs.npy -------------------------------------------------------------------------------- /models/mobilenet/weights/mobilenet_v2 - 9033.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/mobilenet_v2 - 9033.h5 -------------------------------------------------------------------------------- /models/mobilenet/weights/mobilenet_v2.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/mobilenet_v2.h5 -------------------------------------------------------------------------------- /models/mobilenet/weights/momentum/momentum-0.9/losses.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/momentum/momentum-0.9/losses.npy -------------------------------------------------------------------------------- /models/mobilenet/weights/momentum/momentum-0.9/lrs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/momentum/momentum-0.9/lrs.npy -------------------------------------------------------------------------------- /models/mobilenet/weights/momentum/momentum-0.95/losses.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/momentum/momentum-0.95/losses.npy -------------------------------------------------------------------------------- /models/mobilenet/weights/momentum/momentum-0.95/lrs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/momentum/momentum-0.95/lrs.npy -------------------------------------------------------------------------------- /models/mobilenet/weights/momentum/momentum-0.99/losses.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/momentum/momentum-0.99/losses.npy -------------------------------------------------------------------------------- /models/mobilenet/weights/momentum/momentum-0.99/lrs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/momentum/momentum-0.99/lrs.npy -------------------------------------------------------------------------------- /models/mobilenet/weights/weight_decay/weight_decay-1e-05/losses.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/weight_decay/weight_decay-1e-05/losses.npy -------------------------------------------------------------------------------- /models/mobilenet/weights/weight_decay/weight_decay-1e-05/lrs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/weight_decay/weight_decay-1e-05/lrs.npy -------------------------------------------------------------------------------- /models/mobilenet/weights/weight_decay/weight_decay-1e-06/losses.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/weight_decay/weight_decay-1e-06/losses.npy -------------------------------------------------------------------------------- /models/mobilenet/weights/weight_decay/weight_decay-1e-06/lrs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/weight_decay/weight_decay-1e-06/lrs.npy -------------------------------------------------------------------------------- /models/mobilenet/weights/weight_decay/weight_decay-1e-07/losses.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/weight_decay/weight_decay-1e-07/losses.npy -------------------------------------------------------------------------------- /models/mobilenet/weights/weight_decay/weight_decay-1e-07/lrs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/weight_decay/weight_decay-1e-07/lrs.npy -------------------------------------------------------------------------------- /models/mobilenet/weights/weight_decay/weight_decay-3e-05/losses.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/weight_decay/weight_decay-3e-05/losses.npy -------------------------------------------------------------------------------- /models/mobilenet/weights/weight_decay/weight_decay-3e-05/lrs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/weight_decay/weight_decay-3e-05/lrs.npy -------------------------------------------------------------------------------- /models/mobilenet/weights/weight_decay/weight_decay-3e-06/losses.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/weight_decay/weight_decay-3e-06/losses.npy -------------------------------------------------------------------------------- /models/mobilenet/weights/weight_decay/weight_decay-3e-06/lrs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/mobilenet/weights/weight_decay/weight_decay-3e-06/lrs.npy -------------------------------------------------------------------------------- /models/small/find_lr_schedule.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from keras example cifar10_cnn.py 3 | Train NASNet-CIFAR on the CIFAR10 small images dataset. 4 | """ 5 | from __future__ import print_function 6 | import os 7 | 8 | from keras.datasets import cifar10 9 | from keras.preprocessing.image import ImageDataGenerator 10 | from keras.utils import np_utils 11 | from keras.callbacks import ModelCheckpoint 12 | from keras.optimizers import SGD 13 | import numpy as np 14 | 15 | from clr import LRFinder 16 | from models.small.model import MiniVGG 17 | 18 | if not os.path.exists('weights/'): 19 | os.makedirs('weights/') 20 | 21 | weights_file = 'weights/small_vgg_v2_schedule.h5' 22 | model_checkpoint = ModelCheckpoint(weights_file, monitor='val_acc', save_best_only=True, 23 | save_weights_only=True, mode='max') 24 | 25 | batch_size = 128 26 | nb_classes = 10 27 | nb_epoch = 1 # Only finding lr 28 | data_augmentation = True 29 | 30 | # input image dimensions 31 | img_rows, img_cols = 32, 32 32 | # The CIFAR10 images are RGB. 33 | img_channels = 3 34 | 35 | # The data, shuffled and split between train and test sets: 36 | (X_train, y_train), (X_test, y_test) = cifar10.load_data() 37 | 38 | # Convert class vectors to binary class matrices. 39 | Y_train = np_utils.to_categorical(y_train, nb_classes) 40 | Y_test = np_utils.to_categorical(y_test, nb_classes) 41 | 42 | X_train = X_train.astype('float32') 43 | X_test = X_test.astype('float32') 44 | 45 | # preprocess input 46 | mean = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32') 47 | std = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32') 48 | 49 | print("Channel Mean : ", mean) 50 | print("Channel Std : ", std) 51 | 52 | X_train = (X_train - mean) / (std) 53 | X_test = (X_test - mean) / (std) 54 | 55 | # Learning rate finder callback setup 56 | num_samples = X_train.shape[0] 57 | 58 | # Exponential lr finder 59 | # USE THIS FOR A LARGE RANGE SEARCH 60 | # Uncomment the validation_data flag to reduce speed but get a better idea of the learning rate 61 | lr_finder = LRFinder(num_samples, batch_size, minimum_lr=1e-5, maximum_lr=10., 62 | lr_scale='exp', 63 | validation_data=(X_test, Y_test), # use the validation data for losses 64 | validation_sample_rate=5, 65 | save_dir='weights/', verbose=True) 66 | 67 | # Linear lr finder 68 | # USE THIS FOR A CLOSE SEARCH 69 | # Uncomment the validation_data flag to reduce speed but get a better idea of the learning rate 70 | # lr_finder = LRFinder(num_samples, batch_size, minimum_lr=5e-4, maximum_lr=1e-2, 71 | # lr_scale='linear', 72 | # validation_data=(X_test, y_test), # use the validation data for losses 73 | # validation_sample_rate=5, 74 | # save_dir='weights/', verbose=True) 75 | 76 | # plot the previous values if present 77 | LRFinder.plot_schedule_from_file('weights/', clip_beginning=10, clip_endding=5) 78 | 79 | # For training, the auxilary branch must be used to correctly train NASNet 80 | 81 | model = MiniVGG((img_rows, img_cols, img_channels), 82 | dropout=0, weights=None, classes=nb_classes) 83 | model.summary() 84 | 85 | optimizer = SGD(lr=0.1, momentum=0.9, nesterov=True) 86 | model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy']) 87 | 88 | # model.load_weights(weights_file) 89 | 90 | if not data_augmentation: 91 | print('Not using data augmentation.') 92 | model.fit(X_train, Y_train, 93 | batch_size=batch_size, 94 | epochs=nb_epoch, 95 | validation_data=(X_test, Y_test), 96 | shuffle=True, 97 | verbose=1, 98 | callbacks=[lr_finder, model_checkpoint]) 99 | else: 100 | print('Using real-time data augmentation.') 101 | # This will do preprocessing and realtime data augmentation: 102 | datagen = ImageDataGenerator( 103 | featurewise_center=False, # set input mean to 0 over the dataset 104 | samplewise_center=False, # set each sample mean to 0 105 | featurewise_std_normalization=False, # divide inputs by std of the dataset 106 | samplewise_std_normalization=False, # divide each input by its std 107 | zca_whitening=False, # apply ZCA whitening 108 | rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180) 109 | width_shift_range=0, # randomly shift images horizontally (fraction of total width) 110 | height_shift_range=0, # randomly shift images vertically (fraction of total height) 111 | horizontal_flip=True, # randomly flip images 112 | vertical_flip=False) # randomly flip images 113 | 114 | # Compute quantities required for featurewise normalization 115 | # (std, mean, and principal components if ZCA whitening is applied). 116 | datagen.fit(X_train) 117 | 118 | # Fit the model on the batches generated by datagen.flow(). 119 | model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size, shuffle=True), 120 | steps_per_epoch=X_train.shape[0] // batch_size, 121 | validation_data=(X_test, Y_test), 122 | epochs=nb_epoch, verbose=1, 123 | callbacks=[lr_finder, model_checkpoint]) 124 | 125 | lr_finder.plot_schedule(clip_beginning=10, clip_endding=5) 126 | 127 | scores = model.evaluate(X_test, Y_test, batch_size=batch_size) 128 | for score, metric_name in zip(scores, model.metrics_names): 129 | print("%s : %0.4f" % (metric_name, score)) 130 | -------------------------------------------------------------------------------- /models/small/find_momentum_schedule.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from keras example cifar10_cnn.py 3 | Train NASNet-CIFAR on the CIFAR10 small images dataset. 4 | """ 5 | from __future__ import print_function 6 | import os 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | from keras.datasets import cifar10 11 | from keras.preprocessing.image import ImageDataGenerator 12 | from keras.utils import np_utils 13 | from keras.callbacks import ModelCheckpoint 14 | from keras.optimizers import SGD 15 | from keras import backend as K 16 | 17 | from clr import LRFinder 18 | from models.small.model import MiniVGG 19 | 20 | plt.style.use('seaborn-white') 21 | 22 | batch_size = 128 23 | nb_classes = 10 24 | nb_epoch = 1 # Only finding lr 25 | data_augmentation = True 26 | 27 | # input image dimensions 28 | img_rows, img_cols = 32, 32 29 | # The CIFAR10 images are RGB. 30 | img_channels = 3 31 | 32 | # The data, shuffled and split between train and test sets: 33 | (X_train, y_train), (X_test, y_test) = cifar10.load_data() 34 | 35 | # Convert class vectors to binary class matrices. 36 | Y_train = np_utils.to_categorical(y_train, nb_classes) 37 | Y_test = np_utils.to_categorical(y_test, nb_classes) 38 | 39 | X_train = X_train.astype('float32') 40 | X_test = X_test.astype('float32') 41 | 42 | # preprocess input 43 | mean = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32') 44 | std = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32') 45 | 46 | print("Channel Mean : ", mean) 47 | print("Channel Std : ", std) 48 | 49 | X_train = (X_train - mean) / (std) 50 | X_test = (X_test - mean) / (std) 51 | 52 | # Learning rate finder callback setup 53 | num_samples = X_train.shape[0] 54 | 55 | MOMENTUMS = [0.9, 0.95, 0.99] 56 | 57 | # for momentum in MOMENTUMS: 58 | # K.clear_session() 59 | # 60 | # # Learning rate range obtained from `find_lr_schedule.py` 61 | # # NOTE : Minimum is 10x smaller than the max found above ! 62 | # # NOTE : It is preferable to use the validation data here to get a correct value 63 | # lr_finder = LRFinder(num_samples, batch_size, minimum_lr=0.00125, maximum_lr=0.0125, 64 | # validation_data=(X_test, Y_test), 65 | # validation_sample_rate=5, 66 | # lr_scale='linear', save_dir='weights/momentum/momentum-%s/' % str(momentum), 67 | # verbose=True) 68 | # 69 | # model = MiniVGG((img_rows, img_cols, img_channels), 70 | # dropout=0, weights=None, classes=nb_classes) 71 | # model.summary() 72 | # 73 | # # set the weight_decay here ! 74 | # # lr doesnt matter as it will be over written by the callback 75 | # optimizer = SGD(lr=0.00125, momentum=momentum, nesterov=True) 76 | # model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy']) 77 | # 78 | # # model.load_weights(weights_file) 79 | # 80 | # if not data_augmentation: 81 | # print('Not using data augmentation.') 82 | # model.fit(X_train, Y_train, 83 | # batch_size=batch_size, 84 | # epochs=nb_epoch, 85 | # validation_data=(X_test, Y_test), 86 | # shuffle=True, 87 | # verbose=1, 88 | # callbacks=[lr_finder]) 89 | # else: 90 | # print('Using real-time data augmentation.') 91 | # # This will do preprocessing and realtime data augmentation: 92 | # datagen = ImageDataGenerator( 93 | # featurewise_center=False, # set input mean to 0 over the dataset 94 | # samplewise_center=False, # set each sample mean to 0 95 | # featurewise_std_normalization=False, # divide inputs by std of the dataset 96 | # samplewise_std_normalization=False, # divide each input by its std 97 | # zca_whitening=False, # apply ZCA whitening 98 | # rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180) 99 | # width_shift_range=0, # randomly shift images horizontally (fraction of total width) 100 | # height_shift_range=0, # randomly shift images vertically (fraction of total height) 101 | # horizontal_flip=True, # randomly flip images 102 | # vertical_flip=False) # randomly flip images 103 | # 104 | # # Compute quantities required for featurewise normalization 105 | # # (std, mean, and principal components if ZCA whitening is applied). 106 | # datagen.fit(X_train) 107 | # 108 | # # Fit the model on the batches generated by datagen.flow(). 109 | # model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size, shuffle=True), 110 | # steps_per_epoch=X_train.shape[0] // batch_size, 111 | # validation_data=(X_test, Y_test), 112 | # epochs=nb_epoch, verbose=1, 113 | # callbacks=[lr_finder]) 114 | 115 | # from plot we see, the model isnt impacted by the weight_decay very much at all 116 | # so we can use any of them. 117 | 118 | for momentum in MOMENTUMS: 119 | directory = 'weights/momentum/momentum-%s/' % str(momentum) 120 | 121 | losses, lrs = LRFinder.restore_schedule_from_dir(directory, 10, 5) 122 | plt.plot(lrs, losses, label='momentum=%0.2f' % momentum) 123 | 124 | plt.title("Momentum") 125 | plt.xlabel("Learning rate") 126 | plt.ylabel("Validation Loss") 127 | plt.legend() 128 | plt.show() 129 | -------------------------------------------------------------------------------- /models/small/find_weight_decay_schedule.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from keras example cifar10_cnn.py 3 | Train NASNet-CIFAR on the CIFAR10 small images dataset. 4 | """ 5 | from __future__ import print_function 6 | import os 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | from keras.datasets import cifar10 11 | from keras.preprocessing.image import ImageDataGenerator 12 | from keras.utils import np_utils 13 | from keras.callbacks import ModelCheckpoint 14 | from keras.optimizers import SGD 15 | from keras import backend as K 16 | 17 | from clr import LRFinder 18 | from models.small.model import MiniVGG 19 | 20 | plt.style.use('seaborn-white') 21 | 22 | batch_size = 128 23 | nb_classes = 10 24 | nb_epoch = 1 # Only finding lr 25 | data_augmentation = True 26 | 27 | # input image dimensions 28 | img_rows, img_cols = 32, 32 29 | # The CIFAR10 images are RGB. 30 | img_channels = 3 31 | 32 | # The data, shuffled and split between train and test sets: 33 | (X_train, y_train), (X_test, y_test) = cifar10.load_data() 34 | 35 | # Convert class vectors to binary class matrices. 36 | Y_train = np_utils.to_categorical(y_train, nb_classes) 37 | Y_test = np_utils.to_categorical(y_test, nb_classes) 38 | 39 | X_train = X_train.astype('float32') 40 | X_test = X_test.astype('float32') 41 | 42 | # preprocess input 43 | mean = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32') 44 | std = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32') 45 | 46 | print("Channel Mean : ", mean) 47 | print("Channel Std : ", std) 48 | 49 | X_train = (X_train - mean) / (std) 50 | X_test = (X_test - mean) / (std) 51 | 52 | # Learning rate finder callback setup 53 | num_samples = X_train.shape[0] 54 | 55 | # INITIAL WEIGHT DECAY FACTORS 56 | WEIGHT_DECAY_FACTORS = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7] 57 | 58 | # FINEGRAINED WEIGHT DECAY FACTORS 59 | # WEIGHT_DECAY_FACTORS = [3e-3, 3e-4] 60 | 61 | # for weight_decay in WEIGHT_DECAY_FACTORS: 62 | # K.clear_session() 63 | # 64 | # # Learning rate range obtained from `find_lr_schedule.py` 65 | # # NOTE : Minimum is 10x smaller than the max found above ! 66 | # # NOTE : It is preferable to use the validation data here to get a correct value 67 | # lr_finder = LRFinder(num_samples, batch_size, minimum_lr=0.00125, maximum_lr=0.0125, 68 | # validation_data=(X_test, Y_test), 69 | # validation_sample_rate=5, 70 | # lr_scale='linear', save_dir='weights/weight_decay/weight_decay-%s/' % str(weight_decay), 71 | # verbose=True) 72 | # 73 | # # SETUP THE WEIGHT DECAY IN THE MODEL 74 | # model = MiniVGG((img_rows, img_cols, img_channels), 75 | # weight_decay=weight_decay, dropout=0, 76 | # weights=None, classes=nb_classes) 77 | # model.summary() 78 | # 79 | # # set the weight_decay here ! 80 | # # lr doesnt matter as it will be over written by the callback 81 | # optimizer = SGD(lr=0.002, momentum=0.95, nesterov=True) 82 | # model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy']) 83 | # 84 | # # model.load_weights(weights_file) 85 | # 86 | # if not data_augmentation: 87 | # print('Not using data augmentation.') 88 | # model.fit(X_train, Y_train, 89 | # batch_size=batch_size, 90 | # epochs=nb_epoch, 91 | # validation_data=(X_test, Y_test), 92 | # shuffle=True, 93 | # verbose=1, 94 | # callbacks=[lr_finder]) 95 | # else: 96 | # print('Using real-time data augmentation.') 97 | # # This will do preprocessing and realtime data augmentation: 98 | # datagen = ImageDataGenerator( 99 | # featurewise_center=False, # set input mean to 0 over the dataset 100 | # samplewise_center=False, # set each sample mean to 0 101 | # featurewise_std_normalization=False, # divide inputs by std of the dataset 102 | # samplewise_std_normalization=False, # divide each input by its std 103 | # zca_whitening=False, # apply ZCA whitening 104 | # rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180) 105 | # width_shift_range=0, # randomly shift images horizontally (fraction of total width) 106 | # height_shift_range=0, # randomly shift images vertically (fraction of total height) 107 | # horizontal_flip=True, # randomly flip images 108 | # vertical_flip=False) # randomly flip images 109 | # 110 | # # Compute quantities required for featurewise normalization 111 | # # (std, mean, and principal components if ZCA whitening is applied). 112 | # datagen.fit(X_train) 113 | # 114 | # # Fit the model on the batches generated by datagen.flow(). 115 | # model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size, shuffle=True), 116 | # steps_per_epoch=X_train.shape[0] // batch_size, 117 | # validation_data=(X_test, Y_test), 118 | # epochs=nb_epoch, verbose=1, 119 | # callbacks=[lr_finder]) 120 | 121 | # from plot we see, the model isnt impacted by the weight_decay very much at all 122 | # so we can use any of them. 123 | 124 | WEIGHT_DECAY_FACTORS = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7] + [3e-3, 3e-4] 125 | 126 | for weight_decay in WEIGHT_DECAY_FACTORS: 127 | directory = 'weights/weight_decay/weight_decay-%s/' % str(weight_decay) 128 | 129 | losses, lrs = LRFinder.restore_schedule_from_dir(directory, 10, 5) 130 | plt.plot(lrs, losses, label='weight_decay=%0.7f' % weight_decay) 131 | 132 | plt.title("Weight Decay") 133 | plt.xlabel("Learning rate") 134 | plt.ylabel("Validation Loss") 135 | plt.legend() 136 | plt.show() 137 | -------------------------------------------------------------------------------- /models/small/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | import warnings 6 | import math 7 | 8 | from keras.models import Model 9 | from keras.layers import Input 10 | from keras.layers import Activation 11 | from keras.layers import Dropout 12 | from keras.layers import Reshape 13 | from keras.layers import BatchNormalization 14 | from keras.layers import GlobalAveragePooling2D 15 | from keras.layers import GlobalMaxPooling2D 16 | from keras.layers import Conv2D 17 | from keras.layers import Dense 18 | from keras.layers import add, concatenate 19 | from keras import initializers 20 | from keras import regularizers 21 | from keras import constraints 22 | from keras.utils import conv_utils 23 | from keras.utils.data_utils import get_file 24 | from keras.engine.topology import get_source_inputs 25 | from keras.engine import InputSpec 26 | from keras.applications.imagenet_utils import _obtain_input_shape 27 | from keras.applications.inception_v3 import preprocess_input 28 | from keras.applications.imagenet_utils import decode_predictions 29 | from keras import backend as K 30 | 31 | import tensorflow as tf 32 | 33 | BASE_WEIGHT_PATH = '' 34 | BASE_WEIGHT_PATH_V2 = '' 35 | 36 | 37 | def relu6(x): 38 | return K.relu(x, max_value=6) 39 | 40 | 41 | def MiniVGG(input_shape=None, 42 | dropout=0., 43 | weight_decay=0., 44 | include_top=True, 45 | weights=None, 46 | input_tensor=None, 47 | pooling=None, 48 | classes=10): 49 | """Mini VGG is a 3 layer small CNN network 50 | 51 | # Arguments 52 | input_shape: optional shape tuple, only to be specified 53 | if `include_top` is False (otherwise the input shape 54 | has to be `(224, 224, 3)` (with `channels_last` data format) 55 | or (3, 224, 224) (with `channels_first` data format). 56 | It should have exactly 3 inputs channels, 57 | and width and height should be no smaller than 32. 58 | E.g. `(200, 200, 3)` would be one valid value. 59 | dropout: dropout rate 60 | weight_decay: Weight decay factor. 61 | include_top: whether to include the fully-connected 62 | layer at the top of the network. 63 | weights: `None` (random initialization) or 64 | `imagenet` (ImageNet weights) 65 | input_tensor: optional Keras tensor (i.e. output of 66 | `layers.Input()`) 67 | to use as image input for the model. 68 | pooling: Optional pooling mode for feature extraction 69 | when `include_top` is `False`. 70 | - `None` means that the output of the model 71 | will be the 4D tensor output of the 72 | last convolutional layer. 73 | - `avg` means that global average pooling 74 | will be applied to the output of the 75 | last convolutional layer, and thus 76 | the output of the model will be a 77 | 2D tensor. 78 | - `max` means that global max pooling will 79 | be applied. 80 | classes: optional number of classes to classify images 81 | into, only to be specified if `include_top` is True, and 82 | if no `weights` argument is specified. 83 | # Returns 84 | A Keras model instance. 85 | # Raises 86 | ValueError: in case of invalid argument for `weights`, 87 | or invalid input shape. 88 | RuntimeError: If attempting to run this model with a 89 | backend that does not support separable convolutions. 90 | """ 91 | 92 | if K.backend() != 'tensorflow': 93 | raise RuntimeError('Only Tensorflow backend is currently supported, ' 94 | 'as other backends do not support ' 95 | 'depthwise convolution.') 96 | 97 | if weights not in {'imagenet', None}: 98 | raise ValueError('The `weights` argument should be either ' 99 | '`None` (random initialization) or `imagenet` ' 100 | '(pre-training on ImageNet).') 101 | 102 | if weights == 'imagenet' and include_top and classes != 1000: 103 | raise ValueError('If using `weights` as ImageNet with `include_top` ' 104 | 'as true, `classes` should be 1000') 105 | 106 | # Determine proper input shape and default size. 107 | if input_shape is None: 108 | default_size = 224 109 | else: 110 | if K.image_data_format() == 'channels_first': 111 | rows = input_shape[1] 112 | cols = input_shape[2] 113 | else: 114 | rows = input_shape[0] 115 | cols = input_shape[1] 116 | 117 | if rows == cols and rows in [96, 128, 160, 192, 224]: 118 | default_size = rows 119 | else: 120 | default_size = 224 121 | 122 | input_shape = _obtain_input_shape(input_shape, 123 | default_size=default_size, 124 | min_size=32, 125 | data_format=K.image_data_format(), 126 | require_flatten=include_top or weights) 127 | if K.image_data_format() == 'channels_last': 128 | row_axis, col_axis = (0, 1) 129 | channel_axis = -1 130 | else: 131 | row_axis, col_axis = (1, 2) 132 | channel_axis = 1 133 | 134 | rows = input_shape[row_axis] 135 | cols = input_shape[col_axis] 136 | 137 | if input_tensor is None: 138 | img_input = Input(shape=input_shape) 139 | else: 140 | if not K.is_keras_tensor(input_tensor): 141 | img_input = Input(tensor=input_tensor, shape=input_shape) 142 | else: 143 | img_input = input_tensor 144 | 145 | x = _conv_block(img_input, 32, bn_epsilon=1e-3, bn_momentum=0.99, weight_decay=weight_decay, block_id=1) 146 | x = _conv_block(x, 64, bn_epsilon=1e-3, bn_momentum=0.99, weight_decay=weight_decay, block_id=2) 147 | x = _conv_block(x, 96, bn_epsilon=1e-3, bn_momentum=0.99, weight_decay=weight_decay, block_id=3) 148 | 149 | if include_top: 150 | 151 | # Fast.ai's Concat Pooling 152 | a = GlobalAveragePooling2D()(x) 153 | b = GlobalMaxPooling2D()(x) 154 | 155 | x = concatenate([a, b], axis=channel_axis) 156 | 157 | x = Dropout(dropout, name='dropout')(x) 158 | x = Dense(classes, activation='softmax', name='conv_preds')(x) 159 | else: 160 | if pooling == 'avg': 161 | x = GlobalAveragePooling2D()(x) 162 | elif pooling == 'max': 163 | x = GlobalMaxPooling2D()(x) 164 | 165 | # Ensure that the model takes into account 166 | # any potential predecessors of `input_tensor`. 167 | if input_tensor is not None: 168 | inputs = get_source_inputs(input_tensor) 169 | else: 170 | inputs = img_input 171 | 172 | # Create model. 173 | model = Model(inputs, x, name='mini_vgg_%0.2f_%s') 174 | return model 175 | 176 | 177 | # taken from https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/conv_blocks.py 178 | def _make_divisible(v, divisor=8, min_value=8): 179 | if min_value is None: 180 | min_value = divisor 181 | 182 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 183 | # Make sure that round down does not go down by more than 10%. 184 | if new_v < 0.9 * v: 185 | new_v += divisor 186 | return new_v 187 | 188 | 189 | def _conv_block(inputs, filters, kernel=(3, 3), strides=(1, 1), bn_epsilon=1e-3, 190 | bn_momentum=0.99, weight_decay=0., block_id=1): 191 | """Adds an initial convolution layer (with batch normalization and relu6). 192 | # Arguments 193 | inputs: Input tensor of shape `(rows, cols, 3)` 194 | (with `channels_last` data format) or 195 | (3, rows, cols) (with `channels_first` data format). 196 | It should have exactly 3 inputs channels, 197 | and width and height should be no smaller than 32. 198 | E.g. `(224, 224, 3)` would be one valid value. 199 | filters: Integer, the dimensionality of the output space 200 | (i.e. the number output of filters in the convolution). 201 | alpha: controls the width of the network. 202 | - If `alpha` < 1.0, proportionally decreases the number 203 | of filters in each layer. 204 | - If `alpha` > 1.0, proportionally increases the number 205 | of filters in each layer. 206 | - If `alpha` = 1, default number of filters from the paper 207 | are used at each layer. 208 | kernel: An integer or tuple/list of 2 integers, specifying the 209 | width and height of the 2D convolution window. 210 | Can be a single integer to specify the same value for 211 | all spatial dimensions. 212 | strides: An integer or tuple/list of 2 integers, 213 | specifying the strides of the convolution along the width and height. 214 | Can be a single integer to specify the same value for 215 | all spatial dimensions. 216 | Specifying any stride value != 1 is incompatible with specifying 217 | any `dilation_rate` value != 1. 218 | bn_epsilon: Epsilon value for BatchNormalization 219 | bn_momentum: Momentum value for BatchNormalization 220 | # Input shape 221 | 4D tensor with shape: 222 | `(samples, channels, rows, cols)` if data_format='channels_first' 223 | or 4D tensor with shape: 224 | `(samples, rows, cols, channels)` if data_format='channels_last'. 225 | # Output shape 226 | 4D tensor with shape: 227 | `(samples, filters, new_rows, new_cols)` if data_format='channels_first' 228 | or 4D tensor with shape: 229 | `(samples, new_rows, new_cols, filters)` if data_format='channels_last'. 230 | `rows` and `cols` values might have changed due to stride. 231 | # Returns 232 | Output tensor of block. 233 | """ 234 | channel_axis = 1 if K.image_data_format() == 'channels_first' else -1 235 | filters = _make_divisible(filters) 236 | x = Conv2D(filters, kernel, 237 | padding='same', 238 | use_bias=False, 239 | strides=strides, 240 | kernel_initializer=initializers.he_normal(), 241 | kernel_regularizer=regularizers.l2(weight_decay), 242 | name='conv%d' % block_id)(inputs) 243 | x = BatchNormalization(axis=channel_axis, momentum=bn_momentum, epsilon=bn_epsilon, 244 | name='conv%d_bn' % block_id)(x) 245 | return Activation(relu6, name='conv%d_relu' % block_id)(x) 246 | 247 | 248 | if __name__ == '__main__': 249 | import tensorflow as tf 250 | from keras import backend as K 251 | 252 | run_metadata = tf.RunMetadata() 253 | 254 | with tf.Session(graph=tf.Graph()) as sess: 255 | K.set_session(sess) 256 | 257 | model = MiniVGG(input_tensor=tf.placeholder('float32', shape=(1, 32, 32, 3))) 258 | opt = tf.profiler.ProfileOptionBuilder.float_operation() 259 | flops = tf.profiler.profile(sess.graph, run_meta=run_metadata, cmd='op', options=opt) 260 | 261 | opt = tf.profiler.ProfileOptionBuilder.trainable_variables_parameter() 262 | param_count = tf.profiler.profile(sess.graph, run_meta=run_metadata, cmd='op', options=opt) 263 | 264 | print('flops:', flops.total_float_ops) 265 | print('param count:', param_count.total_parameters) 266 | 267 | model.summary() 268 | -------------------------------------------------------------------------------- /models/small/train_cifar_10.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from keras example cifar10_cnn.py 3 | Train NASNet-CIFAR on the CIFAR10 small images dataset. 4 | """ 5 | from __future__ import print_function 6 | import os 7 | 8 | from keras.datasets import cifar10 9 | from keras.preprocessing.image import ImageDataGenerator 10 | from keras.utils import np_utils 11 | from keras.callbacks import ModelCheckpoint 12 | from keras.optimizers import SGD 13 | import numpy as np 14 | 15 | from clr import OneCycleLR 16 | from models.small.model import MiniVGG 17 | 18 | if not os.path.exists('weights/'): 19 | os.makedirs('weights/') 20 | 21 | weights_file = 'weights/mini_vgg.h5' 22 | model_checkpoint = ModelCheckpoint( 23 | weights_file, 24 | monitor='val_acc', 25 | save_best_only=True, 26 | save_weights_only=True, 27 | mode='max') 28 | batch_size = 128 29 | nb_classes = 10 30 | nb_epoch = 50 # Only finding lr 31 | data_augmentation = True 32 | 33 | # input image dimensions 34 | img_rows, img_cols = 32, 32 35 | # The CIFAR10 images are RGB. 36 | img_channels = 3 37 | 38 | # The data, shuffled and split between train and test sets: 39 | (X_train, y_train), (X_test, y_test) = cifar10.load_data() 40 | 41 | # Convert class vectors to binary class matrices. 42 | Y_train = np_utils.to_categorical(y_train, nb_classes) 43 | Y_test = np_utils.to_categorical(y_test, nb_classes) 44 | 45 | X_train = X_train.astype('float32') 46 | X_test = X_test.astype('float32') 47 | 48 | # preprocess input 49 | mean = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32') 50 | std = np.mean(X_train, axis=(0, 1, 2), keepdims=True).astype('float32') 51 | 52 | print("Channel Mean : ", mean) 53 | print("Channel Std : ", std) 54 | 55 | X_train = (X_train - mean) / (std) 56 | X_test = (X_test - mean) / (std) 57 | 58 | # Learning rate finder callback setup 59 | num_samples = X_train.shape[0] 60 | 61 | # When using the validation set for LRFinder, try out values starting from 2x 62 | # the lr found there and move lower until its good for the first few epochs 63 | lr_manager = OneCycleLR( 64 | max_lr=0.025, 65 | end_percentage=0.2, 66 | scale_percentage=0.1, 67 | maximum_momentum=0.95, 68 | verbose=True) 69 | 70 | # For training, the auxilary branch must be used to correctly train NASNet 71 | model = MiniVGG((img_rows, img_cols, img_channels), 72 | weight_decay=1e-5, 73 | weights=None, 74 | classes=nb_classes) 75 | model.summary() 76 | 77 | # These values will be overridden by the above callback 78 | optimizer = SGD(lr=0.0025, momentum=0.95, nesterov=True) 79 | model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy']) 80 | 81 | model.load_weights(weights_file) 82 | 83 | if not data_augmentation: 84 | print('Not using data augmentation.') 85 | model.fit( 86 | X_train, 87 | Y_train, 88 | batch_size=batch_size, 89 | epochs=nb_epoch, 90 | validation_data=(X_test, Y_test), 91 | shuffle=True, 92 | verbose=1, 93 | callbacks=[lr_manager, model_checkpoint]) 94 | else: 95 | print('Using real-time data augmentation.') 96 | # This will do preprocessing and realtime data augmentation: 97 | datagen = ImageDataGenerator( 98 | featurewise_center=False, # set input mean to 0 over the dataset 99 | samplewise_center=False, # set each sample mean to 0 100 | featurewise_std_normalization=False, # divide inputs by std of the dataset 101 | samplewise_std_normalization=False, # divide each input by its std 102 | zca_whitening=False, # apply ZCA whitening 103 | # randomly rotate images in the range (degrees, 0 to 180) 104 | rotation_range=0, 105 | # randomly shift images horizontally (fraction of total width) 106 | width_shift_range=0, 107 | # randomly shift images vertically (fraction of total height) 108 | height_shift_range=0, 109 | horizontal_flip=True, # randomly flip images 110 | vertical_flip=False) # randomly flip images 111 | 112 | # Compute quantities required for featurewise normalization 113 | # (std, mean, and principal components if ZCA whitening is applied). 114 | datagen.fit(X_train) 115 | 116 | # Fit the model on the batches generated by datagen.flow(). 117 | # model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size, shuffle=True), 118 | # steps_per_epoch=X_train.shape[0] // batch_size, 119 | # validation_data=(X_test, Y_test), 120 | # epochs=nb_epoch, verbose=1, 121 | # callbacks=[lr_manager, model_checkpoint]) 122 | 123 | scores = model.evaluate(X_test, Y_test, batch_size=batch_size) 124 | for score, metric_name in zip(scores, model.metrics_names): 125 | print("%s : %0.4f" % (metric_name, score)) 126 | -------------------------------------------------------------------------------- /models/small/weights/losses.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/losses.npy -------------------------------------------------------------------------------- /models/small/weights/lrs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/lrs.npy -------------------------------------------------------------------------------- /models/small/weights/mini_vgg.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/mini_vgg.h5 -------------------------------------------------------------------------------- /models/small/weights/momentum/momentum-0.9/losses.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/momentum/momentum-0.9/losses.npy -------------------------------------------------------------------------------- /models/small/weights/momentum/momentum-0.9/lrs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/momentum/momentum-0.9/lrs.npy -------------------------------------------------------------------------------- /models/small/weights/momentum/momentum-0.95/losses.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/momentum/momentum-0.95/losses.npy -------------------------------------------------------------------------------- /models/small/weights/momentum/momentum-0.95/lrs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/momentum/momentum-0.95/lrs.npy -------------------------------------------------------------------------------- /models/small/weights/momentum/momentum-0.99/losses.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/momentum/momentum-0.99/losses.npy -------------------------------------------------------------------------------- /models/small/weights/momentum/momentum-0.99/lrs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/momentum/momentum-0.99/lrs.npy -------------------------------------------------------------------------------- /models/small/weights/weight_decay/weight_decay-0.0001/losses.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-0.0001/losses.npy -------------------------------------------------------------------------------- /models/small/weights/weight_decay/weight_decay-0.0001/lrs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-0.0001/lrs.npy -------------------------------------------------------------------------------- /models/small/weights/weight_decay/weight_decay-0.0003/losses.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-0.0003/losses.npy -------------------------------------------------------------------------------- /models/small/weights/weight_decay/weight_decay-0.0003/lrs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-0.0003/lrs.npy -------------------------------------------------------------------------------- /models/small/weights/weight_decay/weight_decay-0.001/losses.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-0.001/losses.npy -------------------------------------------------------------------------------- /models/small/weights/weight_decay/weight_decay-0.001/lrs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-0.001/lrs.npy -------------------------------------------------------------------------------- /models/small/weights/weight_decay/weight_decay-0.003/losses.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-0.003/losses.npy -------------------------------------------------------------------------------- /models/small/weights/weight_decay/weight_decay-0.003/lrs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-0.003/lrs.npy -------------------------------------------------------------------------------- /models/small/weights/weight_decay/weight_decay-1e-05/losses.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-1e-05/losses.npy -------------------------------------------------------------------------------- /models/small/weights/weight_decay/weight_decay-1e-05/lrs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-1e-05/lrs.npy -------------------------------------------------------------------------------- /models/small/weights/weight_decay/weight_decay-1e-06/losses.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-1e-06/losses.npy -------------------------------------------------------------------------------- /models/small/weights/weight_decay/weight_decay-1e-06/lrs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-1e-06/lrs.npy -------------------------------------------------------------------------------- /models/small/weights/weight_decay/weight_decay-1e-07/losses.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-1e-07/losses.npy -------------------------------------------------------------------------------- /models/small/weights/weight_decay/weight_decay-1e-07/lrs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-one-cycle/06c202996d71491e624ddef53a57858152e93564/models/small/weights/weight_decay/weight_decay-1e-07/lrs.npy -------------------------------------------------------------------------------- /plot_clr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | from keras.models import Model 5 | from keras.layers import Dense, Activation, Input 6 | from keras.optimizers import SGD, Adam 7 | 8 | from clr import OneCycleLR 9 | 10 | plt.style.use('seaborn-white') 11 | 12 | # Constants 13 | NUM_SAMPLES = 2000 14 | NUM_EPOCHS = 100 15 | BATCH_SIZE = 500 16 | MAX_LR = 0.1 17 | 18 | # Data 19 | X = np.random.rand(NUM_SAMPLES, 10) 20 | Y = np.random.randint(0, 2, size=NUM_SAMPLES) 21 | 22 | # Model 23 | inp = Input(shape=(10,)) 24 | x = Dense(5, activation='relu')(inp) 25 | x = Dense(1, activation='sigmoid')(x) 26 | model = Model(inp, x) 27 | 28 | clr_triangular = OneCycleLR(NUM_SAMPLES, NUM_EPOCHS, BATCH_SIZE, MAX_LR, 29 | end_percentage=0.2, scale_percentage=0.2) 30 | 31 | model.compile(optimizer=SGD(0.1), loss='binary_crossentropy', metrics=['accuracy']) 32 | 33 | model.fit(X, Y, batch_size=BATCH_SIZE, epochs=NUM_EPOCHS, callbacks=[clr_triangular], verbose=0) 34 | 35 | 36 | print("LR Range : ", min(clr_triangular.history['lr']), max(clr_triangular.history['lr'])) 37 | print("Momentum Range : ", min(clr_triangular.history['momentum']), max(clr_triangular.history['momentum'])) 38 | 39 | 40 | plt.xlabel('Training Iterations') 41 | plt.ylabel('Learning Rate') 42 | plt.title("CLR") 43 | plt.plot(clr_triangular.history['lr']) 44 | plt.show() 45 | 46 | plt.xlabel('Training Iterations') 47 | plt.ylabel('Momentum') 48 | plt.title("CLR") 49 | plt.plot(clr_triangular.history['momentum']) 50 | plt.show() 51 | --------------------------------------------------------------------------------