├── .gitignore ├── .gitlab-ci.yml ├── .travis.yml ├── LICENSE ├── README.rst ├── architecture.svg ├── docs ├── contributing.md ├── css │ └── extra.css ├── datasets.md ├── examples.md ├── img │ ├── canevet_icml2016_smooth10.png │ ├── canevet_icml2016_smooth100.png │ ├── canevet_icml2016_smooth40.png │ ├── cifar_cnn_test_acc.png │ ├── cifar_cnn_test_loss.png │ ├── cifar_cnn_train_loss.png │ ├── imdb_fasttext_test_acc_epochs.png │ ├── imdb_fasttext_test_acc_time.png │ ├── mnist_cnn_test_acc.png │ ├── mnist_cnn_train_loss.png │ ├── mnist_mlp_test_acc.png │ ├── mnist_mlp_train_loss.png │ ├── mnist_svrg_test.png │ └── mnist_svrg_training.png ├── index.md └── training.md ├── examples ├── README.md ├── cifar10_resnet.py ├── example_utils.py ├── importance_sampling ├── mnist_cnn.py ├── mnist_mlp.py ├── mnist_svrg.py └── mnist_svrg_cnn.py ├── headers.py ├── importance_sampling ├── __init__.py ├── datasets.py ├── layers │ ├── __init__.py │ ├── metrics.py │ ├── normalization.py │ └── scores.py ├── model_wrappers.py ├── models.py ├── pretrained.py ├── reweighting.py ├── samplers.py ├── training.py └── utils │ ├── __init__.py │ ├── functional.py │ ├── keras_utils.py │ ├── tf.py │ └── tf_config.py ├── mkdocs.yml ├── scripts ├── compute_scores.py ├── importance_sampling ├── importance_sampling.py ├── lfw_evaluate.py ├── lfw_forward_pass.py ├── lsexp.py ├── plot_distribution.py ├── plot_loss_evolution.py └── variance_reduction.py ├── setup.py └── tests ├── importance_sampling ├── test_datasets.py ├── test_finetuning.py ├── test_keras_utils.py ├── test_model_wrappers.py ├── test_models.py ├── test_reweighting.py ├── test_samplers.py ├── test_save_load.py ├── test_seq2seq.py └── test_training.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | build 3 | dist 4 | *.egg-info 5 | site/ 6 | -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | 2 | stages: 3 | - test 4 | 5 | test: 6 | script: 7 | - virtualenv -p python3 venv3 8 | - source venv3/bin/activate 9 | - python -V 10 | - pip install --upgrade pip 11 | - pip install tensorflow 12 | - pip install h5py 13 | - pip install -e . 14 | - python -m unittest discover -s tests/ -v 15 | - deactivate 16 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.6" 4 | - "3.7" 5 | 6 | install: 7 | - pip install --upgrade pip 8 | - pip install tensorflow 9 | - pip install h5py 10 | - pip install -e . 11 | 12 | script: 13 | - python -m unittest discover -s tests/ -v 14 | 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 4 | Written by Angelos Katharopoulos 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy of 7 | this software and associated documentation files (the "Software"), to deal in 8 | the Software without restriction, including without limitation the rights to 9 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 10 | of the Software, and to permit persons to whom the Software is furnished to do 11 | so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | Importance Sampling 2 | ==================== 3 | 4 | This python package provides a library that accelerates the training of 5 | arbitrary neural networks created with `Keras `__ using 6 | **importance sampling**. 7 | 8 | .. code:: python 9 | 10 | # Keras imports 11 | 12 | from importance_sampling.training import ImportanceTraining 13 | 14 | x_train, y_train, x_val, y_val = load_data() 15 | model = create_keras_model() 16 | model.compile( 17 | optimizer="adam", 18 | loss="categorical_crossentropy", 19 | metrics=["accuracy"] 20 | ) 21 | 22 | ImportanceTraining(model).fit( 23 | x_train, y_train, 24 | batch_size=32, 25 | epochs=10, 26 | verbose=1, 27 | validation_data=(x_val, y_val) 28 | ) 29 | 30 | model.evaluate(x_val, y_val) 31 | 32 | Importance sampling for Deep Learning is an active research field and this 33 | library is undergoing development so your mileage may vary. 34 | 35 | Relevant Research 36 | ----------------- 37 | 38 | **Ours** 39 | 40 | * Not All Samples Are Created Equal: Deep Learning with Importance Sampling [`preprint `__] 41 | * Biased Importance Sampling for Deep Neural Network Training [`preprint `__] 42 | 43 | **By others** 44 | 45 | * Stochastic optimization with importance sampling for regularized loss 46 | minimization [`pdf `__] 47 | * Variance reduction in SGD by distributed importance sampling [`pdf `__] 48 | 49 | Dependencies & Installation 50 | --------------------------- 51 | 52 | Normally if you already have a functional Keras installation you just need to 53 | ``pip install keras-importance-sampling``. 54 | 55 | * ``Keras`` > 2 56 | * A Keras backend among *Tensorflow*, *Theano* and *CNTK* 57 | * ``blinker`` 58 | * ``numpy`` 59 | * ``matplotlib``, ``seaborn``, ``scikit-learn`` are optional (used by the plot 60 | scripts) 61 | 62 | Documentation 63 | ------------- 64 | 65 | The module has a dedicated `documentation site 66 | `__ but you can also read the 67 | `source code `__ and the `examples 68 | `__ to get an 69 | idea of how the library should be used and extended. 70 | 71 | Examples 72 | --------- 73 | 74 | In the ``examples`` folder you can find some Keras examples that have been edited 75 | to use importance sampling. 76 | 77 | Code examples 78 | ************* 79 | 80 | In this section we will showcase part of the API that can be used to train 81 | neural networks with importance sampling. 82 | 83 | .. code:: python 84 | 85 | # Import what is needed to build the Keras model 86 | from keras import backend as K 87 | from keras.layers import Dense, Activation, Flatten 88 | from keras.models import Sequential 89 | 90 | # Import a toy dataset and the importance training 91 | from importance_sampling.datasets import MNIST 92 | from importance_sampling.training import ImportanceTraining 93 | 94 | 95 | def create_nn(): 96 | """Build a simple fully connected NN""" 97 | model = Sequential([ 98 | Flatten(input_shape=(28, 28, 1)), 99 | Dense(40, activation="tanh"), 100 | Dense(40, activation="tanh"), 101 | Dense(10), 102 | Activation("softmax") # Needs to be separate to automatically 103 | # get the preactivation outputs 104 | ]) 105 | 106 | model.compile( 107 | optimizer="adam", 108 | loss="categorical_crossentropy", 109 | metrics=["accuracy"] 110 | ) 111 | 112 | return model 113 | 114 | 115 | if __name__ == "__main__": 116 | # Load the data 117 | dataset = MNIST() 118 | x_train, y_train = dataset.train_data[:] 119 | x_test, y_test = dataset.test_data[:] 120 | 121 | # Create the NN and keep the initial weights 122 | model = create_nn() 123 | weights = model.get_weights() 124 | 125 | # Train with uniform sampling 126 | K.set_value(model.optimizer.lr, 0.01) 127 | model.fit( 128 | x_train, y_train, 129 | batch_size=64, epochs=10, 130 | validation_data=(x_test, y_test) 131 | ) 132 | 133 | # Train with importance sampling 134 | model.set_weights(weights) 135 | K.set_value(model.optimizer.lr, 0.01) 136 | ImportanceTraining(model).fit( 137 | x_train, y_train, 138 | batch_size=64, epochs=2, 139 | validation_data=(x_test, y_test) 140 | ) 141 | 142 | Using the script 143 | **************** 144 | 145 | The following terminal commands train a small VGG-like network to ~0.65% error 146 | on MNIST (the numbers are from a CPU). 147 | .. code:: 148 | 149 | $ # Train a small cnn with mnist for 500 mini-batches using importance 150 | $ # sampling with bias to achieve ~ 0.65% error (on the CPU). 151 | $ time ./importance_sampling.py \ 152 | > small_cnn \ 153 | > oracle-gnorm \ 154 | > model \ 155 | > predicted \ 156 | > mnist \ 157 | > /tmp/is \ 158 | > --hyperparams 'batch_size=i128;lr=f0.003;lr_reductions=I10000' \ 159 | > --train_for 500 --validate_every 500 160 | real 1m41.985s 161 | user 8m14.400s 162 | sys 0m35.900s 163 | $ 164 | $ # And with uniform sampling to achieve ~ 0.9% error. 165 | $ time ./importance_sampling.py \ 166 | > small_cnn \ 167 | > oracle-loss \ 168 | > uniform \ 169 | > unweighted \ 170 | > mnist \ 171 | > /tmp/uniform \ 172 | > --hyperparams 'batch_size=i128;lr=f0.003;lr_reductions=I10000' \ 173 | > --train_for 3000 --validate_every 3000 174 | real 9m23.971s 175 | user 47m32.600s 176 | sys 3m4.188s 177 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | The development of Importance Sampling for Keras is mirrored on [GitHub][repo]. 4 | If you have found a bug or you want a new feature just create a [pull 5 | request][pr_url] or an [issue][issue_url]. Before submitting code for a PR make 6 | sure that: 7 | 8 | * You have **rebased** your local branch on top of origin/master 9 | * Your code passes **PEP8** 10 | * You have added **tests** covering the code you added 11 | 12 | 13 | [repo]: https://github.com/idiap/importance-sampling 14 | [pr_url]: https://github.com/idiap/importance-sampling/pulls 15 | [issue_url]: https://github.com/idiap/importance-sampling/issues 16 | -------------------------------------------------------------------------------- /docs/css/extra.css: -------------------------------------------------------------------------------- 1 | body { 2 | counter-reset: figure; 3 | } 4 | .fig { 5 | counter-increment: figure; 6 | clear: both; 7 | margin-bottom: 25px; 8 | } 9 | .fig.col-2 > img { 10 | width: 50%; 11 | float: left; 12 | } 13 | .fig.col-3 > img { 14 | width: 33%; 15 | float: left; 16 | } 17 | .fig.col-4 > img { 18 | width: 25%; 19 | float: left; 20 | } 21 | .fig > span { 22 | display: block; 23 | font-size: 85%; 24 | font-style: italic; 25 | } 26 | .fig > span::before { 27 | font-weight: bold; 28 | content: "Figure " counter(figure) ": "; 29 | } 30 | -------------------------------------------------------------------------------- /docs/datasets.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | Importance sampling entails the process of accessing random samples from a 4 | dataset in a specific manner. To achieve this we introduce an interface for a 5 | random access `Dataset` in `importance_sampling.datasets`. 6 | 7 | Initially, we will present the Dataset interface and subsequently several 8 | implementations both reusable for your own datasets and wrapping some well 9 | known benchmark datasets. 10 | 11 | ## BaseDataset 12 | 13 | `BaseDataset` provides the interface that the rest of the components rely on to 14 | perform importance sampling. The main idea is that it provides two properties 15 | `train_data` and `test_data` that return a proxy object that can be accessed as 16 | a random access container returning tuples of (data, targets). The concept will 17 | be better illustrated by the following code example. 18 | 19 | ```python 20 | # Let's assume dataset is an instance of BaseDataset 21 | 22 | x, y = dataset.train_data[0] # Just the first sample 23 | x, y = dataset.train_data[::5] # Every fifth sample 24 | x, y = dataset.train_data[[0, 33, 1]] # 1st, 33rd and 2nd samples 25 | N = len(dataset.train_data) # How many samples are there? 26 | ``` 27 | 28 | To implement this behaviour, subclasses of `BaseDataset` have to extend 4 29 | functions, `_train_data`, `_train_size`, `_test_data` and `_test_size`, see 30 | `importance_sampling.datasets.InMemoryDataset` for a simple implementation of 31 | this API. The complete API of BaseDataset is given below: 32 | 33 | ```python 34 | class BaseDataset(object): 35 | def _train_data(self, idxs=slice(None)): 36 | raise NotImplementedError() 37 | 38 | def _train_size(self): 39 | raise NotImplementedError() 40 | 41 | def _test_data(self, idxs=slice(None)): 42 | raise NotImplementedError() 43 | 44 | def _test_size(self): 45 | raise NotImplementedError() 46 | 47 | @property 48 | def shape(self): 49 | """Return the shape of the data without the batch axis""" 50 | raise NotImplementedError() 51 | 52 | @property 53 | def output_size(self): 54 | """Return the dimensions of the targets""" 55 | raise NotImplementedError() 56 | ``` 57 | 58 | ## InMemoryDataset 59 | 60 | ```python 61 | importance_sampling.dataset.InMemoryDataset(X_train, y_train, X_test, y_test, categorical=True) 62 | ``` 63 | 64 | `InMemoryDataset` simply wraps 4 Numpy arrays and implements the interface. If 65 | `categorical` is True `y_*` are transformed to one-hot dense vectors to 66 | indicate the class. 67 | 68 | **Arguments** 69 | 70 | * **X_train:** At least 2-dimensional Numpy array that contains the training 71 | data 72 | * **y_train:** Numpy array that contains the training targets 73 | * **X_test:** At least 2-dimensional Numpy array that contains the 74 | testing/validation data 75 | * **y_test:** Numpy array that contains the testing/validation targets 76 | * **categorical:** Controls whether the targets will be transformed into 77 | one-hot dense vectors 78 | 79 | **Static methods** 80 | 81 | ```python 82 | from_loadable(dataset) 83 | ``` 84 | 85 | Creates a dataset from an object that returns the four Numpy arrays with a 86 | `load_data()` method. Used, for instance, with the *Keras* datasets. 87 | 88 | **Example** 89 | 90 | ```python 91 | from importance_sampling.datasets import InMemoryDataset 92 | import numpy as np 93 | 94 | dset = InMemoryDataset( 95 | np.random.rand(100, 10), 96 | np.random.rand(100, 1), 97 | np.random.rand(100, 10), 98 | np.random.rand(100, 1), 99 | categorical=False 100 | ) 101 | 102 | assert dset.shape == (10,) 103 | assert dset.output_size == 1 104 | assert len(dset.train_data) == 100 105 | ``` 106 | 107 | 108 | ## InMemoryImageDataset 109 | 110 | ```python 111 | importance_sampling.dataset.InMemoryImageDataset(X_train, y_train, X_test, y_test) 112 | ``` 113 | 114 | `InMemoryImageDataset` asserts that the passed arrays are 4-dimensional and 115 | normalizes them as `float32` in the range `[0, 1]`. 116 | 117 | **Arguments** 118 | 119 | * **X_train:** At least 2-dimensional Numpy array that contains the training 120 | data 121 | * **y_train:** Numpy array that contains the training targets 122 | * **X_test:** At least 2-dimensional Numpy array that contains the 123 | testing/validation data 124 | * **y_test:** Numpy array that contains the testing/validation targets 125 | 126 | **Static methods** 127 | 128 | ```python 129 | from_loadable(dataset) 130 | ``` 131 | 132 | Creates a dataset from an object that returns the four Numpy arrays with a 133 | `load_data()` method. Used, for instance, with the *Keras* datasets. 134 | 135 | **Example** 136 | 137 | ```python 138 | from keras.datasets import mnist 139 | from importance_sampling.datasets import InMemoryImageDataset 140 | 141 | dset = InMemoryImageDataset.from_loadable(mnist) 142 | 143 | assert dset.shape == (28, 28, 1) 144 | assert dset.output_size == 10 145 | assert len(dset.train_data) == 60000 146 | ``` 147 | 148 | ## OntheflyAugmentedImages 149 | 150 | ```python 151 | importance_sampling.dataset.OntheflyAugmentedImages(dataset, augmentation_params, N=None, random_state=0, cache_size=None) 152 | ``` 153 | 154 | `OntheflyAugmentedImages` uses *Keras* `ImageDataGenerator` to augment an image 155 | dataset deterministically producing `N` images without explicitly storing them 156 | in memory. 157 | 158 | **Arguments** 159 | 160 | * **dataset:** Another instance of `BaseDataset` that this class will decorate 161 | * **augmentation\_params:** A dictionary of keyword arguments to pass to 162 | `ImageDataGenerator` 163 | * **N:** The size of the augmented dataset, if not given it defaults to 10 164 | times the decorated dataset 165 | * **random\_state:** A seed for the pseudo random number generator so that the 166 | augmented datasets are reproducible 167 | * **cache\_size:** The number of samples to cache using an LRU policy in order 168 | to reduce the time spent augmenting the same images (defaults to 169 | `len(dataset.train_data)`) 170 | 171 | **Example** 172 | 173 | ```python 174 | from keras.datasets import cifar10 175 | from importance_sampling.datasets import InMemoryImageDataset, \ 176 | OntheflyAugmentedImages 177 | 178 | dset = OntheflyAugmentedImages( 179 | InMemoryImageDataset.from_loadable(cifar10), 180 | dict( 181 | width_shift_range=0.1, 182 | height_shift_range=0.1, 183 | horizontal_flip=True 184 | ) 185 | ) 186 | 187 | assert dset.shape == (32, 32, 3) 188 | assert dset.output_size == 10 189 | assert len(dset.train_data) == 10 * 50000 190 | ``` 191 | 192 | ## GeneratorDataset 193 | 194 | ```python 195 | importance_sampling.dataset.GeneratorDataset(train_data, test_data=None, test_data_length=None, cache_size=5) 196 | ``` 197 | 198 | The `GeneratorDataset` wraps one or two generators and partially implements the 199 | `BaseDataset` interface. The `test_data` can be a generator or in memory data. 200 | The generators are consumed in background threads and at most `cache_size` 201 | return values are saved from each at any given time. 202 | 203 | **Arguments** 204 | 205 | * **train\_data**: A normal Keras compatible data generator. It should be infinite and 206 | return both inputs and targets 207 | * **test\_data**: Either a Keras compatible data generator or a list, numpy 208 | array etc. 209 | * **test\_data\_length**: When `test_data` is a generator then the number of 210 | points in the test set should be given. 211 | * **cach\_size**: The maximum return values cached in the backgound threads 212 | from the generators, equivalent to Keras's `max_queue_size` 213 | 214 | **Example** 215 | 216 | ```python 217 | from keras.datasets import cifar10 218 | from keras.preprocessing.image import ImageDataGenerator 219 | from keras.utils import to_categorical 220 | from importance_sampling.datasets import GeneratorDataset 221 | 222 | # Load cifar into x_train, y_train, x_test, y_test 223 | (x_train, y_train), (x_test, y_test) = cifar10.load_data() 224 | y_train = to_categorical(y_train, 10) 225 | y_test = to_categorical(y_test, 10) 226 | x_train = x_train.astype('float32') 227 | x_test = x_test.astype('float32') 228 | x_train /= 255 229 | x_test /= 255 230 | 231 | # Create a data augmentation pipeline 232 | datagen = ImageDataGenerator( 233 | featurewise_center=False, # set input mean to 0 over the dataset 234 | samplewise_center=False, # set each sample mean to 0 235 | featurewise_std_normalization=False, # divide inputs by std of the dataset 236 | samplewise_std_normalization=False, # divide each input by its std 237 | zca_whitening=False, # apply ZCA whitening 238 | rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180) 239 | width_shift_range=0.1, # randomly shift images horizontally (fraction of total width) 240 | height_shift_range=0.1, # randomly shift images vertically (fraction of total height) 241 | horizontal_flip=True, # randomly flip images 242 | vertical_flip=False) # randomly flip images 243 | datagen.fit(x_train) 244 | 245 | dset = GeneratorDataset( 246 | datagen.flow(x_train, y_train, batch_size=32), 247 | (x_test, y_test) 248 | ) 249 | 250 | assert dset.shape == (32, 32, 3) 251 | assert dset.output_size == 10 252 | assert len(dset.test_data) == 10000 253 | ``` 254 | 255 | ## Provided dataset classes 256 | 257 | `MNIST`, `CIFAR10` and `CIFAR100` are already provided as dataset classes with 258 | no constructor parameters. 259 | 260 | ### CanevetICML2016 261 | 262 | ```python 263 | importance_sampling.datasets.CanevetICML2016(N=8192, test_split=0.33, smooth=40) 264 | ``` 265 | 266 | This dataset is an artificial 2-dimensional binary classification dataset that 267 | is suitable for importance sampling and was introduced by [Canevet et al. ICML 268 | 2016][canevet_et_al]. 269 | 270 |
271 | Canevet dataset with smooth 10 273 | Canevet dataset with smooth 40 275 | Canevet dataset with smooth 100 277 | The effect of the smooth argument on the artificial dataset. From left to 278 | right smooth is 10, 40, 100. 279 |
280 | 281 | **Arguments** 282 | 283 | * **N:** The dataset is going to have N^2 points 284 | * **test\_split:** The percentage of the points to keep as a test/validation 285 | set 286 | * **smooth:** A jitter controlling parameter whose effects are seen in the 287 | previous figure. 288 | 289 | [canevet_et_al]: http://fleuret.org/papers/canevet-et-al-icml2016.pdf 290 | -------------------------------------------------------------------------------- /docs/img/canevet_icml2016_smooth10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/importance-sampling/9c9cab2ac91081ae2b64f99891504155057c09e3/docs/img/canevet_icml2016_smooth10.png -------------------------------------------------------------------------------- /docs/img/canevet_icml2016_smooth100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/importance-sampling/9c9cab2ac91081ae2b64f99891504155057c09e3/docs/img/canevet_icml2016_smooth100.png -------------------------------------------------------------------------------- /docs/img/canevet_icml2016_smooth40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/importance-sampling/9c9cab2ac91081ae2b64f99891504155057c09e3/docs/img/canevet_icml2016_smooth40.png -------------------------------------------------------------------------------- /docs/img/cifar_cnn_test_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/importance-sampling/9c9cab2ac91081ae2b64f99891504155057c09e3/docs/img/cifar_cnn_test_acc.png -------------------------------------------------------------------------------- /docs/img/cifar_cnn_test_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/importance-sampling/9c9cab2ac91081ae2b64f99891504155057c09e3/docs/img/cifar_cnn_test_loss.png -------------------------------------------------------------------------------- /docs/img/cifar_cnn_train_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/importance-sampling/9c9cab2ac91081ae2b64f99891504155057c09e3/docs/img/cifar_cnn_train_loss.png -------------------------------------------------------------------------------- /docs/img/imdb_fasttext_test_acc_epochs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/importance-sampling/9c9cab2ac91081ae2b64f99891504155057c09e3/docs/img/imdb_fasttext_test_acc_epochs.png -------------------------------------------------------------------------------- /docs/img/imdb_fasttext_test_acc_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/importance-sampling/9c9cab2ac91081ae2b64f99891504155057c09e3/docs/img/imdb_fasttext_test_acc_time.png -------------------------------------------------------------------------------- /docs/img/mnist_cnn_test_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/importance-sampling/9c9cab2ac91081ae2b64f99891504155057c09e3/docs/img/mnist_cnn_test_acc.png -------------------------------------------------------------------------------- /docs/img/mnist_cnn_train_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/importance-sampling/9c9cab2ac91081ae2b64f99891504155057c09e3/docs/img/mnist_cnn_train_loss.png -------------------------------------------------------------------------------- /docs/img/mnist_mlp_test_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/importance-sampling/9c9cab2ac91081ae2b64f99891504155057c09e3/docs/img/mnist_mlp_test_acc.png -------------------------------------------------------------------------------- /docs/img/mnist_mlp_train_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/importance-sampling/9c9cab2ac91081ae2b64f99891504155057c09e3/docs/img/mnist_mlp_train_loss.png -------------------------------------------------------------------------------- /docs/img/mnist_svrg_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/importance-sampling/9c9cab2ac91081ae2b64f99891504155057c09e3/docs/img/mnist_svrg_test.png -------------------------------------------------------------------------------- /docs/img/mnist_svrg_training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/importance-sampling/9c9cab2ac91081ae2b64f99891504155057c09e3/docs/img/mnist_svrg_training.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Importance Sampling for Keras 2 | 3 | Deep learning models spend countless GPU/CPU cycles on trivial, correctly 4 | classified examples that do not individually affect the parameters. For 5 | instance, even a very simple neural network achieves ~98% accuracy on MNIST 6 | after a single epoch. 7 | 8 | Importance sampling focuses the computation to informative/important samples 9 | (by sampling mini-batches from a distribution other than uniform) thus 10 | accelerating the convergence. 11 | 12 | This library: 13 | 14 | * wraps Keras models requiring just **one line changed** to try out *Importance Sampling* 15 | * comes with modified Keras examples for quick and dirty comparison 16 | * is the result of ongoing research which means that *your mileage may vary* 17 | 18 | ## Quick-start 19 | 20 | The main API that is provided is that of 21 | `importance_sampling.training.ImportanceTraining`. The library uses composition 22 | to seamlessly wrap your Keras models and perform importance sampling behind the 23 | scenes. 24 | 25 | The example that follows is the minimal working example of importance 26 | sampling. Note the use of a **separate final activation layer** in order for the 27 | library to be able to get the pre-activation outputs. 28 | 29 | ```python 30 | from keras.datasets import mnist 31 | from keras.models import Sequential 32 | from keras.layers import Dense, Activation 33 | import numpy as np 34 | 35 | from importance_sampling.training import ImportanceTraining 36 | 37 | # Load mnist 38 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 39 | x_train = x_train.reshape(-1, 784).astype(np.float32) / 255 40 | y_train = np.eye(10).astype(np.float32)[y_train] 41 | x_test = x_test.reshape(-1, 784).astype(np.float32) / 255 42 | y_test = np.eye(10).astype(np.float32)[y_test] 43 | 44 | # Build your NN normally 45 | model = Sequential() 46 | model.add(Dense(512, activation='relu', input_shape=(784,))) 47 | model.add(Dense(512, activation='relu')) 48 | model.add(Dense(10)) 49 | model.add(Activation("softmax")) 50 | model.compile("adam", "categorical_crossentropy", metrics=["accuracy"]) 51 | 52 | # Train with importance sampling 53 | history = ImportanceTraining(model).fit( 54 | x_train, y_train, 55 | batch_size=128, epochs=5, 56 | verbose=1, 57 | validation_data=(x_test, y_test) 58 | ) 59 | ``` 60 | 61 | ## Installation 62 | 63 | Importance sampling has the following dependencies: 64 | 65 | * Keras >= 2 66 | * numpy 67 | * blinker 68 | 69 | You can install it from PyPI with: 70 | 71 | ```bash 72 | pip install --user keras-importance-sampling 73 | ``` 74 | 75 | ## Research 76 | 77 | In case you want theoretical and empirical evidence regarding Importance 78 | Sampling and Deep Learning we encourage you to follow our research. 79 | 80 | 1. [Not All Samples Are Created Equal: Deep Learning with Importance Sampling (2018)][nasace] 81 | 2. [Biased Importance Sampling for Deep Neural Network Training (2017)][biased_is] 82 | 83 | ```bibtex 84 | @article{katharopoulos2018is, 85 | Author = {Katharopoulos, Angelos and Fleuret, Fran\c{c}ois}, 86 | Journal = {arXiv preprint arXiv:1803.00942}, 87 | Title = {Not All Samples Are Created Equal: Deep Learning with Importance 88 | Sampling}, 89 | Year = {2018} 90 | } 91 | ``` 92 | 93 | Moreover we suggest you look into the following highly related and influential papers: 94 | 95 | * Stochastic optimization with importance sampling for regularized loss 96 | minimization [[pdf][zhao_zhang]] 97 | * Variance reduction in SGD by distributed importance sampling [[pdf][distributed_is]] 98 | 99 | ## Support, License and Copyright 100 | 101 | This software is distributed with the **MIT** license which pretty much means 102 | that you can use it however you want and for whatever reason you want. All the 103 | information regarding support, copyright and the license can be found in the 104 | [LICENSE][lic] file in the repository. 105 | 106 | [github_examples]: https://github.com/idiap/importance-sampling/tree/master/examples 107 | [nasace]: https://arxiv.org/abs/1803.00942 108 | [biased_is]: https://arxiv.org/abs/1706.00043 109 | [zhao_zhang]: http://www.jmlr.org/proceedings/papers/v37/zhaoa15.pdf 110 | [distributed_is]: https://arxiv.org/pdf/1511.06481 111 | [lic]: https://github.com/idiap/importance-sampling/blob/master/LICENSE 112 | -------------------------------------------------------------------------------- /docs/training.md: -------------------------------------------------------------------------------- 1 | # How to 2 | 3 | The `training` module provides several implementations of `ImportanceTraining` 4 | that can wrap a *Keras* model and train it with importance sampling. 5 | 6 | ```python 7 | from importance_sampling.training import ImportanceTraining, BiasedImportanceTraining 8 | 9 | # assuming model is a Keras model 10 | wrapped_model = ImportanceTraining(model) 11 | wrapped_model = BiasedImportanceTraining(model, k=1.0, smooth=0.5) 12 | 13 | wrapped_model.fit(x_train, y_train, epochs=10) 14 | model.evaluate(x_test, y_test) 15 | ``` 16 | 17 | ## Sampling probabilites and sample weights 18 | 19 | All of the `fit` methods accept two extra keyword arguments `on_sample` and 20 | `on_scores`. They are callables that allow the user of the library to have read 21 | access to the sampling probabilities weights and scores from the performed 22 | importance sampling. Their API is the following, 23 | 24 | ```python 25 | on_sample(sampler, idxs, w, predicted_scores) 26 | ``` 27 | 28 | **Arguments** 29 | 30 | * **sampler**: The instance of BaseSampler being currently used 31 | * **idxs**: A numpy array containing the indices that were sampled 32 | * **w**: A numpy array containing the computed sample weights 33 | * **predicted_scores**: A numpy array containing the unnormalized importance 34 | scores 35 | 36 | ```python 37 | on_scores(sampler, scores) 38 | ``` 39 | 40 | **Arguments** 41 | 42 | * **sampler**: The instance of BaseSampler being currently used 43 | * **scores**: A numpy array containing all the importance scores from the 44 | presampled data 45 | 46 | ## Bias 47 | 48 | `BiasedImportanceTraining` and `ApproximateImportanceTraining` classes accept a 49 | constructor parameter \(k \in (-\infty, 1]\). \(k\) biases the gradient 50 | estimator to focus more on hard examples, the smaller the value the closer to 51 | max-loss minimization the algorithm is. By default `k=0.5` which is found to 52 | often improve the generalization performance of the final model. 53 | 54 | ## Smoothing 55 | 56 | Modern deep networks often have innate sources of randomness (e.g. dropout, 57 | batch normalization) that can result in noisy importance predictions. To 58 | alleviate this noise one can smooth the importance using additive smoothing. 59 | The proposed `ImportanceTraining` class does not use smoothing and we propose 60 | to replace *Dropout* and *BatchNormalization* with \(L_2\) regularization and 61 | *LayerNormalization*. 62 | 63 | The classes that accept smoothing do so in the following way, the 64 | \(\text{smooth} \in \mathbb{R}\) parameter is added to all importance 65 | predictions before computing the sampling distribution. In addition, they 66 | accept the \(\text{adaptive_smoothing}\) parameter which when set to `True` 67 | multiplies \(\text{smooth}\) with \(\bar{L} \approx 68 | \mathbb{E}\left[\frac{1}{\|B\|} \sum_{i \in B} L(x_i, y_i)\right]\) as computed 69 | by the moving average of the mini-batch losses. 70 | 71 | Although, smooth is initialized at `smooth=0.0`, if instability is observed 72 | during training, it can be set to small values (e.g. `[0.05, 0.1, 0.5]`) or one 73 | can use adaptive smoothing with a sane default value for smooth being 74 | `smooth=0.5`. 75 | 76 | ## Methods 77 | 78 | The wrapped models aim to expose the same `fit` methods as the original *Keras* 79 | models in order to make their use as simple as possible. The following is a 80 | list of deviations or additions: 81 | 82 | * `class_weights`, `sample_weights` are **not** supported 83 | * `fit_generator` accepts a `batch_size` argument 84 | * `fit_generator` is not supported by all `ImportanceTraining` classes 85 | * `fit_dataset` has been added as a method (see [Datasets](datasets.md)) 86 | 87 | Below, follows the list of methods with their arguments. 88 | 89 | ### fit 90 | 91 | ``` 92 | fit(x, y, batch_size=32, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, steps_per_epoch=None, on_sample=None, on_scores=None) 93 | ``` 94 | 95 | **Arguments** 96 | 97 | * **x**: Numpy array of training data, lists and dictionaries are not supported 98 | * **y**: Numpy array of target data, lists and dictionaries are not supported 99 | * **batch\_size**: The number of samples per gradient update 100 | * **epochs**: Multiplied by `steps_per_epoch` defines the total number of 101 | parameter updates 102 | * **verbose**: When set `>0` the *Keras* progress callback is added to the list 103 | of callbacks 104 | * **callbacks**: A list of *Keras* callbacks for logging, changing training 105 | parameters, monitoring, etc. 106 | * **validation\_split**: A float in `[0, 1)` that defines the percentage of the 107 | training data to use for evaluation 108 | * **validation\_data**: A tuple of numpy arrays containing data and targets to 109 | evaluate the network on 110 | * **steps\_per\_epoch**: The number of gradient updates to do in order to 111 | assume that an epoch has passed 112 | * **on_sample**: A callable that accepts the sampler, idxs, w, scores 113 | * **on_scores**: A callable that accepts the sampler and scores 114 | 115 | **Returns** 116 | 117 | A *Keras* `History` instance. 118 | 119 | ### fit\_generator 120 | 121 | ``` 122 | fit_generator(train, steps_per_epoch, batch_size=32, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, on_sample=None, on_scores=None) 123 | ``` 124 | 125 | **Arguments** 126 | 127 | * **train**: A generator yielding tuples of (data, targets) 128 | * **steps\_per\_epoch**: The number of gradient updates to do in order to 129 | assume that an epoch has passed 130 | * **batch\_size**: The number of samples per gradient update (in contrast to 131 | *Keras* this can be variable) 132 | * **epochs**: Multiplied by `steps_per_epoch` defines the total number of 133 | parameter updates 134 | * **verbose**: When set `>0` the *Keras* progress callback is added to the list 135 | of callbacks 136 | * **callbacks**: A list of *Keras* callbacks for logging, changing training 137 | parameters, monitoring, etc. 138 | * **validation\_data**: A tuple of numpy arrays containing data and targets to 139 | evaluate the network on or a generator yielding tuples of (data, targets) 140 | * **validation\_steps**: The number of tuples to extract from the validation 141 | data generator (if a generator is given) 142 | * **on_sample**: A callable that accepts the sampler, idxs, w, scores 143 | * **on_scores**: A callable that accepts the sampler and scores 144 | 145 | **Returns** 146 | 147 | A *Keras* `History` instance. 148 | 149 | ### fit\_dataset 150 | 151 | ``` 152 | fit_dataset(dataset, steps_per_epoch=None, batch_size=32, epochs=1, verbose=1, callbacks=None, on_sample=None, on_scores=None) 153 | ``` 154 | 155 | The calls to the other `fit*` methods are delegated to this one after a 156 | `Dataset` instance has been created. See [Datasets]() for details on how to 157 | create a `Dataset` and what datasets are available by default. 158 | 159 | **Arguments** 160 | 161 | * **dataset**: Instance of the `Dataset` class 162 | * **steps\_per\_epoch**: The number of gradient updates to do in order to 163 | assume that an epoch has passed (if not given equals the number of training 164 | samples) 165 | * **batch\_size**: The number of samples per gradient update (in contrast to 166 | *Keras* this can be variable) 167 | * **epochs**: Multiplied by `steps_per_epoch` defines the total number of 168 | parameter updates 169 | * **verbose**: When set `>0` the *Keras* progress callback is added to the list 170 | of callbacks 171 | * **callbacks**: A list of *Keras* callbacks for logging, changing training 172 | parameters, monitoring, etc. 173 | * **on_sample**: A callable that accepts the sampler, idxs, w, scores 174 | * **on_scores**: A callable that accepts the sampler and scores 175 | 176 | **Returns** 177 | 178 | A *Keras* `History` instance. 179 | 180 | ## ImportanceTraining 181 | 182 | ``` 183 | importance_sampling.training.ImportanceTraining(model, presample=3.0, tau_th=None, forward_batch_size=None, score="gnorm", layer=None) 184 | ``` 185 | 186 | `ImportanceTraining` uses the passed model to compute the importance of the 187 | samples. It computes the variance reduction and enables importance sampling 188 | only when the variance will be reduced more than `tau_th`. When importance sampling is enabled, it 189 | samples uniformly `presample*batch_size` samples, then it runs a 190 | **forward pass** for all of them to compute the `score` and **resamples 191 | according to the importance**. 192 | 193 | See our [paper](https://arxiv.org/abs/1803.00942) for a precise definition of 194 | the algorithm. 195 | 196 | **Arguments** 197 | 198 | * **model**: The Keras model to train 199 | * **presample**: The number of samples to presample for scoring, given as a 200 | factor of the batch size 201 | * **tau\_th**: The variance reduction threshold after which we enable 202 | importance sampling, when not given it is computed from eq. 29 (it is given 203 | in units of batch size increment) 204 | * **forward\_batch\_size**: Define the batch size when running the forward pass 205 | to compute the importance 206 | * **score**: Choose the importance score among \(\{\text{gnorm}, \text{loss}, 207 | \text{full_gnorm}\}\). `gnorm` computes an upper bound to the full gradient norm 208 | that requires only one forward pass. 209 | * **layer**: Defines which layer will be used to compute the upper bound (if 210 | not given it is automatically inferred). It can also be given as an index in 211 | the model's layers property. 212 | 213 | ## BiasedImportanceTraining 214 | 215 | ``` 216 | importance_sampling.training.BiasedImportanceTraining(model, k=0.5, smooth=0.0, adaptive_smoothing=False, presample=256, forward_batch_size=128) 217 | ``` 218 | 219 | `BiasedImportanceTraining` uses the model and the loss to compute the per 220 | sample importance. `presample` data points are sampled uniformly and after a 221 | forward pass on all of them the importance distribution is calculated and we 222 | resample the mini batch. 223 | 224 | See the corresponding [paper](https://arxiv.org/abs/1706.00043) for details. 225 | 226 | **Arguments** 227 | 228 | * **model**: The Keras model to train 229 | * **k**: Controls the bias of the sampling that focuses the network on the hard 230 | examples 231 | * **smooth**: Influences the sampling distribution towards uniform by additive 232 | smoothing 233 | * **adaptive\_smoothing**: When set to `True` multiplies `smooth` with the 234 | average training loss 235 | * **presample**: Defines the number of samples to compute the importance for 236 | before creating each batch 237 | * **forward\_batch\_size**: Define the batch size when running the forward pass 238 | to compute the importance 239 | 240 | ## ApproximateImportanceTraining 241 | 242 | ``` 243 | importance_sampling.training.ApproximateImportanceTraining(model, k=0.5, smooth=0.0, adaptive_smoothing=False, presample=2048) 244 | ``` 245 | 246 | `ApproximateImportanceTraining` creates a small model that uses the per sample 247 | history of the loss and the class to predict the importance for each sample. It 248 | can be faster than `BiasedImportanceTraining` but less effective. 249 | 250 | See the corresponding [paper](https://arxiv.org/abs/1706.00043) for details. 251 | 252 | **Arguments** 253 | 254 | * **model**: The Keras model to train 255 | * **k**: Controls the bias of the sampling that focuses the network on the hard 256 | examples 257 | * **smooth**: Influences the sampling distribution towards uniform by additive 258 | smoothing 259 | * **adaptive\_smoothing**: When set to `True` multiplies `smooth` with the 260 | average training loss 261 | * **presample**: Defines the number of samples to compute the importance for 262 | before creating each batch 263 | 264 | ## SVRG 265 | 266 | ``` 267 | importance_sampling.training.SVRG(model, B=10., B_rate=1.0, B_over_b=128) 268 | ``` 269 | 270 | `SVRG` trains a Keras model with stochastic variance reduced gradient. 271 | Specifically it implements the following two variants of SVRG 272 | 273 | * SVRG - [Accelerating stochastic gradient descent using predictive variance 274 | reduction][svrg] by Johnson R. and Zhang T. 275 | * SCSG - [Less than a single pass: Stochastically controlled stochastic 276 | gradient][scsg] by Lei L. and Jordan M. 277 | 278 | **Arguments** 279 | 280 | * **model**: The Keras model to train 281 | * **B**: The number of batches to use to compute the full batch gradient. For 282 | SVRG this should be either a very large number or 0. For SCSG it can be any 283 | number larger than 1 284 | * **B\_rate**: A factor to multiply `B` with after every update 285 | * **B\_over\_b**: Compute a batch gradient after every `B_over_b` gradient 286 | updates. 287 | 288 | [svrg]: https://papers.nips.cc/paper/4937-accelerating-stochastic-gradient-descent-using-predictive-variance-reduction.pdf 289 | [scsg]: https://arxiv.org/abs/1609.03261 290 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | Importance sampling examples 2 | ============================ 3 | 4 | In this directory we copy some of the Keras examples and apply importance 5 | sampling. The change is minimal and in most cases amounts to the following 6 | code: 7 | 8 | from importance_sampling.training import ImportanceTraining 9 | 10 | ... 11 | ... 12 | 13 | model.compile(...) 14 | 15 | # instead of model.fit(....) 16 | ImportanceTraining(model).fit(....) 17 | -------------------------------------------------------------------------------- /examples/cifar10_resnet.py: -------------------------------------------------------------------------------- 1 | """Trains a ResNet on the CIFAR10 dataset.""" 2 | 3 | from __future__ import print_function 4 | 5 | import time 6 | 7 | from keras import backend as K 8 | from keras.callbacks import LearningRateScheduler, Callback 9 | from keras.datasets import cifar10 10 | from keras.layers import Activation, BatchNormalization, Conv2D, Dense, \ 11 | GlobalAveragePooling2D, Input, add 12 | from keras.models import Model 13 | from keras.optimizers import SGD 14 | from keras.preprocessing.image import ImageDataGenerator 15 | from keras.regularizers import l2 16 | from keras.utils import to_categorical 17 | import numpy as np 18 | 19 | from importance_sampling.datasets import CIFAR10, ZCAWhitening 20 | from importance_sampling.models import wide_resnet 21 | from importance_sampling.training import ImportanceTraining 22 | from example_utils import get_parser 23 | 24 | 25 | class TrainingSchedule(Callback): 26 | """Implement the training schedule for training a resnet on CIFAR10 for a 27 | given time budget.""" 28 | def __init__(self, total_time): 29 | self._total_time = total_time 30 | self._lr = self._get_lr(0.0) 31 | 32 | def _get_lr(self, progress): 33 | if progress > 0.8: 34 | return 0.004 35 | elif progress > 0.5: 36 | return 0.02 37 | else: 38 | return 0.1 39 | 40 | def on_train_begin(self, logs={}): 41 | self._start = time.time() 42 | self._lr = self._get_lr(0.0) 43 | K.set_value(self.model.optimizer.lr, self._lr) 44 | 45 | def on_batch_end(self, batch, logs): 46 | t = time.time() - self._start 47 | 48 | if t >= self._total_time: 49 | self.model.stop_training = True 50 | 51 | lr = self._get_lr(t / self._total_time) 52 | if lr != self._lr: 53 | self._lr = lr 54 | K.set_value(self.model.optimizer.lr, self._lr) 55 | 56 | @property 57 | def lr(self): 58 | return self._lr 59 | 60 | 61 | if __name__ == "__main__": 62 | parser = get_parser("Train a ResNet on CIFAR10") 63 | parser.add_argument( 64 | "--depth", 65 | type=int, 66 | default=28, 67 | help="Choose the depth of the resnet" 68 | ) 69 | parser.add_argument( 70 | "--width", 71 | type=int, 72 | default=2, 73 | help="Choose the width of the resnet" 74 | ) 75 | parser.add_argument( 76 | "--presample", 77 | type=float, 78 | default=3.0, 79 | help="Presample that many times the batch size for importance sampling" 80 | ) 81 | parser.add_argument( 82 | "--batch_size", 83 | type=int, 84 | default=128, 85 | help="Choose the size of the minibatch" 86 | ) 87 | parser.add_argument( 88 | "--time_budget", 89 | type=int, 90 | default=3600*3, 91 | help="How many seconds to train for" 92 | ) 93 | args = parser.parse_args() 94 | 95 | # Load the data 96 | dset = ZCAWhitening(CIFAR10()) 97 | x_train, y_train = dset.train_data[:] 98 | x_test, y_test = dset.test_data[:] 99 | 100 | # Build the model 101 | training_schedule = TrainingSchedule(args.time_budget) 102 | model = wide_resnet(args.depth, args.width)(dset.shape, dset.output_size) 103 | model.compile( 104 | loss="categorical_crossentropy", 105 | optimizer=SGD(lr=training_schedule.lr, momentum=0.9), 106 | metrics=["accuracy"] 107 | ) 108 | model.summary() 109 | 110 | # Create the data augmentation generator 111 | datagen = ImageDataGenerator( 112 | # set input mean to 0 over the dataset 113 | featurewise_center=False, 114 | # set each sample mean to 0 115 | samplewise_center=False, 116 | # divide inputs by std of dataset 117 | featurewise_std_normalization=False, 118 | # divide each input by its std 119 | samplewise_std_normalization=False, 120 | # apply ZCA whitening 121 | zca_whitening=False, 122 | # randomly rotate images in the range (deg 0 to 180) 123 | rotation_range=0, 124 | # randomly shift images horizontally 125 | width_shift_range=0.1, 126 | # randomly shift images vertically 127 | height_shift_range=0.1, 128 | # randomly flip images 129 | horizontal_flip=True, 130 | # randomly flip images 131 | vertical_flip=False) 132 | datagen.fit(x_train) 133 | 134 | # Train the model 135 | if args.importance_training: 136 | ImportanceTraining(model).fit_generator( 137 | datagen.flow(x_train, y_train, batch_size=args.batch_size), 138 | validation_data=(x_test, y_test), 139 | epochs=10**6, 140 | verbose=1, 141 | callbacks=[training_schedule], 142 | batch_size=args.batch_size, 143 | steps_per_epoch=int(np.ceil(float(len(x_train)) / args.batch_size)) 144 | ) 145 | else: 146 | model.fit_generator( 147 | datagen.flow(x_train, y_train, batch_size=args.batch_size), 148 | validation_data=(x_test, y_test), 149 | epochs=10**6, 150 | verbose=1, 151 | callbacks=[training_schedule] 152 | ) 153 | 154 | # Score trained model. 155 | scores = model.evaluate(x_test, y_test, verbose=1) 156 | print('Test loss:', scores[0]) 157 | print('Test accuracy:', scores[1]) 158 | -------------------------------------------------------------------------------- /examples/example_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | 4 | def get_parser(desc): 5 | parser = argparse.ArgumentParser(description=desc) 6 | parser.add_argument( 7 | "--uniform", 8 | action="store_false", 9 | dest="importance_training", 10 | help="Enable uniform sampling" 11 | ) 12 | 13 | return parser 14 | -------------------------------------------------------------------------------- /examples/importance_sampling: -------------------------------------------------------------------------------- 1 | ../importance_sampling/ -------------------------------------------------------------------------------- /examples/mnist_cnn.py: -------------------------------------------------------------------------------- 1 | """Trains a simple convnet on the MNIST dataset.""" 2 | 3 | from __future__ import print_function 4 | import keras 5 | from keras.datasets import mnist 6 | from keras.models import Sequential 7 | from keras.layers import Activation, Dense, Flatten 8 | from keras.layers import Conv2D, MaxPooling2D 9 | from keras.regularizers import l2 10 | from keras import backend as K 11 | 12 | from importance_sampling.training import ConstantTimeImportanceTraining 13 | from example_utils import get_parser 14 | 15 | if __name__ == "__main__": 16 | parser = get_parser("Train a CNN on MNIST") 17 | args = parser.parse_args() 18 | 19 | batch_size = 128 20 | num_classes = 10 21 | epochs = 10 22 | 23 | # input image dimensions 24 | img_rows, img_cols = 28, 28 25 | 26 | # the data, shuffled and split between train and test sets 27 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 28 | 29 | if K.image_data_format() == 'channels_first': 30 | x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) 31 | x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) 32 | input_shape = (1, img_rows, img_cols) 33 | else: 34 | x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) 35 | x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) 36 | input_shape = (img_rows, img_cols, 1) 37 | 38 | x_train = x_train.astype('float32') 39 | x_test = x_test.astype('float32') 40 | x_train /= 255 41 | x_test /= 255 42 | print('x_train shape:', x_train.shape) 43 | print(x_train.shape[0], 'train samples') 44 | print(x_test.shape[0], 'test samples') 45 | 46 | # convert class vectors to binary class matrices 47 | y_train = keras.utils.to_categorical(y_train, num_classes) 48 | y_test = keras.utils.to_categorical(y_test, num_classes) 49 | 50 | model = Sequential() 51 | model.add(Conv2D(32, kernel_size=(3, 3), 52 | activation='relu', 53 | kernel_regularizer=l2(1e-5), 54 | input_shape=input_shape)) 55 | model.add(Conv2D(64, (3, 3), activation='relu', kernel_regularizer=l2(1e-5))) 56 | model.add(MaxPooling2D(pool_size=(2, 2))) 57 | model.add(Flatten()) 58 | model.add(Dense(128, activation='relu', kernel_regularizer=l2(1e-5))) 59 | model.add(Dense(num_classes, kernel_regularizer=l2(1e-5))) 60 | model.add(Activation('softmax')) 61 | 62 | model.compile(loss=keras.losses.categorical_crossentropy, 63 | optimizer=keras.optimizers.Adadelta(), 64 | metrics=['accuracy']) 65 | 66 | if args.importance_training: 67 | wrapped = ConstantTimeImportanceTraining(model) 68 | else: 69 | wrapped = model 70 | wrapped.fit(x_train, y_train, 71 | batch_size=batch_size, 72 | epochs=epochs, 73 | verbose=1, 74 | validation_data=(x_test, y_test)) 75 | score = model.evaluate(x_test, y_test, verbose=0) 76 | print('Test loss:', score[0]) 77 | print('Test accuracy:', score[1]) 78 | -------------------------------------------------------------------------------- /examples/mnist_mlp.py: -------------------------------------------------------------------------------- 1 | """Trains a simple fully connected NN on the MNIST dataset.""" 2 | 3 | from __future__ import print_function 4 | 5 | import keras 6 | from keras.datasets import mnist 7 | from keras.models import Sequential 8 | from keras.layers import Activation, Dense 9 | from keras.optimizers import RMSprop 10 | from keras.regularizers import l2 11 | 12 | from importance_sampling.training import ImportanceTraining 13 | from example_utils import get_parser 14 | 15 | if __name__ == "__main__": 16 | parser = get_parser("Train an MLP on MNIST") 17 | args = parser.parse_args() 18 | 19 | batch_size = 128 20 | num_classes = 10 21 | epochs = 10 22 | 23 | # the data, shuffled and split between train and test sets 24 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 25 | 26 | x_train = x_train.reshape(60000, 784) 27 | x_test = x_test.reshape(10000, 784) 28 | x_train = x_train.astype('float32') 29 | x_test = x_test.astype('float32') 30 | x_train /= 255 31 | x_test /= 255 32 | print(x_train.shape[0], 'train samples') 33 | print(x_test.shape[0], 'test samples') 34 | 35 | # convert class vectors to binary class matrices 36 | y_train = keras.utils.to_categorical(y_train, num_classes) 37 | y_test = keras.utils.to_categorical(y_test, num_classes) 38 | 39 | model = Sequential() 40 | model.add(Dense(512, activation='relu', kernel_regularizer=l2(1e-5), 41 | input_shape=(784,))) 42 | model.add(Dense(512, activation='relu', kernel_regularizer=l2(1e-5))) 43 | model.add(Dense(10, kernel_regularizer=l2(1e-5))) 44 | model.add(Activation('softmax')) 45 | 46 | model.summary() 47 | 48 | model.compile(loss='categorical_crossentropy', 49 | optimizer=RMSprop(), 50 | metrics=['accuracy']) 51 | 52 | if args.importance_training: 53 | wrapped = ImportanceTraining(model, presample=5) 54 | else: 55 | wrapped = model 56 | history = wrapped.fit( 57 | x_train, y_train, 58 | batch_size=batch_size, 59 | epochs=epochs, 60 | verbose=1, 61 | validation_data=(x_test, y_test) 62 | ) 63 | score = model.evaluate(x_test, y_test, verbose=0) 64 | print('Test loss:', score[0]) 65 | print('Test accuracy:', score[1]) 66 | -------------------------------------------------------------------------------- /examples/mnist_svrg.py: -------------------------------------------------------------------------------- 1 | """Trains a simple logistic regression on the MNIST dataset.""" 2 | 3 | from __future__ import print_function 4 | 5 | import keras 6 | from keras import backend as K 7 | from keras.datasets import mnist 8 | from keras.models import Sequential 9 | from keras.layers import Activation, Dense 10 | from keras.optimizers import SGD 11 | 12 | from importance_sampling.training import SVRG 13 | from example_utils import get_parser 14 | 15 | if __name__ == "__main__": 16 | parser = get_parser("Train logistic regression with SVRG on MNIST") 17 | args = parser.parse_args() 18 | 19 | batch_size = 16 20 | num_classes = 10 21 | epochs = 100 22 | 23 | # the data, shuffled and split between train and test sets 24 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 25 | 26 | x_train = x_train.reshape(60000, 784) 27 | x_test = x_test.reshape(10000, 784) 28 | x_train = x_train.astype('float32') 29 | x_test = x_test.astype('float32') 30 | x_train /= 255 31 | x_test /= 255 32 | print(x_train.shape[0], 'train samples') 33 | print(x_test.shape[0], 'test samples') 34 | 35 | # convert class vectors to binary class matrices 36 | y_train = keras.utils.to_categorical(y_train, num_classes) 37 | y_test = keras.utils.to_categorical(y_test, num_classes) 38 | 39 | model = Sequential() 40 | model.add(Dense(10, input_shape=(784,))) 41 | model.add(Activation('softmax')) 42 | 43 | model.summary() 44 | 45 | model.compile(loss='categorical_crossentropy', 46 | optimizer=SGD(lr=0.1, momentum=0), 47 | metrics=['accuracy']) 48 | 49 | sgd_epochs = epochs 50 | svrg_epochs = 0 51 | svrg_wrapped = SVRG(model, B=0, B_over_b=300) 52 | if args.importance_training: 53 | sgd_epochs = 20 54 | svrg_epochs = (epochs-20) // (2 + len(x_train)//(batch_size*300)) 55 | history = model.fit( 56 | x_train, y_train, 57 | batch_size=batch_size, 58 | epochs=sgd_epochs, 59 | verbose=1, 60 | validation_data=(x_test, y_test) 61 | ) 62 | history = svrg_wrapped.fit( 63 | x_train, y_train, 64 | batch_size=batch_size, 65 | epochs=svrg_epochs, 66 | verbose=1, 67 | validation_data=(x_test, y_test) 68 | ) 69 | score = model.evaluate(x_test, y_test, verbose=0) 70 | print('Test loss:', score[0]) 71 | print('Test accuracy:', score[1]) 72 | -------------------------------------------------------------------------------- /examples/mnist_svrg_cnn.py: -------------------------------------------------------------------------------- 1 | """Trains a simple convnet on the MNIST dataset.""" 2 | 3 | from __future__ import print_function 4 | 5 | import time 6 | 7 | import keras 8 | from keras.datasets import mnist 9 | from keras.models import Sequential 10 | from keras.layers import Activation, Dense, Flatten 11 | from keras.layers import Conv2D, MaxPooling2D 12 | from keras.regularizers import l2 13 | from keras import backend as K 14 | 15 | from importance_sampling.training import ConstantTimeImportanceTraining, SVRG 16 | from example_utils import get_parser 17 | 18 | if __name__ == "__main__": 19 | batch_size = 128 20 | num_classes = 10 21 | epochs = 10 22 | 23 | # input image dimensions 24 | img_rows, img_cols = 28, 28 25 | 26 | # the data, shuffled and split between train and test sets 27 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 28 | 29 | if K.image_data_format() == 'channels_first': 30 | x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) 31 | x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) 32 | input_shape = (1, img_rows, img_cols) 33 | else: 34 | x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) 35 | x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) 36 | input_shape = (img_rows, img_cols, 1) 37 | 38 | x_train = x_train.astype('float32') 39 | x_test = x_test.astype('float32') 40 | x_train /= 255 41 | x_test /= 255 42 | print('x_train shape:', x_train.shape) 43 | print(x_train.shape[0], 'train samples') 44 | print(x_test.shape[0], 'test samples') 45 | 46 | # convert class vectors to binary class matrices 47 | y_train = keras.utils.to_categorical(y_train, num_classes) 48 | y_test = keras.utils.to_categorical(y_test, num_classes) 49 | 50 | model = Sequential() 51 | model.add(Conv2D(32, kernel_size=(3, 3), 52 | activation='relu', 53 | kernel_regularizer=l2(1e-5), 54 | input_shape=input_shape)) 55 | model.add(Conv2D(64, (3, 3), activation='relu', kernel_regularizer=l2(1e-5))) 56 | model.add(MaxPooling2D(pool_size=(2, 2))) 57 | model.add(Flatten()) 58 | model.add(Dense(128, activation='relu', kernel_regularizer=l2(1e-5))) 59 | model.add(Dense(num_classes, kernel_regularizer=l2(1e-5))) 60 | model.add(Activation('softmax')) 61 | 62 | model.compile(loss=keras.losses.categorical_crossentropy, 63 | optimizer=keras.optimizers.SGD(lr=0.01, momentum=0.9), 64 | metrics=['accuracy']) 65 | 66 | # Keep the initial weights to compare 67 | W = model.get_weights() 68 | 69 | # Train with SVRG 70 | s_svrg = time.time() 71 | model.set_weights(W) 72 | SVRG(model, B=0, B_over_b=len(x_train) // batch_size).fit( 73 | x_train, y_train, 74 | batch_size=batch_size, 75 | epochs=epochs, 76 | verbose=1, 77 | validation_data=(x_test, y_test) 78 | ) 79 | e_svrg = time.time() 80 | score_svrg = model.evaluate(x_test, y_test, verbose=0) 81 | 82 | # Train with uniform 83 | s_uniform = time.time() 84 | model.set_weights(W) 85 | model.fit( 86 | x_train, y_train, 87 | batch_size=batch_size, 88 | epochs=epochs, 89 | verbose=1, 90 | validation_data=(x_test, y_test) 91 | ) 92 | e_uniform = time.time() 93 | score_uniform = model.evaluate(x_test, y_test, verbose=0) 94 | 95 | # Train with IS 96 | s_is = time.time() 97 | model.set_weights(W) 98 | ConstantTimeImportanceTraining(model).fit( 99 | x_train, y_train, 100 | batch_size=batch_size, 101 | epochs=epochs, 102 | verbose=1, 103 | validation_data=(x_test, y_test) 104 | ) 105 | e_is = time.time() 106 | score_is = model.evaluate(x_test, y_test, verbose=0) 107 | 108 | # Print the results 109 | print("SVRG: ", score_svrg[1], " in ", e_svrg - s_svrg, "s") 110 | print("Uniform: ", score_uniform[1], " in ", e_uniform - s_uniform, "s") 111 | print("IS: ", score_is[1], " in ", e_is - s_is, "s") 112 | -------------------------------------------------------------------------------- /headers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 4 | # Written by Angelos Katharopoulos 5 | # 6 | 7 | import argparse 8 | from itertools import chain, ifilter 9 | import os 10 | from os import path 11 | from subprocess import PIPE, Popen 12 | 13 | 14 | class Header(object): 15 | """Represents the copyright header for a source file""" 16 | COPY = """# 17 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 18 | # Written by """ 19 | 20 | def __init__(self, start=-1, stop=-1, content=None): 21 | self.start = start 22 | self.stop = stop 23 | self.content = content 24 | 25 | def update(self, filepath, dry_run=False): 26 | def peek(f): 27 | pos = f.tell() 28 | c = f.read(1) 29 | f.seek(pos) 30 | 31 | return c 32 | 33 | new_content = Header.get_content_for_file(filepath) 34 | needs_update = new_content != self.content or self.start < 0 35 | 36 | if not needs_update: 37 | return 38 | 39 | # Do the dry run 40 | if dry_run: 41 | print new_content 42 | return 43 | 44 | # Open both files and do the copy while updating the header 45 | with open(filepath) as f_in, open(filepath+".header", "w") as f_out: 46 | # Copy the comments that appear on top 47 | while peek(f_in) == "#": 48 | f_out.write(f_in.readline()) 49 | 50 | # Consume one new line 51 | if peek(f_in) in ["\r", "\n"]: 52 | f_in.readline() 53 | 54 | # Add the new header 55 | start = f_out.tell() 56 | f_out.write(new_content) 57 | 58 | # If the file had a header skip it while writing the rest of the data 59 | if self.start > 0: 60 | f_out.write(f_in.read(max(0, self.start - f_in.tell()))) 61 | f_in.seek(max(f_in.tell(), self.stop)) 62 | f_out.write(f_in.read()) 63 | 64 | stat = os.stat(filepath) 65 | os.chmod(filepath+".header", stat.st_mode) 66 | os.rename(filepath+".header", filepath) 67 | 68 | 69 | @classmethod 70 | def from_file(cls, filepath): 71 | # Create an empty object to be filled with contents 72 | header = cls() 73 | 74 | # Read the file contents into memory 75 | with open(filepath) as f: 76 | contents = f.read() 77 | 78 | # Find the copyright disclaimer 79 | start = contents.find("#\n# Copyright") 80 | if start < 0: 81 | return header 82 | end = contents.find("\n#\n\n", start) + 4 83 | 84 | # Fill in the header 85 | header.start = start 86 | header.end = end 87 | header.content = contents[start:end] 88 | 89 | return header 90 | 91 | @staticmethod 92 | def get_content_for_file(filepath): 93 | """Return the generated header for the file""" 94 | # Call into git to get the list of authors 95 | p = Popen(["git", "shortlog", "-se", "--", filepath], stdout=PIPE) 96 | out, _ = p.communicate() 97 | authors = [ 98 | l.split("\t")[1].strip() 99 | for l in out.splitlines() 100 | if l != "" 101 | ] 102 | 103 | return Header.COPY + ",\n# ".join(authors) + "\n#\n\n" 104 | 105 | 106 | def is_python_file(path): 107 | return path.endswith(".py") 108 | 109 | 110 | def in_directory(directory): 111 | if directory[0] == path.sep: 112 | directory = directory[1:] 113 | if directory[-1] == path.sep: 114 | directory = directory[:-1] 115 | def inner(x): 116 | return path.sep + directory + path.sep in x 117 | return inner 118 | 119 | 120 | def _all(*predicates): 121 | def inner(x): 122 | return all(p(x) for p in predicates) 123 | return inner 124 | 125 | 126 | def _not(predicate): 127 | def inner(x): 128 | return not predicate(x) 129 | return inner 130 | 131 | 132 | def walk_directories(root): 133 | """'find' in a generator function.""" 134 | for child in os.listdir(root): 135 | if child.startswith("."): 136 | continue 137 | 138 | full_path = path.join(root, child) 139 | if path.isfile(full_path): 140 | yield full_path 141 | elif full_path.endswith((path.sep+".", path.sep+"..")): 142 | continue 143 | elif path.islink(full_path): 144 | continue 145 | else: 146 | for fp in walk_directories(full_path): 147 | yield fp 148 | 149 | 150 | if __name__ == "__main__": 151 | parser = argparse.ArgumentParser( 152 | description=("Generate file copywrite headers and prepend them to " 153 | "the files in the repository") 154 | ) 155 | parser.add_argument( 156 | "--dry_run", 157 | action="store_true", 158 | help="Don't actually change anything just write the headers to STDOUT" 159 | ) 160 | parser.add_argument( 161 | "--blacklist", 162 | type=lambda x: x.split(":"), 163 | default=[], 164 | help="A colon separated list of directories to blacklist" 165 | ) 166 | 167 | args = parser.parse_args() 168 | 169 | # Loop over all python files 170 | predicate = _all( 171 | is_python_file, 172 | _all(*map(_not, map(in_directory, args.blacklist))) 173 | ) 174 | for source_file in ifilter(predicate, walk_directories(".")): 175 | print source_file 176 | header = Header.from_file(source_file) 177 | header.update(source_file, args.dry_run) 178 | -------------------------------------------------------------------------------- /importance_sampling/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | """Accelerate training of neural networks using importance sampling.""" 7 | 8 | __author__ = "Angelos Katharopoulos" 9 | __copyright__ = "Copyright (c) 2017 Idiap Research Institute" 10 | __license__ = "MIT" 11 | __maintainer__ = "Angelos Katharopoulos" 12 | __email__ = "angelos.katharopoulos@idiap.ch" 13 | __url__ = "http://www.idiap.ch/~katharas/importance-sampling/" 14 | __version__ = "0.10" 15 | 16 | from .training import ImportanceTraining 17 | -------------------------------------------------------------------------------- /importance_sampling/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | from .metrics import MetricLayer, TripletLossLayer 7 | from .normalization import BatchRenormalization, LayerNormalization, \ 8 | StatsBatchNorm, GroupNormalization 9 | from .scores import GradientNormLayer, LossLayer 10 | -------------------------------------------------------------------------------- /importance_sampling/layers/metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | from functools import reduce 7 | 8 | from keras import backend as K 9 | from keras.engine import Layer 10 | from keras.metrics import categorical_accuracy, binary_accuracy, \ 11 | get as get_metric, sparse_categorical_accuracy 12 | 13 | from ..utils.functional import compose 14 | 15 | 16 | class MetricLayer(Layer): 17 | """Create a layer that computes a metric taking into account masks.""" 18 | def __init__(self, metric_func, **kwargs): 19 | self.supports_masking = True 20 | self.metric_func = metric_func 21 | 22 | super(MetricLayer, self).__init__(**kwargs) 23 | 24 | def compute_mask(self, inputs, input_mask): 25 | return None 26 | 27 | def build(self, input_shape): 28 | # Special care for accuracy because keras treats it specially 29 | try: 30 | if "acc" in self.metric_func: 31 | self.metric_func = self._generic_accuracy 32 | except TypeError: 33 | pass # metric_func is not a string 34 | self.metric_func = compose(K.expand_dims, get_metric(self.metric_func)) 35 | 36 | super(MetricLayer, self).build(input_shape) 37 | 38 | def compute_output_shape(self, input_shape): 39 | # We need two inputs y_true, y_pred 40 | assert len(input_shape) == 2 41 | 42 | return (input_shape[0][0], 1) 43 | 44 | def call(self, inputs, mask=None): 45 | # Compute the metric 46 | metric = self.metric_func(*inputs) 47 | if K.int_shape(metric)[-1] == 1: 48 | metric = K.squeeze(metric, axis=-1) 49 | if len(K.int_shape(metric)) == 0: 50 | metric = K.ones(K.shape(inputs[0])[0]) * metric 51 | 52 | # Apply the mask if needed 53 | if mask is not None: 54 | if not isinstance(mask, list): 55 | mask = [mask] 56 | mask = [K.cast(m, K.floatx()) for m in mask if m is not None] 57 | mask = reduce(lambda a, b: a*b, mask) 58 | metric *= mask 59 | metric /= K.mean(mask, axis=-1, keepdims=True) 60 | 61 | # Make sure that the tensor returned is (None, 1) 62 | dims = len(K.int_shape(metric)) 63 | if dims > 1: 64 | metric = K.mean(metric, axis=list(range(1, dims))) 65 | 66 | return K.expand_dims(metric) 67 | 68 | @staticmethod 69 | def _generic_accuracy(y_true, y_pred): 70 | if K.int_shape(y_pred)[1] == 1: 71 | return binary_accuracy(y_true, y_pred) 72 | if K.int_shape(y_true)[-1] == 1: 73 | return sparse_categorical_accuracy(y_true, y_pred) 74 | 75 | return categorical_accuracy(y_true, y_pred) 76 | 77 | 78 | class TripletLossLayer(Layer): 79 | """A bit of an unorthodox layer that implements the triplet loss with L2 80 | normalization. 81 | 82 | It receives 1 vector which is the concatenation of the three 83 | representations and performs the following operations. 84 | x = concat(x_a, x_p, x_n) 85 | N = x.shape[1]/3 86 | return ||x[:N] - x[2*N:]||_2^2 - ||x[:N] - x[N:2*N]||_2^2 87 | """ 88 | def __init__(self, **kwargs): 89 | super(TripletLossLayer, self).__init__(**kwargs) 90 | 91 | def build(self, input_shape): 92 | assert not isinstance(input_shape, list) 93 | self.N = input_shape[1] // 3 94 | self.built = True 95 | 96 | def compute_output_shape(self, input_shape): 97 | assert not isinstance(input_shape, list) 98 | return (input_shape[0], 1) 99 | 100 | def call(self, x): 101 | N = self.N 102 | 103 | xa = x[:, :N] 104 | xp = x[:, N:2*N] 105 | xn = x[:, 2*N:] 106 | 107 | xa = xa / K.sqrt(K.sum(xa**2, axis=1, keepdims=True)) 108 | xp = xp / K.sqrt(K.sum(xp**2, axis=1, keepdims=True)) 109 | xn = xn / K.sqrt(K.sum(xn**2, axis=1, keepdims=True)) 110 | 111 | dn = K.sum(K.square(xa - xn), axis=1, keepdims=True) 112 | dp = K.sum(K.square(xa - xp), axis=1, keepdims=True) 113 | 114 | return dn - dp 115 | -------------------------------------------------------------------------------- /importance_sampling/layers/normalization.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | from keras import backend as K 7 | from keras.engine import Layer 8 | from keras import initializers 9 | 10 | from ..utils.tf import tf 11 | 12 | 13 | class _BaseNormalization(Layer): 14 | """Implement utility functions for the normalization layers.""" 15 | def _moments(self, x, axes): 16 | return ( 17 | K.mean(x, axis=axes, keepdims=True), 18 | K.var(x, axis=axes, keepdims=True) 19 | ) 20 | 21 | 22 | class BatchRenormalization(_BaseNormalization): 23 | """Batch renormalization layer (Sergey Ioffe, 2017). 24 | 25 | # Arguments 26 | momentum: Momentum for the moving average 27 | epsilon: Added to variance to avoid divide by 0 28 | rmax: Maximum correction for the variance 29 | dmax: Maximum correction for the bias 30 | """ 31 | def __init__(self, momentum=0.99, epsilon=1e-3, rmax_0=1., rmax_inf=3., 32 | dmax_0=0., dmax_inf=5., rmax_duration=40000, 33 | dmax_duration=25000, **kwargs): 34 | super(BatchRenormalization, self).__init__(**kwargs) 35 | 36 | self.momentum = momentum 37 | self.epsilon = epsilon 38 | self.rmax_0 = rmax_0 39 | self.rmax_inf = rmax_inf 40 | self.rmax_dur = rmax_duration 41 | self.dmax_0 = dmax_0 42 | self.dmax_inf = dmax_inf 43 | self.dmax_dur = dmax_duration 44 | 45 | def build(self, input_shape): 46 | dim = input_shape[-1] 47 | if dim is None: 48 | raise ValueError(("The normalization axis should have a " 49 | "defined dimension")) 50 | self.dim = dim 51 | 52 | # Trainable part 53 | self.gamma = self.add_weight( 54 | shape=(dim,), 55 | name="gamma", 56 | initializer=initializers.get("ones") 57 | ) 58 | self.beta = self.add_weight( 59 | shape=(dim,), 60 | name="beta", 61 | initializer=initializers.get("zeros") 62 | ) 63 | 64 | # Statistics 65 | self.moving_mean = self.add_weight( 66 | shape=(dim,), 67 | name="moving_mean", 68 | initializer=initializers.get("zeros"), 69 | trainable=False 70 | ) 71 | self.moving_sigma = self.add_weight( 72 | shape=(dim,), 73 | name="moving_sigma", 74 | initializer=initializers.get("ones"), 75 | trainable=False 76 | ) 77 | 78 | # rmax, dmax and steps 79 | self.steps = self.add_weight( 80 | shape=tuple(), 81 | name="steps", 82 | initializer=initializers.get("zeros"), 83 | trainable=False 84 | ) 85 | self.rmax = self.add_weight( 86 | shape=tuple(), 87 | name="rmax", 88 | initializer=initializers.Constant(self.rmax_0), 89 | trainable=False 90 | ) 91 | self.dmax = self.add_weight( 92 | shape=tuple(), 93 | name="dmax", 94 | initializer=initializers.Constant(self.dmax_0), 95 | trainable=False 96 | ) 97 | 98 | self.built = True 99 | 100 | def _moments(self, x): 101 | axes = range(len(K.int_shape(x))-1) 102 | if K.backend() == "tensorflow": 103 | return tf.nn.moments(x, axes) 104 | else: 105 | # TODO: Maybe the following can be optimized a bit? 106 | mean = K.mean(K.reshape(x, (-1, self.dim)), axis=0) 107 | var = K.var(K.reshape(x, (-1, self.dim)), axis=0) 108 | 109 | return mean, var 110 | 111 | def _clip(self, x, x_min, x_max): 112 | if K.backend() == "tensorflow": 113 | return tf.clip_by_value(x, x_min, x_max) 114 | else: 115 | return K.maximum(K.minimum(x, x_max), x_min) 116 | 117 | def call(self, inputs, training=None): 118 | x = inputs 119 | assert not isinstance(x, list) 120 | 121 | # Compute the minibatch statistics 122 | mean, var = self._moments(x) 123 | sigma = K.sqrt(var + self.epsilon) 124 | 125 | # If in training phase set rmax, dmax large so that we use the moving 126 | # averages to do the normalization 127 | rmax = K.in_train_phase(self.rmax, K.constant(1e5), training) 128 | dmax = K.in_train_phase(self.dmax, K.constant(1e5), training) 129 | 130 | # Compute the corrections based on rmax, dmax 131 | r = K.stop_gradient(self._clip( 132 | sigma/self.moving_sigma, 133 | 1./rmax, 134 | rmax 135 | )) 136 | d = K.stop_gradient(self._clip( 137 | (mean - self.moving_mean)/self.moving_sigma, 138 | -dmax, 139 | dmax 140 | )) 141 | 142 | # Actually do the normalization and the rescaling 143 | xnorm = ((x-mean)/sigma)*r + d 144 | y = self.gamma * xnorm + self.beta 145 | 146 | # Add the moving average updates 147 | self.add_update([ 148 | K.moving_average_update(self.moving_mean, mean, self.momentum), 149 | K.moving_average_update(self.moving_sigma, sigma, self.momentum) 150 | ], x) 151 | 152 | # Add the r, d updates 153 | rmax_prog = K.minimum(1., self.steps/self.rmax_dur) 154 | dmax_prog = K.minimum(1., self.steps/self.dmax_dur) 155 | self.add_update([ 156 | K.update_add(self.steps, 1), 157 | K.update( 158 | self.rmax, 159 | self.rmax_0 + rmax_prog*(self.rmax_inf-self.rmax_0) 160 | ), 161 | K.update( 162 | self.dmax, 163 | self.dmax_0 + dmax_prog*(self.dmax_inf-self.dmax_0) 164 | ) 165 | ]) 166 | 167 | # Fix the output's uses learning phase 168 | y._uses_learning_phase = rmax._uses_learning_phase 169 | 170 | return y 171 | 172 | 173 | class LayerNormalization(_BaseNormalization): 174 | """LayerNormalization is a determenistic normalization layer to replace 175 | BN's stochasticity. 176 | 177 | # Arguments 178 | axes: list of axes that won't be aggregated over 179 | """ 180 | def __init__(self, axes=None, bias_axes=[-1], epsilon=1e-3, **kwargs): 181 | super(LayerNormalization, self).__init__(**kwargs) 182 | self.axes = axes 183 | self.bias_axes = bias_axes 184 | self.epsilon = epsilon 185 | 186 | def build(self, input_shape): 187 | # Get the number of dimensions and the axes that won't be aggregated 188 | # over 189 | ndims = len(input_shape) 190 | axes = self.axes or [] 191 | bias_axes = self.bias_axes or [-1] 192 | 193 | # Figure out the shape of the statistics 194 | gamma_shape = [1]*ndims 195 | beta_shape = [1]*ndims 196 | for ax in axes: 197 | gamma_shape[ax] = input_shape[ax] 198 | for ax in bias_axes: 199 | beta_shape[ax] = input_shape[ax] 200 | 201 | # Figure out the axes we will aggregate over accounting for negative 202 | # axes 203 | self.reduction_axes = [ 204 | ax for ax in range(ndims) 205 | if ax > 0 and (ax+ndims) % ndims not in axes 206 | ] 207 | 208 | # Create trainable variables 209 | self.gamma = self.add_weight( 210 | shape=gamma_shape, 211 | name="gamma", 212 | initializer=initializers.get("ones") 213 | ) 214 | self.beta = self.add_weight( 215 | shape=beta_shape, 216 | name="beta", 217 | initializer=initializers.get("zeros") 218 | ) 219 | 220 | self.built = True 221 | 222 | def call(self, inputs): 223 | x = inputs 224 | assert not isinstance(x, list) 225 | 226 | # Compute the per sample statistics 227 | mean, var = self._moments(x, self.reduction_axes) 228 | std = K.sqrt(var + self.epsilon) 229 | 230 | return self.gamma*(x-mean)/std + self.beta 231 | 232 | 233 | class StatsBatchNorm(_BaseNormalization): 234 | """Use the accumulated statistics for batch norm instead of computing them 235 | for each minibatch. 236 | 237 | # Arguments 238 | momentum: Momentum for the moving average 239 | epsilon: Added to variance to avoid divide by 0 240 | """ 241 | def __init__(self, momentum=0.99, epsilon=1e-3, update_stats=False, 242 | **kwargs): 243 | super(StatsBatchNorm, self).__init__(**kwargs) 244 | 245 | self.momentum = momentum 246 | self.epsilon = epsilon 247 | self.update_stats = update_stats 248 | 249 | def build(self, input_shape): 250 | dim = input_shape[-1] 251 | if dim is None: 252 | raise ValueError(("The normalization axis should have a " 253 | "defined dimension")) 254 | self.dim = dim 255 | 256 | # Trainable part 257 | self.gamma = self.add_weight( 258 | shape=(dim,), 259 | name="gamma", 260 | initializer=initializers.get("ones") 261 | ) 262 | self.beta = self.add_weight( 263 | shape=(dim,), 264 | name="beta", 265 | initializer=initializers.get("zeros") 266 | ) 267 | 268 | # Statistics 269 | self.moving_mean = self.add_weight( 270 | shape=(dim,), 271 | name="moving_mean", 272 | initializer=initializers.get("zeros"), 273 | trainable=False 274 | ) 275 | self.moving_variance = self.add_weight( 276 | shape=(dim,), 277 | name="moving_variance", 278 | initializer=initializers.get("ones"), 279 | trainable=False 280 | ) 281 | 282 | self.built = True 283 | 284 | def call(self, inputs, training=None): 285 | x = inputs 286 | assert not isinstance(x, list) 287 | 288 | # Do the normalization and the rescaling 289 | xnorm = K.batch_normalization( 290 | x, 291 | self.moving_mean, 292 | self.moving_variance, 293 | self.beta, 294 | self.gamma, 295 | epsilon=self.epsilon 296 | ) 297 | 298 | # Compute and update the minibatch statistics 299 | if self.update_stats: 300 | mean, var = self._moments(x, axes=range(len(K.int_shape(x))-1)) 301 | self.add_update([ 302 | K.moving_average_update(self.moving_mean, mean, self.momentum), 303 | K.moving_average_update(self.moving_variance, var, self.momentum) 304 | ], x) 305 | 306 | return xnorm 307 | 308 | 309 | class GroupNormalization(_BaseNormalization): 310 | """GroupNormalization is an improvement to LayerNormalization presented in 311 | https://arxiv.org/abs/1803.08494. 312 | 313 | #Arguments 314 | G: The channels per group in the normalization 315 | epsilon: Added to the variance to avoid division with 0 316 | """ 317 | def __init__(self, G=32, epsilon=1e-3, **kwargs): 318 | super(GroupNormalization, self).__init__(**kwargs) 319 | 320 | self.G = G 321 | self.epsilon = epsilon 322 | 323 | def build(self, input_shape): 324 | # Get the number of dimensions and the channel axis 325 | ndims = len(input_shape) 326 | channel_axis = -1 if K.image_data_format() == "channels_last" else -3 327 | shape = [1]*ndims 328 | shape[channel_axis] = input_shape[channel_axis] 329 | 330 | # Make sure everything is in order 331 | assert None not in shape, "The channel axis must be defined" 332 | assert shape[channel_axis] % self.G == 0, ("The channels must be " 333 | "divisible by the number " 334 | "of groups") 335 | 336 | # Create the trainable weights 337 | self.gamma = self.add_weight( 338 | shape=shape, 339 | name="gamma", 340 | initializer=initializers.get("ones") 341 | ) 342 | self.beta = self.add_weight( 343 | shape=shape, 344 | name="beta", 345 | initializer=initializers.get("zeros") 346 | ) 347 | 348 | # Keep the channel axis for later use 349 | self.channel_axis = channel_axis 350 | self.ndims = ndims 351 | 352 | # We 're done 353 | self.built = True 354 | 355 | def call(self, inputs): 356 | x = inputs 357 | assert not isinstance(x, list) 358 | 359 | # Get the shapes 360 | G = self.G 361 | channel_axis = self.channel_axis 362 | ndims = self.ndims 363 | original_shape = K.shape(x) 364 | shape = [ 365 | original_shape[i] 366 | for i in range(ndims) 367 | ] 368 | if channel_axis == -1: 369 | shape[channel_axis] //= G 370 | shape.append(G) 371 | axes = sorted([ 372 | ndims + channel_axis - i 373 | for i in range(3) 374 | ]) 375 | else: 376 | shape[channel_axis] //= G 377 | shape.insert(channel_axis, G) 378 | axes = sorted([ 379 | ndims + channel_axis + i 380 | for i in range(3) 381 | ]) 382 | 383 | # Do the group norm 384 | x = K.reshape(x, shape) 385 | mean, var = self._moments(x, axes) 386 | std = K.sqrt(var + self.epsilon) 387 | x = (x - mean)/std 388 | x = K.reshape(x, original_shape) 389 | 390 | return self.gamma*x + self.beta 391 | -------------------------------------------------------------------------------- /importance_sampling/layers/scores.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | from functools import reduce 7 | 8 | from keras import backend as K 9 | from keras import objectives 10 | from keras.layers import Layer 11 | 12 | 13 | def _per_sample_loss(loss_function, mask, x): 14 | """Compute the per sample loss supporting masking and returning always 15 | tensor of shape (batch_size, 1) 16 | 17 | Arguments 18 | --------- 19 | loss_function: callable 20 | mask: boolean tensor or None 21 | x: list/tuple of inputs to the loss 22 | """ 23 | # Compute the loss 24 | loss = loss_function(*x) 25 | 26 | # Apply masking if needed 27 | if mask is not None: 28 | if not isinstance(mask, list): 29 | mask = [mask] 30 | mask = [K.cast(m, K.floatx()) for m in mask if m is not None] 31 | mask = reduce(lambda a, b: a*b, mask) 32 | mask_dims = len(K.int_shape(mask)) 33 | mask_mean = K.mean(mask, axis=list(range(1, mask_dims)), keepdims=True) 34 | loss *= mask 35 | loss /= mask_mean 36 | 37 | # If the loss has more than 1 dimensions then aggregate the last dimension 38 | dims = len(K.int_shape(loss)) 39 | if dims > 1: 40 | loss = K.mean(loss, axis=list(range(1, dims))) 41 | 42 | return K.expand_dims(loss) 43 | 44 | 45 | class LossLayer(Layer): 46 | """LossLayer outputs the loss per sample 47 | 48 | # Arguments 49 | loss: The loss function to use to combine the model output and the 50 | target 51 | """ 52 | def __init__(self, loss, **kwargs): 53 | self.supports_masking = True 54 | self.loss = objectives.get(loss) 55 | 56 | super(LossLayer, self).__init__(**kwargs) 57 | 58 | def compute_mask(self, inputs, input_mask): 59 | return None 60 | 61 | def build(self, input_shape): 62 | pass # Nothing to do 63 | 64 | super(LossLayer, self).build(input_shape) 65 | 66 | def compute_output_shape(self, input_shape): 67 | # We need two inputs X and y 68 | assert len(input_shape) == 2 69 | 70 | # (None, 1) because all losses should be scalar 71 | return (input_shape[0][0], 1) 72 | 73 | def call(self, x, mask=None): 74 | return _per_sample_loss(self.loss, mask, x) 75 | 76 | 77 | class GradientNormLayer(Layer): 78 | """GradientNormLayer aims to output the gradient norm given a list of 79 | parameters (whose gradient to compute) and a loss function to combine the 80 | two inputs. 81 | 82 | # Arguments 83 | parameter_list: A list of Keras variables to compute the gradient 84 | norm for 85 | loss: The loss function to use to combine the model output and the 86 | target into a scalar and then compute the gradient norm 87 | fast: If set to True it means we know that the gradient with respect to 88 | each sample only affects one part of the parameter list so we can 89 | use the batch mode to compute the gradient 90 | """ 91 | def __init__(self, parameter_list, loss, fast=False, **kwargs): 92 | self.supports_masking = True 93 | self.parameter_list = parameter_list 94 | self.loss = objectives.get(loss) 95 | self.fast = fast 96 | 97 | super(GradientNormLayer, self).__init__(**kwargs) 98 | 99 | def compute_mask(self, inputs, input_mask): 100 | return None 101 | 102 | def build(self, input_shape): 103 | pass # Nothing to do 104 | 105 | super(GradientNormLayer, self).build(input_shape) 106 | 107 | def compute_output_shape(self, input_shape): 108 | # We get two inputs 109 | assert len(input_shape) == 2 110 | 111 | return (input_shape[0][0], 1) 112 | 113 | def call(self, x, mask=None): 114 | # x should be an output and a target 115 | assert len(x) == 2 116 | 117 | losses = _per_sample_loss(self.loss, mask, x) 118 | if self.fast: 119 | grads = K.sqrt(sum([ 120 | self._sum_per_sample(K.square(g)) 121 | for g in K.gradients(losses, self.parameter_list) 122 | ])) 123 | else: 124 | nb_samples = K.shape(losses)[0] 125 | grads = K.map_fn( 126 | lambda i: self._grad_norm(losses[i]), 127 | K.arange(0, nb_samples), 128 | dtype=K.floatx() 129 | ) 130 | 131 | return K.reshape(grads, (-1, 1)) 132 | 133 | def _sum_per_sample(self, x): 134 | """Sum across all the dimensions except the batch dim""" 135 | # Instead we might be able to use x.ndims but there have been problems 136 | # with ndims and Keras so I think len(int_shape()) is more reliable 137 | dims = len(K.int_shape(x)) 138 | return K.sum(x, axis=list(range(1, dims))) 139 | 140 | def _grad_norm(self, loss): 141 | grads = K.gradients(loss, self.parameter_list) 142 | return K.sqrt( 143 | sum([ 144 | K.sum(K.square(g)) 145 | for g in grads 146 | ]) 147 | ) 148 | 149 | -------------------------------------------------------------------------------- /importance_sampling/model_wrappers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | from functools import partial 7 | import sys 8 | 9 | from blinker import signal 10 | from keras import backend as K 11 | from keras.layers import Input, Layer, multiply 12 | from keras.models import Model, clone_model 13 | import numpy as np 14 | 15 | from .layers import GradientNormLayer, LossLayer, MetricLayer 16 | from .reweighting import UNBIASED 17 | from .utils.functional import compose 18 | 19 | 20 | def _tolist(x, acceptable_iterables=(list, tuple)): 21 | if not isinstance(x, acceptable_iterables): 22 | x = [x] 23 | return x 24 | 25 | 26 | def _get_scoring_layer(score, y_true, y_pred, loss="categorical_crossentropy", 27 | layer=None, model=None): 28 | """Get a scoring layer that computes the score for each pair of y_true, 29 | y_pred""" 30 | assert score in ["loss", "gnorm", "full_gnorm", "acc"] 31 | 32 | if score == "loss": 33 | return LossLayer(loss)([ 34 | y_true, 35 | y_pred 36 | ]) 37 | elif score == "gnorm": 38 | return GradientNormLayer( 39 | layer.output, 40 | loss, 41 | fast=True 42 | )([ 43 | y_true, 44 | y_pred 45 | ]) 46 | elif score == "full_gnorm": 47 | return GradientNormLayer( 48 | model.trainable_weights, 49 | loss, 50 | fast=False 51 | )([ 52 | y_true, 53 | y_pred 54 | ]) 55 | elif score == "acc": 56 | return LossLayer(categorical_accuracy)([ 57 | y_true, 58 | y_pred 59 | ]) 60 | 61 | 62 | class ModelWrapper(object): 63 | """The goal of the ModelWrapper is to take a NN and add some extra layers 64 | that produce a score, a loss and the sample weights to perform importance 65 | sampling.""" 66 | def _iterate_batches(self, x, y, batch_size): 67 | bs = batch_size 68 | for s in range(0, len(y), bs): 69 | yield [xi[s:s+bs] for xi in _tolist(x)], y[s:s+bs] 70 | 71 | def evaluate(self, x, y, batch_size=128): 72 | result = np.mean( 73 | np.vstack([ 74 | self.evaluate_batch(xi, yi) 75 | for xi, yi in self._iterate_batches(x, y, batch_size) 76 | ]), 77 | axis=0 78 | ) 79 | 80 | signal("is.evaluation").send(result) 81 | return result 82 | 83 | def score(self, x, y, batch_size=128): 84 | bs = batch_size 85 | result = np.hstack([ 86 | self.score_batch(xi, yi).T 87 | for xi, yi in self._iterate_batches(x, y, batch_size) 88 | ]).T 89 | 90 | signal("is.score").send(result) 91 | return result 92 | 93 | def set_lr(self, lr): 94 | """Set the learning rate of the wrapped models. 95 | 96 | We try to set the learning rate on a member variable model and a member 97 | variable small. If we do not find a member variable model we raise a 98 | NotImplementedError 99 | """ 100 | try: 101 | K.set_value( 102 | self.optimizer.lr, 103 | lr 104 | ) 105 | except AttributeError: 106 | try: 107 | K.set_value( 108 | self.model.optimizer.lr, 109 | lr 110 | ) 111 | except AttributeError: 112 | raise NotImplementedError() 113 | 114 | try: 115 | K.set_value( 116 | self.small.optimizer.lr, 117 | lr 118 | ) 119 | except AttributeError: 120 | pass 121 | 122 | def evaluate_batch(self, x, y): 123 | raise NotImplementedError() 124 | 125 | def score_batch(self, x, y): 126 | raise NotImplementedError() 127 | 128 | def train_batch(self, x, y, w): 129 | raise NotImplementedError() 130 | 131 | 132 | class ModelWrapperDecorator(ModelWrapper): 133 | def __init__(self, model_wrapper, implemented_attributes=set()): 134 | self.model_wrapper = model_wrapper 135 | self.implemented_attributes = ( 136 | implemented_attributes | set(["model_wrapper"]) 137 | ) 138 | 139 | def __getattribute__(self, name): 140 | _getattr = object.__getattribute__ 141 | implemented_attributes = _getattr(self, "implemented_attributes") 142 | if name in implemented_attributes: 143 | return _getattr(self, name) 144 | else: 145 | model_wrapper = _getattr(self, "model_wrapper") 146 | return getattr(model_wrapper, name) 147 | 148 | 149 | class OracleWrapper(ModelWrapper): 150 | AVG_LOSS = 0 151 | LOSS = 1 152 | WEIGHTED_LOSS = 2 153 | SCORE = 3 154 | METRIC0 = 4 155 | 156 | FUSED_ACTIVATION_WARNING = ("[WARNING]: The last layer has a fused " 157 | "activation i.e. Dense(..., " 158 | "activation=\"sigmoid\").\nIn order for the " 159 | "preactivation to be automatically extracted " 160 | "use a separate activation layer (see " 161 | "examples).\n") 162 | 163 | def __init__(self, model, reweighting, score="loss", layer=None): 164 | self.reweighting = reweighting 165 | self.layer = self._gnorm_layer(model, layer) 166 | 167 | # Augment the model with reweighting, scoring etc 168 | # Save the new model and the training functions in member variables 169 | self._augment_model(model, score, reweighting) 170 | 171 | def _gnorm_layer(self, model, layer): 172 | # If we were given a layer then use it directly 173 | if isinstance(layer, Layer): 174 | return layer 175 | 176 | # If we were given a layer index extract the layer 177 | if isinstance(layer, int): 178 | return model.layers[layer] 179 | 180 | try: 181 | # Get the last or the previous to last layer depending on wether 182 | # the last has trainable weights 183 | skip_one = not bool(model.layers[-1].trainable_weights) 184 | last_layer = -2 if skip_one else -1 185 | 186 | # If the last layer has trainable weights that means that we cannot 187 | # automatically extract the preactivation tensor so we have to warn 188 | # them because they might be missing out or they might not even 189 | # have noticed 190 | if last_layer == -1: 191 | config = model.layers[-1].get_config() 192 | if config.get("activation", "linear") != "linear": 193 | sys.stderr.write(self.FUSED_ACTIVATION_WARNING) 194 | 195 | return model.layers[last_layer] 196 | except: 197 | # In case of an error then probably we are not using the gnorm 198 | # importance 199 | return None 200 | 201 | def _augment_model(self, model, score, reweighting): 202 | # Extract some info from the model 203 | loss = model.loss 204 | optimizer = model.optimizer.__class__(**model.optimizer.get_config()) 205 | output_shape = model.get_output_shape_at(0)[1:] 206 | if isinstance(loss, str) and loss.startswith("sparse"): 207 | output_shape = output_shape[:-1] + (1,) 208 | 209 | # Make sure that some stuff look ok 210 | assert not isinstance(loss, list) 211 | 212 | # We need to create two more inputs 213 | # 1. the targets 214 | # 2. the predicted scores 215 | y_true = Input(shape=output_shape) 216 | pred_score = Input(shape=(reweighting.weight_size,)) 217 | 218 | # Create a loss layer and a score layer 219 | loss_tensor = LossLayer(loss)([y_true, model.get_output_at(0)]) 220 | score_tensor = _get_scoring_layer( 221 | score, 222 | y_true, 223 | model.get_output_at(0), 224 | loss, 225 | self.layer, 226 | model 227 | ) 228 | 229 | # Create the sample weights 230 | weights = reweighting.weight_layer()([score_tensor, pred_score]) 231 | 232 | # Create the output 233 | weighted_loss = weighted_loss_model = multiply([loss_tensor, weights]) 234 | for l in model.losses: 235 | weighted_loss += l 236 | weighted_loss_mean = K.mean(weighted_loss) 237 | 238 | # Create the metric layers 239 | metrics = model.metrics or [] 240 | metrics = [ 241 | MetricLayer(metric)([y_true, model.get_output_at(0)]) 242 | for metric in metrics 243 | ] 244 | 245 | # Create a model for plotting and providing access to things such as 246 | # trainable_weights etc. 247 | new_model = Model( 248 | inputs=_tolist(model.get_input_at(0)) + [y_true, pred_score], 249 | outputs=[weighted_loss_model] 250 | ) 251 | 252 | # Build separate on_batch keras functions for scoring and training 253 | updates = optimizer.get_updates( 254 | weighted_loss_mean, 255 | new_model.trainable_weights 256 | ) 257 | metrics_updates = [] 258 | if hasattr(model, "metrics_updates"): 259 | metrics_updates = model.metrics_updates 260 | learning_phase = [] 261 | if weighted_loss_model._uses_learning_phase: 262 | learning_phase.append(K.learning_phase()) 263 | inputs = _tolist(model.get_input_at(0)) + [y_true, pred_score] + \ 264 | learning_phase 265 | outputs = [ 266 | weighted_loss_mean, 267 | loss_tensor, 268 | weighted_loss, 269 | score_tensor 270 | ] + metrics 271 | 272 | train_on_batch = K.function( 273 | inputs=inputs, 274 | outputs=outputs, 275 | updates=updates + model.updates + metrics_updates 276 | ) 277 | evaluate_on_batch = K.function( 278 | inputs=inputs, 279 | outputs=outputs, 280 | updates=model.state_updates + metrics_updates 281 | ) 282 | 283 | self.model = new_model 284 | self.optimizer = optimizer 285 | self.model.optimizer = optimizer 286 | self._train_on_batch = train_on_batch 287 | self._evaluate_on_batch = evaluate_on_batch 288 | 289 | def evaluate_batch(self, x, y): 290 | if len(y.shape) == 1: 291 | y = np.expand_dims(y, axis=1) 292 | dummy_weights = np.ones((y.shape[0], self.reweighting.weight_size)) 293 | inputs = _tolist(x) + [y, dummy_weights] + [0] 294 | outputs = self._evaluate_on_batch(inputs) 295 | 296 | signal("is.evaluate_batch").send(outputs) 297 | 298 | return np.hstack([outputs[self.LOSS]] + outputs[self.METRIC0:]) 299 | 300 | def score_batch(self, x, y): 301 | if len(y.shape) == 1: 302 | y = np.expand_dims(y, axis=1) 303 | dummy_weights = np.ones((y.shape[0], self.reweighting.weight_size)) 304 | inputs = _tolist(x) + [y, dummy_weights] + [0] 305 | outputs = self._evaluate_on_batch(inputs) 306 | 307 | return outputs[self.SCORE].ravel() 308 | 309 | def train_batch(self, x, y, w): 310 | if len(y.shape) == 1: 311 | y = np.expand_dims(y, axis=1) 312 | 313 | # train on a single batch 314 | outputs = self._train_on_batch(_tolist(x) + [y, w, 1]) 315 | 316 | # Add the outputs in a tuple to send to whoever is listening 317 | result = ( 318 | outputs[self.WEIGHTED_LOSS], 319 | outputs[self.METRIC0:], 320 | outputs[self.SCORE] 321 | ) 322 | signal("is.training").send(result) 323 | 324 | return result 325 | 326 | 327 | class SVRGWrapper(ModelWrapper): 328 | """Train using SVRG.""" 329 | def __init__(self, model): 330 | self._augment(model) 331 | 332 | def _augment(self, model): 333 | # TODO: There is a lot of overlap with the OracleWrapper, merge some 334 | # functionality into a separate function or a parent class 335 | 336 | # Extract info from the model 337 | loss_function = model.loss 338 | output_shape = model.get_output_shape_at(0)[1:] 339 | 340 | # Create two identical models one with the current weights and one with 341 | # the snapshot of the weights 342 | self.model = model 343 | self._snapshot = clone_model(model) 344 | 345 | # Create the target variable and compute the losses and the metrics 346 | inputs = [ 347 | Input(shape=K.int_shape(x)[1:]) 348 | for x in _tolist(model.get_input_at(0)) 349 | ] 350 | model_output = self.model(inputs) 351 | snapshot_output = self._snapshot(inputs) 352 | y_true = Input(shape=output_shape) 353 | loss = LossLayer(loss_function)([y_true, model_output]) 354 | loss_snapshot = LossLayer(loss_function)([y_true, snapshot_output]) 355 | metrics = self.model.metrics or [] 356 | metrics = [ 357 | MetricLayer(metric)([y_true, model_output]) 358 | for metric in metrics 359 | ] 360 | 361 | # Make a set of variables that will be holding the batch gradient of 362 | # the snapshot 363 | self._batch_grad = [ 364 | K.zeros(K.int_shape(p)) 365 | for p in self.model.trainable_weights 366 | ] 367 | 368 | # Create an optimizer that computes the variance reduced gradients and 369 | # get the updates 370 | loss_mean = K.mean(loss) 371 | loss_snapshot_mean = K.mean(loss_snapshot) 372 | optimizer, updates = self._get_updates( 373 | loss_mean, 374 | loss_snapshot_mean, 375 | self._batch_grad 376 | ) 377 | 378 | # Create the training function and gradient computation function 379 | metrics_updates = [] 380 | if hasattr(self.model, "metrics_updates"): 381 | metrics_updates = self.model.metrics_updates 382 | learning_phase = [] 383 | if loss._uses_learning_phase: 384 | learning_phase.append(K.learning_phase()) 385 | inputs = inputs + [y_true] + learning_phase 386 | outputs = [loss_mean, loss] + metrics 387 | 388 | train_on_batch = K.function( 389 | inputs=inputs, 390 | outputs=outputs, 391 | updates=updates + self.model.updates + metrics_updates 392 | ) 393 | evaluate_on_batch = K.function( 394 | inputs=inputs, 395 | outputs=outputs, 396 | updates=self.model.state_updates + metrics_updates 397 | ) 398 | get_grad = K.function( 399 | inputs=inputs, 400 | outputs=K.gradients(loss_mean, self.model.trainable_weights), 401 | updates=self.model.updates 402 | ) 403 | 404 | self.optimizer = optimizer 405 | self._train_on_batch = train_on_batch 406 | self._evaluate_on_batch = evaluate_on_batch 407 | self._get_grad = get_grad 408 | 409 | def _get_updates(self, loss, loss_snapshot, batch_grad): 410 | model = self.model 411 | snapshot = self._snapshot 412 | class Optimizer(self.model.optimizer.__class__): 413 | def get_gradients(self, *args): 414 | grad = K.gradients(loss, model.trainable_weights) 415 | grad_snapshot = K.gradients( 416 | loss_snapshot, 417 | snapshot.trainable_weights 418 | ) 419 | 420 | return [ 421 | g - gs + bg 422 | for g, gs, bg in zip(grad, grad_snapshot, batch_grad) 423 | ] 424 | 425 | optimizer = Optimizer(**self.model.optimizer.get_config()) 426 | return optimizer, \ 427 | optimizer.get_updates(loss, self.model.trainable_weights) 428 | 429 | def evaluate_batch(self, x, y): 430 | outputs = self._evaluate_on_batch(_tolist(x) + [y, 0]) 431 | signal("is.evaluate_batch").send(outputs) 432 | 433 | return np.hstack(outputs[1:]) 434 | 435 | def score_batch(self, x, y): 436 | raise NotImplementedError() 437 | 438 | def train_batch(self, x, y, w): 439 | outputs = self._train_on_batch(_tolist(x) + [y, 1]) 440 | 441 | result = ( 442 | outputs[0], # mean loss 443 | outputs[2:], # metrics 444 | outputs[1] # loss per sample 445 | ) 446 | signal("is.training").send(result) 447 | 448 | return result 449 | 450 | def update_grad(self, sample_generator): 451 | sample_generator = iter(sample_generator) 452 | x, y = next(sample_generator) 453 | N = len(y) 454 | gradient_sum = self._get_grad(_tolist(x) + [y, 1]) 455 | for g_sum in gradient_sum: 456 | g_sum *= N 457 | for x, y in sample_generator: 458 | grads = self._get_grad(_tolist(x) + [y, 1]) 459 | n = len(y) 460 | for g_sum, g in zip(gradient_sum, grads): 461 | g_sum += g*n 462 | N += len(y) 463 | for g_sum in gradient_sum: 464 | g_sum /= N 465 | 466 | K.batch_set_value(zip(self._batch_grad, gradient_sum)) 467 | self._snapshot.set_weights(self.model.get_weights()) 468 | 469 | 470 | class KatyushaWrapper(SVRGWrapper): 471 | """Implement Katyusha training on top of plain SVRG.""" 472 | def __init__(self, model, t1=0.5, t2=0.5): 473 | self.t1 = K.variable(t1, name="tau1") 474 | self.t2 = K.variable(t2, name="tau2") 475 | 476 | super(KatyushaWrapper, self).__init__(model) 477 | 478 | def _get_updates(self, loss, loss_snapshot, batch_grad): 479 | optimizer = self.model.optimizer 480 | t1, t2 = self.t1, self.t2 481 | lr = optimizer.lr 482 | 483 | # create copies and local copies of the parameters 484 | shapes = [K.int_shape(p) for p in self.model.trainable_weights] 485 | x_tilde = [p for p in self._snapshot.trainable_weights] 486 | z = [K.variable(p) for p in self.model.trainable_weights] 487 | y = [K.variable(p) for p in self.model.trainable_weights] 488 | 489 | # Get the gradients 490 | grad = K.gradients(loss, self.model.trainable_weights) 491 | grad_snapshot = K.gradients( 492 | loss_snapshot, 493 | self._snapshot.trainable_weights 494 | ) 495 | 496 | # Collect the updates 497 | p_plus = [ 498 | t1*zi + t2*x_tildei + (1-t1-t2)*yi 499 | for zi, x_tildei, yi in 500 | zip(z, x_tilde, y) 501 | ] 502 | vr_grad = [ 503 | gi + bg - gsi 504 | for gi, bg, gsi in zip(grad, grad_snapshot, batch_grad) 505 | ] 506 | updates = [ 507 | K.update(yi, xi - lr * gi) 508 | for yi, xi, gi in zip(y, p_plus, vr_grad) 509 | ] + [ 510 | K.update(zi, zi - lr * gi / t1) 511 | for zi, xi, gi in zip(z, p_plus, vr_grad) 512 | ] + [ 513 | K.update(p, xi) 514 | for p, xi in zip(self.model.trainable_weights, p_plus) 515 | ] 516 | 517 | return optimizer, updates 518 | -------------------------------------------------------------------------------- /importance_sampling/models.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | from keras import backend as K 7 | from keras.layers import \ 8 | Activation, \ 9 | AveragePooling2D, \ 10 | BatchNormalization, \ 11 | Convolution2D, \ 12 | Dense, \ 13 | Dropout, \ 14 | ELU, \ 15 | Embedding, \ 16 | Flatten, \ 17 | GlobalAveragePooling2D, \ 18 | Input, \ 19 | LSTM, CuDNNLSTM, \ 20 | MaxPooling2D, \ 21 | Masking, \ 22 | TimeDistributed, \ 23 | add, \ 24 | concatenate 25 | from keras.models import Model, Sequential 26 | from keras.optimizers import SGD 27 | from keras.regularizers import l2 28 | 29 | from .layers import \ 30 | BatchRenormalization, \ 31 | LayerNormalization, \ 32 | TripletLossLayer 33 | from .pretrained import ResNet50, DenseNet121 34 | from .utils.functional import partial 35 | 36 | 37 | def build_small_nn(input_shape, output_size): 38 | model = Sequential([ 39 | Dense(40, activation="tanh", input_shape=input_shape), 40 | Dense(40, activation="tanh"), 41 | Dense(output_size), 42 | Activation("softmax") 43 | ]) 44 | 45 | model.compile( 46 | loss="categorical_crossentropy", 47 | optimizer="adam", 48 | metrics=["accuracy"] 49 | ) 50 | 51 | return model 52 | 53 | 54 | def build_svrg_nn(input_shape, output_size): 55 | model = Sequential([ 56 | Flatten(input_shape=input_shape), 57 | Dense(100, activation="tanh"), 58 | Dense(10), 59 | Activation("softmax") 60 | ]) 61 | 62 | model.compile( 63 | loss="categorical_crossentropy", 64 | optimizer="adam", 65 | metrics=["accuracy"] 66 | ) 67 | 68 | return model 69 | 70 | 71 | def build_cnn(input_shape, output_size): 72 | kwargs = { 73 | "kernel_size": 3, 74 | "activation": "relu", 75 | "padding": "same" 76 | } 77 | model = Sequential([ 78 | # conv1_* 79 | Convolution2D(64, input_shape=input_shape, **kwargs), 80 | BatchRenormalization(), 81 | Convolution2D(64, **kwargs), 82 | BatchRenormalization(), 83 | MaxPooling2D(pool_size=(2, 2)), 84 | Dropout(0.25), 85 | 86 | # conv2_* 87 | Convolution2D(128, **kwargs), 88 | BatchRenormalization(), 89 | Convolution2D(128, **kwargs), 90 | BatchRenormalization(), 91 | MaxPooling2D(pool_size=(2, 2)), 92 | Dropout(0.25), 93 | 94 | # conv3_* 95 | Convolution2D(256, **kwargs), 96 | BatchRenormalization(), 97 | Convolution2D(256, **kwargs), 98 | BatchRenormalization(), 99 | MaxPooling2D(pool_size=(2, 2)), 100 | Dropout(0.25), 101 | 102 | # Fully connected 103 | Flatten(), 104 | Dense(1024), 105 | Activation("relu"), 106 | Dropout(0.5), 107 | Dense(512), 108 | Activation("relu"), 109 | Dropout(0.5), 110 | Dense(output_size), 111 | Activation("softmax") 112 | ]) 113 | 114 | model.compile( 115 | loss="categorical_crossentropy", 116 | optimizer="adam", 117 | metrics=["accuracy"] 118 | ) 119 | 120 | return model 121 | 122 | 123 | def build_small_cnn(input_shape, output_size): 124 | model = Sequential([ 125 | # conv1_* 126 | Convolution2D(32, kernel_size=3, padding="same", 127 | input_shape=input_shape), 128 | Activation("relu"), 129 | Convolution2D(32, kernel_size=3, padding="same"), 130 | Activation("relu"), 131 | MaxPooling2D(pool_size=(2, 2)), 132 | 133 | # conv2_* 134 | Convolution2D(64, kernel_size=3, padding="same"), 135 | Activation("relu"), 136 | Convolution2D(64, kernel_size=3, padding="same"), 137 | Activation("relu"), 138 | MaxPooling2D(pool_size=(2, 2)), 139 | 140 | # Fully connected 141 | Flatten(), 142 | Dense(512), 143 | Activation("relu"), 144 | Dense(512), 145 | Activation("relu"), 146 | Dense(output_size), 147 | Activation("softmax") 148 | ]) 149 | 150 | model.compile( 151 | loss="categorical_crossentropy", 152 | optimizer="adam", 153 | metrics=["accuracy"] 154 | ) 155 | 156 | return model 157 | 158 | 159 | def build_lr(input_shape, output_size): 160 | model = Sequential([ 161 | Flatten(input_shape=input_shape), 162 | Dense(output_size), 163 | Activation("softmax") 164 | ]) 165 | 166 | model.compile( 167 | loss="categorical_crossentropy", 168 | optimizer="adam", 169 | metrics=["accuracy"] 170 | ) 171 | 172 | return model 173 | 174 | 175 | def build_all_conv_nn(input_shape, output_size): 176 | """Build a small variation of the best performing network from 177 | 'Springenberg, Jost Tobias, et al. "Striving for simplicity: The all 178 | convolutional net." arXiv preprint arXiv:1412.6806 (2014)' which should 179 | achieve approximately 91% in CIFAR-10. 180 | """ 181 | kwargs = { 182 | "activation": "relu", 183 | "border_mode": "same" 184 | } 185 | model = Sequential([ 186 | # conv1 187 | Convolution2D(96, 3, 3, input_shape=input_shape, **kwargs), 188 | BatchRenormalization(), 189 | Convolution2D(96, 3, 3, **kwargs), 190 | BatchRenormalization(), 191 | Convolution2D(96, 3, 3, subsample=(2, 2), **kwargs), 192 | BatchRenormalization(), 193 | Dropout(0.25), 194 | 195 | # conv2 196 | Convolution2D(192, 3, 3, **kwargs), 197 | BatchRenormalization(), 198 | Convolution2D(192, 3, 3, **kwargs), 199 | BatchRenormalization(), 200 | Convolution2D(192, 3, 3, subsample=(2, 2), **kwargs), 201 | BatchRenormalization(), 202 | Dropout(0.25), 203 | 204 | # conv3 205 | Convolution2D(192, 1, 1, **kwargs), 206 | BatchRenormalization(), 207 | Dropout(0.25), 208 | Convolution2D(output_size, 1, 1, **kwargs), 209 | GlobalAveragePooling2D(), 210 | Activation("softmax") 211 | ]) 212 | 213 | model.compile( 214 | loss="categorical_crossentropy", 215 | optimizer=SGD(momentum=0.9), 216 | metrics=["accuracy"] 217 | ) 218 | 219 | return model 220 | 221 | 222 | def build_elu_cnn(input_shape, output_size): 223 | """Build a variation of the CNN implemented in the ELU paper. 224 | 225 | https://arxiv.org/abs/1511.07289 226 | """ 227 | def layers(n, channels, kernel): 228 | return sum( 229 | ( 230 | [ 231 | Convolution2D( 232 | channels, 233 | kernel_size=kernel, 234 | padding="same" 235 | ), 236 | ELU() 237 | ] 238 | for i in range(n) 239 | ), 240 | [] 241 | ) 242 | 243 | model = Sequential( 244 | [ 245 | Convolution2D(384, kernel_size=3, padding="same", 246 | input_shape=input_shape) 247 | ] + 248 | layers(1, 384, 3) + 249 | [MaxPooling2D(pool_size=(2, 2))] + 250 | layers(1, 384, 1) + layers(1, 384, 2) + layers(2, 640, 2) + 251 | [MaxPooling2D(pool_size=(2, 2))] + 252 | layers(1, 640, 1) + layers(3, 768, 2) + 253 | [MaxPooling2D(pool_size=(2, 2))] + 254 | layers(1, 768, 1) + layers(2, 896, 2) + 255 | [MaxPooling2D(pool_size=(2, 2))] + 256 | layers(1, 896, 3) + layers(2, 1024, 2) + 257 | [ 258 | MaxPooling2D(pool_size=(2, 2)), 259 | Convolution2D(output_size, kernel_size=1, padding="same"), 260 | GlobalAveragePooling2D(), 261 | Activation("softmax") 262 | ] 263 | ) 264 | 265 | model.compile( 266 | optimizer=SGD(momentum=0.9), 267 | loss="categorical_crossentropy", 268 | metrics=["accuracy"] 269 | ) 270 | 271 | return model 272 | 273 | 274 | def build_lstm_lm(input_shape, output_size): 275 | # LM datasets will report the vocab_size as output_size 276 | vocab_size = output_size 277 | 278 | model = Sequential([ 279 | Embedding(vocab_size + 1, 64, mask_zero=True, 280 | input_length=input_shape[0]), 281 | LSTM(256, unroll=True, return_sequences=True), 282 | LSTM(256, unroll=True), 283 | Dense(output_size), 284 | Activation("softmax") 285 | ]) 286 | 287 | model.compile( 288 | optimizer="adam", 289 | loss="sparse_categorical_crossentropy", 290 | metrics=["accuracy"] 291 | ) 292 | 293 | return model 294 | 295 | 296 | def build_lstm_lm2(input_shape, output_size): 297 | # LM datasets will report the vocab_size as output_size 298 | vocab_size = output_size 299 | 300 | model = Sequential([ 301 | Embedding(vocab_size + 1, 128, mask_zero=True, 302 | input_length=input_shape[0]), 303 | LSTM(384, unroll=True, return_sequences=True), 304 | Dropout(0.5), 305 | LSTM(384, unroll=True), 306 | Dropout(0.5), 307 | Dense(output_size), 308 | Activation("softmax") 309 | ]) 310 | 311 | model.compile( 312 | optimizer="adam", 313 | loss="sparse_categorical_crossentropy", 314 | metrics=["accuracy"] 315 | ) 316 | 317 | return model 318 | 319 | 320 | def build_lstm_lm3(input_shape, output_size): 321 | # LM datasets will report the vocab_size as output_size 322 | vocab_size = output_size 323 | 324 | model = Sequential([ 325 | Embedding(vocab_size + 1, 128, mask_zero=True, 326 | input_length=input_shape[0]), 327 | LSTM(650, unroll=True, return_sequences=True), 328 | Dropout(0.5), 329 | LSTM(650, unroll=True), 330 | Dropout(0.5), 331 | Dense(output_size), 332 | Activation("softmax") 333 | ]) 334 | 335 | model.compile( 336 | optimizer="adam", 337 | loss="sparse_categorical_crossentropy", 338 | metrics=["accuracy"] 339 | ) 340 | 341 | return model 342 | 343 | 344 | def build_lstm_timit(input_shape, output_size): 345 | """Build a simple LSTM to classify the phonemes in the TIMIT dataset""" 346 | model = Sequential([ 347 | LSTM(256, unroll=True, input_shape=input_shape), 348 | Dense(output_size), 349 | Activation("softmax") 350 | ]) 351 | 352 | model.compile( 353 | optimizer="adam", 354 | loss="sparse_categorical_crossentropy", 355 | metrics=["accuracy"] 356 | ) 357 | 358 | return model 359 | 360 | 361 | def build_lstm_mnist(input_shape, output_size): 362 | """Build a small LSTM to recognize MNIST digits as permuted sequences""" 363 | model = Sequential([ 364 | CuDNNLSTM(128, input_shape=input_shape), 365 | Dense(output_size), 366 | Activation("softmax") 367 | ]) 368 | 369 | model.compile( 370 | optimizer="adam", 371 | loss="categorical_crossentropy", 372 | metrics=["accuracy"] 373 | ) 374 | 375 | return model 376 | 377 | 378 | def build_small_cnn_squared(input_shape, output_size): 379 | def squared_categorical_crossent(y_true, y_pred): 380 | return K.square(K.categorical_crossentropy(y_pred, y_true)) 381 | 382 | model = build_small_cnn(input_shape, output_size) 383 | model.compile( 384 | optimizer=model.optimizer, 385 | loss=squared_categorical_crossent, 386 | metrics=model.metrics 387 | ) 388 | 389 | return model 390 | 391 | 392 | def wide_resnet(L, k, drop_rate=0.0): 393 | """Implement the WRN-L-k from 'Wide Residual Networks' BMVC 2016""" 394 | def wide_resnet_impl(input_shape, output_size): 395 | def conv(channels, strides, 396 | params=dict(padding="same", use_bias=False, 397 | kernel_regularizer=l2(5e-4))): 398 | def inner(x): 399 | x = LayerNormalization()(x) 400 | x = Activation("relu")(x) 401 | x = Convolution2D(channels, 3, strides=strides, **params)(x) 402 | x = Dropout(drop_rate)(x) if drop_rate > 0 else x 403 | x = LayerNormalization()(x) 404 | x = Activation("relu")(x) 405 | x = Convolution2D(channels, 3, **params)(x) 406 | return x 407 | return inner 408 | 409 | def resize(x, shape): 410 | if K.int_shape(x) == shape: 411 | return x 412 | channels = shape[3 if K.image_data_format() == "channels_last" else 1] 413 | strides = K.int_shape(x)[2] // shape[2] 414 | return Convolution2D( 415 | channels, 1, padding="same", use_bias=False, strides=strides 416 | )(x) 417 | 418 | def block(channels, k, n, strides): 419 | def inner(x): 420 | for i in range(n): 421 | x2 = conv(channels*k, strides if i == 0 else 1)(x) 422 | x = add([resize(x, K.int_shape(x2)), x2]) 423 | return x 424 | return inner 425 | 426 | # According to the paper L = 6*n+4 427 | n = int((L-4)/6) 428 | 429 | group0 = Convolution2D(16, 3, padding="same", use_bias=False, 430 | kernel_regularizer=l2(5e-4)) 431 | group1 = block(16, k, n, 1) 432 | group2 = block(32, k, n, 2) 433 | group3 = block(64, k, n, 2) 434 | 435 | x_in = x = Input(shape=input_shape) 436 | x = group0(x) 437 | x = group1(x) 438 | x = group2(x) 439 | x = group3(x) 440 | 441 | x = LayerNormalization()(x) 442 | x = Activation("relu")(x) 443 | x = GlobalAveragePooling2D()(x) 444 | x = Dense(output_size, kernel_regularizer=l2(5e-4))(x) 445 | y = Activation("softmax")(x) 446 | 447 | model = Model(inputs=x_in, outputs=y) 448 | model.compile( 449 | loss="categorical_crossentropy", 450 | optimizer="adam", 451 | metrics=["accuracy"] 452 | ) 453 | 454 | return model 455 | return wide_resnet_impl 456 | 457 | 458 | def pretrained(Net, weights="imagenet"): 459 | def pretrained_impl(input_shape, output_size): 460 | net = Net(weights=weights, input_shape=input_shape, 461 | output_size=output_size) 462 | net.compile( 463 | loss="categorical_crossentropy", 464 | optimizer="adam", 465 | metrics=["accuracy"] 466 | ) 467 | 468 | return net 469 | 470 | return pretrained_impl 471 | 472 | 473 | def face(modelf, embedding): 474 | def face_impl(input_shape, output_size): 475 | x = Input(shape=input_shape) 476 | e = modelf(input_shape, embedding)(x) 477 | y = Dense(output_size)(e) 478 | y = Activation("softmax")(y) 479 | 480 | model = Model(x, y) 481 | model.compile( 482 | "adam", 483 | "sparse_categorical_crossentropy", 484 | metrics=["accuracy"] 485 | ) 486 | 487 | return model 488 | 489 | return face_impl 490 | 491 | 492 | def triplet(modelf): 493 | def triplet_loss(y_true, y_pred): 494 | return K.relu(y_true - y_pred) 495 | 496 | def neg_minus_pos(y_true, y_pred): 497 | return y_pred 498 | 499 | def triplet_impl(input_shape, output_size): 500 | inputs = [Input(shape=shape) for shape in input_shape] 501 | net = modelf(input_shape[0], output_size) 502 | output = concatenate(list(map(net, inputs))) 503 | 504 | y = TripletLossLayer()(output) 505 | 506 | model = Model(inputs, y, name="triplet({})".format(net.name)) 507 | model.compile( 508 | optimizer="adam", 509 | loss=triplet_loss, 510 | metrics=[neg_minus_pos] 511 | ) 512 | 513 | return model 514 | 515 | return triplet_impl 516 | 517 | 518 | def get(name): 519 | models = { 520 | "small_nn": build_small_nn, 521 | "svrg_nn": build_svrg_nn, 522 | "small_cnn": build_small_cnn, 523 | "small_cnn_sq": build_small_cnn_squared, 524 | "cnn": build_cnn, 525 | "all_conv": build_all_conv_nn, 526 | "elu_cnn": build_elu_cnn, 527 | "lstm_lm": build_lstm_lm, 528 | "lstm_lm2": build_lstm_lm2, 529 | "lstm_lm3": build_lstm_lm3, 530 | "lstm_timit": build_lstm_timit, 531 | "lstm_mnist": build_lstm_mnist, 532 | "wide_resnet_16_4": wide_resnet(16, 4), 533 | "wide_resnet_16_4_dropout": wide_resnet(16, 4, 0.3), 534 | "wide_resnet_28_2": wide_resnet(28, 2), 535 | "wide_resnet_28_10": wide_resnet(28, 10), 536 | "wide_resnet_28_10_dropout": wide_resnet(28, 10, 0.3), 537 | "pretrained_resnet50": pretrained(partial(ResNet50, softmax=True)), 538 | "triplet_pre_resnet50": triplet(pretrained(ResNet50)), 539 | "pretrained_densenet121": pretrained(DenseNet121), 540 | "triplet_pre_densenet121": triplet(pretrained(DenseNet121)), 541 | "face_pre_resnet50": face(pretrained(ResNet50), 128) 542 | } 543 | return models[name] 544 | -------------------------------------------------------------------------------- /importance_sampling/pretrained.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | """Replace the models provided by the Keras applications module""" 7 | 8 | from keras.layers import \ 9 | Activation, \ 10 | AveragePooling2D, \ 11 | Conv2D, \ 12 | Dense, \ 13 | Flatten, \ 14 | Input, \ 15 | MaxPooling2D, \ 16 | add 17 | from keras.models import Model 18 | from keras.utils.data_utils import get_file 19 | 20 | from .layers import StatsBatchNorm 21 | 22 | 23 | RESNET50_WEIGHTS_PATH = ("https://github.com/fchollet/deep-learning-models/" 24 | "releases/download/v0.2/" 25 | "resnet50_weights_tf_dim_ordering_tf_kernels.h5") 26 | 27 | 28 | def ResNet50(weights="imagenet", input_shape=(224, 224, 3), output_size=1000, 29 | softmax=False, norm_layer=StatsBatchNorm): 30 | def block(x_in, kernel, filters, strides, stage, block, shortcut=False): 31 | conv_name = "res" + str(stage) + block + "_branch" 32 | bn_name = "bn" + str(stage) + block + "_branch" 33 | 34 | x = Conv2D(filters[0], 1, strides=strides, name=conv_name+"2a")(x_in) 35 | x = norm_layer(name=bn_name+"2a")(x) 36 | x = Activation("relu")(x) 37 | x = Conv2D(filters[1], kernel, padding="same", name=conv_name+"2b")(x) 38 | x = norm_layer(name=bn_name+"2b")(x) 39 | x = Activation("relu")(x) 40 | x = Conv2D(filters[2], 1, name=conv_name+"2c")(x) 41 | x = norm_layer(name=bn_name+"2c")(x) 42 | 43 | if shortcut: 44 | s = Conv2D(filters[2], 1, strides=strides, name=conv_name+"1")(x_in) 45 | s = norm_layer(name=bn_name+"1")(s) 46 | else: 47 | s = x_in 48 | 49 | return Activation("relu")(add([x, s])) 50 | 51 | x_in = Input(shape=input_shape) 52 | x = Conv2D(64, 7, strides=2, padding="same", name="conv1")(x_in) 53 | x = norm_layer(name="bn_conv1")(x) 54 | x = Activation("relu")(x) 55 | x = MaxPooling2D((3, 3), strides=(2, 2))(x) 56 | 57 | x = block(x, 3, [64, 64, 256], 1, 2, "a", shortcut=True) 58 | x = block(x, 3, [64, 64, 256], 1, 2, "b") 59 | x = block(x, 3, [64, 64, 256], 1, 2, "c") 60 | 61 | x = block(x, 3, [128, 128, 512], 2, 3, "a", shortcut=True) 62 | x = block(x, 3, [128, 128, 512], 1, 3, "b") 63 | x = block(x, 3, [128, 128, 512], 1, 3, "c") 64 | x = block(x, 3, [128, 128, 512], 1, 3, "d") 65 | 66 | x = block(x, 3, [256, 256, 1024], 2, 4, "a", shortcut=True) 67 | x = block(x, 3, [256, 256, 1024], 1, 4, "b") 68 | x = block(x, 3, [256, 256, 1024], 1, 4, "c") 69 | x = block(x, 3, [256, 256, 1024], 1, 4, "d") 70 | x = block(x, 3, [256, 256, 1024], 1, 4, "e") 71 | x = block(x, 3, [256, 256, 1024], 1, 4, "f") 72 | 73 | x = block(x, 3, [512, 512, 2048], 2, 5, "a", shortcut=True) 74 | x = block(x, 3, [512, 512, 2048], 1, 5, "b") 75 | x = block(x, 3, [512, 512, 2048], 1, 5, "c") 76 | 77 | x = AveragePooling2D((7, 7), name="avg_pool")(x) 78 | x = Flatten()(x) 79 | x = Dense(output_size, name="fc"+str(output_size))(x) 80 | if softmax: 81 | x = Activation("softmax")(x) 82 | 83 | model = Model(x_in, x, name="resnet50") 84 | 85 | if weights == "imagenet": 86 | weights_path = get_file( 87 | "resnet50_weights_tf_dim_ordering_tf_kernels.h5", 88 | RESNET50_WEIGHTS_PATH, 89 | cache_subdir="models", 90 | md5_hash="a7b3fe01876f51b976af0dea6bc144eb" 91 | ) 92 | model.load_weights(weights_path, by_name=True) 93 | 94 | return model 95 | 96 | 97 | def DenseNet121(*args, **kwargs): 98 | raise NotImplementedError() 99 | -------------------------------------------------------------------------------- /importance_sampling/reweighting.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | from keras import backend as K 7 | from keras.layers import Layer 8 | import numpy as np 9 | 10 | 11 | class ReweightingPolicy(object): 12 | """ReweightingPolicy defines how we weigh the samples given their 13 | importance. 14 | 15 | Each policy should provide 16 | 1. A layer implementation for use with Keras models 17 | 2. A python implementation for use with the samplers 18 | """ 19 | def weight_layer(self): 20 | """Return a layer that accepts the scores and the return value of the 21 | sample_weights() method and produces new sample weights.""" 22 | raise NotImplementedError() 23 | 24 | def sample_weights(self, idxs, scores): 25 | """Given the scores and the chosen indices return whatever is needed by 26 | the weight_layer to produce the sample weights.""" 27 | raise NotImplementedError() 28 | 29 | @property 30 | def weight_size(self): 31 | """Return how many numbers per sample make up the sample weights""" 32 | raise NotImplementedError() 33 | 34 | 35 | class AdjustedBiasedReweightingPolicy(ReweightingPolicy): 36 | """AdjustedBiasedReweightingPolicy adjusts the biased sample weights with 37 | the importance that has just been computed in the forward-backward pass. 38 | 39 | See AdjustedBiasedReweighting for details. 40 | """ 41 | def __init__(self, k=1.0): 42 | self.k = k 43 | 44 | def weight_layer(self): 45 | return AdjustedBiasedReweighting(self.k) 46 | 47 | def sample_weights(self, idxs, scores): 48 | N = len(scores) 49 | S1 = scores[np.setdiff1d(np.arange(N), idxs)].sum() 50 | 51 | return np.tile([float(N), float(S1)], (len(idxs), 1)) 52 | 53 | @property 54 | def weight_size(self): 55 | return 2 56 | 57 | 58 | class BiasedReweightingPolicy(ReweightingPolicy): 59 | """BiasedReweightingPolicy computes the sample weights before the 60 | forward-backward pass based on the sampling probabilities. It can introduce 61 | a bias that focuses on the hard examples when combined with the loss as an 62 | importance metric.""" 63 | def __init__(self, k=1.0): 64 | self.k = k 65 | 66 | def weight_layer(self): 67 | return ExternalReweighting() 68 | 69 | def sample_weights(self, idxs, scores): 70 | N = len(scores) 71 | s = scores[idxs] 72 | w = scores.sum() / N / s 73 | w_hat = w**self.k 74 | w_hat *= w.dot(s) / w_hat.dot(s) 75 | 76 | return w_hat[:, np.newaxis] 77 | 78 | @property 79 | def weight_size(self): 80 | return 1 81 | 82 | 83 | class NoReweightingPolicy(ReweightingPolicy): 84 | """Set all sample weights to 1.""" 85 | def weight_layer(self): 86 | return ExternalReweighting() 87 | 88 | def sample_weights(self, idxs, scores): 89 | return np.ones((len(idxs), 1)) 90 | 91 | @property 92 | def weight_size(self): 93 | return 1 94 | 95 | 96 | class CorrectingReweightingPolicy(ReweightingPolicy): 97 | """CorrectingReweightingPolicy aims to scale the sample weights according 98 | to the mistakes of the importance predictor 99 | 100 | Arguments 101 | --------- 102 | k: float 103 | The bias power used in all other reweighting schemes 104 | """ 105 | def __init__(self, k=1.0): 106 | self.k = k 107 | self._biased_reweighting = BiasedReweightingPolicy(k) 108 | 109 | def weight_layer(self): 110 | return CorrectingReweighting() 111 | 112 | def sample_weights(self, idxs, scores): 113 | w = self._biased_reweighting.sample_weights(idxs, scores) 114 | 115 | return np.hstack([w, scores[idxs][:, np.newaxis]]) 116 | 117 | @property 118 | def weight_size(self): 119 | return 2 120 | 121 | 122 | class AdjustedBiasedReweighting(Layer): 123 | """Implement a Keras layer that using the sum of the weights and the number 124 | of samples it recomputes the weights. 125 | 126 | The specifics are the following: 127 | 128 | Given B = {i_1, i_2, ..., i_|B|} the mini-batch idexes, \\hat{s_i} the 129 | predicted score of sample i and k the max-loss bias constant. 130 | 131 | S_1 = \\sum_{i \\notin B} \\hat{s_i} 132 | S_2 = \\sum_{i \\in B} s_i 133 | a_i^{(1)} = \\frac{1}{N} (\\frac{S1 + S2}{s_i})^k \\forall i \\in B 134 | a_i^{(2)} = \\frac{1}{N} \\frac{S1 + S2}{s_i} \\forall i \\in B 135 | t = \\frac{\\sum_{i} a_i^{(2)} s_i }{\\sum_{i} a_i^{(1)} s_i} 136 | a_i = t a_i^{(1)} \\forall i \\in B 137 | """ 138 | def __init__(self, k=1.0, **kwargs): 139 | self.k = k 140 | 141 | super(AdjustedBiasedReweighting, self).__init__(**kwargs) 142 | 143 | def build(self, input_shape): 144 | assert isinstance(input_shape, list) 145 | assert len(input_shape) == 2 146 | assert input_shape[0][1] == 1 147 | assert input_shape[1][1] == 2 148 | 149 | super(AdjustedBiasedReweighting, self).build(input_shape) 150 | 151 | def compute_output_shape(self, input_shape): 152 | return input_shape[0] 153 | 154 | def call(self, x): 155 | s, s_hat = x 156 | 157 | # Compute the variables defined in the class comment 158 | S2 = K.sum(s) 159 | S1 = s_hat[0, 1] 160 | N = s_hat[0, 0] 161 | 162 | # Compute the unbiased weights 163 | a2 = (S1 + S2) / N / s 164 | 165 | # Compute the biased weights and the scaling factor t 166 | a1 = K.pow(a2, self.k) 167 | sT = K.transpose(s) 168 | t = K.dot(sT, a2) / K.dot(sT, a1) 169 | 170 | return K.stop_gradient([a1 * t])[0] 171 | 172 | 173 | class ExternalReweighting(Layer): 174 | """Use the provided input as sample weights""" 175 | def build(self, input_shape): 176 | super(ExternalReweighting, self).build(input_shape) 177 | 178 | def compute_output_shape(self, input_shape): 179 | return input_shape[0] 180 | 181 | def call(self, x): 182 | return K.stop_gradient(x[1]) 183 | 184 | 185 | class CorrectingReweighting(Layer): 186 | """Use the provided weights and the score to correct sample weights that 187 | were computed with a very wrong predicted score""" 188 | def __init__(self, min_decrease=0, max_increase=2, **kwargs): 189 | self.min_decrease = min_decrease 190 | self.max_increase = max_increase 191 | 192 | super(CorrectingReweighting, self).__init__(**kwargs) 193 | 194 | def build(self, input_shape): 195 | super(CorrectingReweighting, self).build(input_shape) 196 | 197 | def compute_output_shape(self, input_shape): 198 | return input_shape[0] 199 | 200 | def call(self, x): 201 | s, x1 = x 202 | a = x1[:, :1] 203 | s_hat = x1[:, 1:2] 204 | 205 | # Rescale the weights, making sure we mostly scale down 206 | a_hat = a * K.clip(s_hat / s, self.min_decrease, self.max_increase) 207 | 208 | # Scale again so that the reported loss is comparable to the other ones 209 | t = 1 210 | #sT = K.transpose(s) 211 | #t = K.dot(sT, a) / K.dot(sT, a_hat) 212 | 213 | return K.stop_gradient([a_hat * t])[0] 214 | 215 | 216 | UNWEIGHTED = NoReweightingPolicy() 217 | UNBIASED = BiasedReweightingPolicy(k=1.0) 218 | -------------------------------------------------------------------------------- /importance_sampling/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | -------------------------------------------------------------------------------- /importance_sampling/utils/functional.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | """This module should provide some useful functions that I wish existed in the 7 | functools standard package. But then again standard library developers know 8 | better""" 9 | 10 | 11 | placeholder = object() 12 | x = placeholder 13 | _ = placeholder 14 | ___ = placeholder 15 | 16 | 17 | def partial(func, *args, **kwargs): 18 | """Partially define the argument list of func 19 | 20 | Example: 21 | >>> import functional as f 22 | >>> from operator import add 23 | >>> add2 = f.partial(add, f._, 2) 24 | >>> add2(5) 25 | 7 26 | >>> add2(8) 27 | 10 28 | >>> join = f.compose(' '.join, lambda *args: list(args)) 29 | >>> create_sentence = f.partial(join, "The", f._, f._, "fox") 30 | >>> create_sentence("quick", "brown") 31 | 'The quick brown fox' 32 | >>> import tempfile 33 | >>> fd, file_name = tempfile.mkstemp() 34 | >>> opener = f.partial(open, f._, "wb") 35 | >>> with opener(file_name) as out: 36 | ... out.write("Hello World!") 37 | ... 38 | >>> with opener(file_name) as out: 39 | ... out.write("Hello World!") 40 | ... 41 | >>> import os 42 | >>> os.remove(file_name) 43 | >>> 44 | """ 45 | def inner(*args2, **kwargs2): 46 | # compute the kwargs to pass to the function 47 | fkwargs = kwargs.copy() 48 | fkwargs.update(kwargs2) 49 | 50 | # walk through the fargs list replacing each placeholder with an arg 51 | args2 = list(args2) 52 | fargs = list(args) 53 | for i in range(len(fargs)): 54 | if fargs[i] is placeholder: 55 | fargs[i] = args2.pop(0) 56 | fargs += args2 57 | 58 | # apply fargs and fkwargs to func 59 | return func(*fargs, **fkwargs) 60 | return inner 61 | 62 | 63 | def compose(*functions): 64 | """Compose the functions with the following rules 65 | 66 | If a function returns a tuple the parameters are passed as positional 67 | arguments. 68 | 69 | Example: 70 | >>> import functional as f 71 | >>> add2 = f.compose(lambda x: x+1, lambda x: x+1) 72 | >>> add2(2) 73 | 4 74 | >>> testf = f.compose(lambda x: x**2, lambda x: 2*x) 75 | >>> testf(1) 76 | 4 77 | >>> testf = f.compose(lambda x: x**2, lambda a,b: a+b) 78 | >>> testf(1, 1) 79 | 4 80 | >>> 81 | """ 82 | def inner(*args, **kwargs): 83 | for f in reversed(functions): 84 | args = f(*args, **kwargs) 85 | if isinstance(args, dict): 86 | kwargs = args 87 | args = tuple() 88 | elif not isinstance(args, tuple): 89 | args = (args, ) 90 | kwargs = {} 91 | if not isinstance(args, tuple) or len(args) > 1: 92 | return args 93 | else: 94 | return args[0] 95 | return inner 96 | 97 | 98 | def call(f, *args, **kwargs): 99 | """Call f with args and kwargs""" 100 | return f(*args, **kwargs) 101 | 102 | 103 | def attr(a): 104 | """Return a function that returns the attribute 'a'. 105 | 106 | It would be equal to partial(getattr(___, a)) but I define it separately 107 | for better documentation. 108 | """ 109 | def inner(o): 110 | return getattr(o, a) 111 | return inner 112 | 113 | 114 | def identity(x): 115 | return x 116 | -------------------------------------------------------------------------------- /importance_sampling/utils/keras_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | from functools import reduce 7 | from math import ceil 8 | 9 | import h5py 10 | from keras import backend as K 11 | from keras.utils.data_utils import Sequence 12 | import numpy as np 13 | 14 | 15 | def weights_from_hdf5(f): 16 | """Extract all the weights from an h5py File or Group""" 17 | if "weight_names" in f.attrs: 18 | for n in f.attrs["weight_names"]: 19 | yield n, f[n] 20 | else: 21 | for k in f.keys(): 22 | for n, w in weights_from_hdf5(f[k]): 23 | yield n, w 24 | 25 | 26 | def possible_weight_names(name, n=10): 27 | name = name.decode() 28 | yield name 29 | parts = name.split("/") 30 | for i in range(1, n+1): 31 | yield str("{}_{}/{}".format(parts[0], i, parts[1])) 32 | 33 | 34 | def load_weights_by_name(f, layers): 35 | """Load the weights by name from the h5py file to the model""" 36 | # If f is not an h5py thing try to open it 37 | if not isinstance(f, (h5py.File, h5py.Group)): 38 | with h5py.File(f, "r") as h5f: 39 | return load_weights_by_name(h5f, layers) 40 | 41 | # Extract all the weights from the layers/model 42 | if not isinstance(layers, list): 43 | layers = layers.layers 44 | weights = dict(reduce( 45 | lambda a, x: a + [(w.name, w) for w in x.weights], 46 | layers, 47 | [] 48 | )) 49 | 50 | # Loop through all the possible layer weights in the file and make a list 51 | # of updates 52 | updates = [] 53 | updated = [] 54 | for name, weight in weights_from_hdf5(f): 55 | for n in possible_weight_names(name): 56 | if n in weights: 57 | updates.append((weights[n], weight)) 58 | updated.append(n) 59 | break 60 | K.batch_set_value(updates) 61 | 62 | return updated 63 | 64 | 65 | class DatasetSequence(Sequence): 66 | """Implement the Keras Sequence interface from a BaseDataset interface.""" 67 | def __init__(self, dataset, train=True, part=slice(None), batch_size=32): 68 | self._data = dataset.train_data if train else dataset.test_data 69 | self._idxs = np.arange(len(self._data))[part] 70 | self._batch_size = batch_size 71 | 72 | def __len__(self): 73 | return int(ceil(float(len(self._idxs)) / self._batch_size)) 74 | 75 | def __getitem__(self, idx): 76 | batch = self._idxs[self._batch_size*idx:self._batch_size*(idx+1)] 77 | return self._data[batch] 78 | -------------------------------------------------------------------------------- /importance_sampling/utils/tf.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | from keras import backend as K 7 | 8 | tf = None 9 | if K.backend() == "tensorflow": 10 | import tensorflow as tf 11 | -------------------------------------------------------------------------------- /importance_sampling/utils/tf_config.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | """This module aims to configure the Keras tensorflow session based on 7 | environment variables or other sources""" 8 | 9 | from multiprocessing import cpu_count 10 | import os 11 | 12 | from keras import backend as K 13 | from .tf import tf 14 | 15 | 16 | if K.backend() == "tensorflow": 17 | TF_THREADS = int(os.environ.get("TF_THREADS", cpu_count())) 18 | 19 | config = tf.ConfigProto( 20 | intra_op_parallelism_threads=TF_THREADS, 21 | inter_op_parallelism_threads=TF_THREADS, 22 | device_count={"CPU": TF_THREADS} 23 | ) 24 | session = tf.Session(config=config) 25 | K.set_session(session) 26 | 27 | 28 | def with_tensorflow(f): 29 | def inner(*args, **kwargs): 30 | if K.backend() == "tensorflow": 31 | return f(tf, *args, **kwargs) 32 | return inner 33 | 34 | 35 | @with_tensorflow 36 | def set_random_seed(tf, seed): 37 | return tf.set_random_seed(seed) 38 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Importance Sampling for Keras 2 | repo_url: https://github.com/idiap/importance-sampling 3 | theme: readthedocs 4 | extra_javascript: 5 | - https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML 6 | extra_css: 7 | - css/extra.css 8 | markdown_extensions: 9 | - mdx_math 10 | nav: 11 | - Home: index.md 12 | - Training: training.md 13 | - Datasets: datasets.md 14 | - Examples: examples.md 15 | - Contributing: contributing.md 16 | -------------------------------------------------------------------------------- /scripts/compute_scores.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 4 | # Written by Angelos Katharopoulos 5 | # 6 | 7 | import argparse 8 | 9 | import numpy as np 10 | 11 | from importance_sampling import models 12 | from importance_sampling.datasets import CIFAR10, CIFAR100, MNIST, \ 13 | OntheflyAugmentedImages, PennTreeBank 14 | from importance_sampling.model_wrappers import OracleWrapper 15 | from importance_sampling.utils.functional import compose, partial, ___ 16 | 17 | 18 | 19 | def load_dataset(dataset): 20 | datasets = { 21 | "mnist": MNIST, 22 | "cifar10": CIFAR10, 23 | "cifar100": CIFAR100, 24 | "cifar10-augmented": compose( 25 | partial(OntheflyAugmentedImages, ___, dict( 26 | featurewise_center=False, 27 | samplewise_center=False, 28 | featurewise_std_normalization=False, 29 | samplewise_std_normalization=False, 30 | zca_whitening=False, 31 | rotation_range=0, 32 | width_shift_range=0.1, 33 | height_shift_range=0.1, 34 | horizontal_flip=True, 35 | vertical_flip=False 36 | )), 37 | CIFAR10 38 | ), 39 | "cifar100-augmented": compose( 40 | partial(OntheflyAugmentedImages, ___, dict( 41 | featurewise_center=False, 42 | samplewise_center=False, 43 | featurewise_std_normalization=False, 44 | samplewise_std_normalization=False, 45 | zca_whitening=False, 46 | rotation_range=0, 47 | width_shift_range=0.1, 48 | height_shift_range=0.1, 49 | horizontal_flip=True, 50 | vertical_flip=False 51 | )), 52 | CIFAR100 53 | ), 54 | "ptb": partial(PennTreeBank, 20) 55 | } 56 | 57 | return datasets[dataset]() 58 | 59 | 60 | 61 | def main(argv): 62 | parser = argparse.ArgumentParser( 63 | description="Plot the loss distribution of a model and dataset pair" 64 | ) 65 | 66 | parser.add_argument( 67 | "model", 68 | choices=[ 69 | "small_cnn", "cnn", "lstm_lm", "lstm_lm2", "lstm_lm3", 70 | "small_cnn_sq" 71 | ], 72 | help="Choose the type of the model" 73 | ) 74 | parser.add_argument( 75 | "weights", 76 | help="The file containing the model weights" 77 | ) 78 | parser.add_argument( 79 | "dataset", 80 | choices=[ 81 | "mnist", "cifar10", "cifar100", "cifar10-augmented", 82 | "cifar100-augmented", "ptb" 83 | ], 84 | help="Choose the dataset to compute the loss" 85 | ) 86 | parser.add_argument( 87 | "--score", 88 | choices=["gnorm", "loss"], 89 | default="loss", 90 | help="Choose a score to plot" 91 | ) 92 | parser.add_argument( 93 | "--batch_size", 94 | type=int, 95 | default=128, 96 | help="The batch size for computing the loss" 97 | ) 98 | parser.add_argument( 99 | "--random_seed", 100 | type=int, 101 | default=0, 102 | help="A seed for the PRNG (mainly used for dataset generation)" 103 | ) 104 | 105 | args = parser.parse_args(argv) 106 | 107 | np.random.seed(args.random_seed) 108 | 109 | dataset = load_dataset(args.dataset) 110 | network = models.get(args.model)(dataset.shape, dataset.output_size) 111 | model = OracleWrapper(network, score=args.score) 112 | model.model.load_weights(args.weights) 113 | 114 | for i in range(0, dataset.train_size, args.batch_size): 115 | idxs = slice(i, i+args.batch_size) 116 | for s in model.score_batch(*dataset.train_data(idxs)): 117 | print s 118 | 119 | 120 | if __name__ == "__main__": 121 | import sys 122 | main(sys.argv[1:]) 123 | -------------------------------------------------------------------------------- /scripts/importance_sampling: -------------------------------------------------------------------------------- 1 | ../importance_sampling/ -------------------------------------------------------------------------------- /scripts/lfw_evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 4 | # Written by Angelos Katharopoulos 5 | # 6 | 7 | """Evaluate the embedding performance in the lfw dataset""" 8 | 9 | import argparse 10 | import os 11 | 12 | from importance_sampling.datasets import LFW 13 | import numpy as np 14 | 15 | 16 | def compute_distances(representations, pairs): 17 | X = representations[pairs[:, 0], :] 18 | Y = representations[pairs[:, 1], :] 19 | 20 | X = X / np.sqrt((X**2).sum(axis=1, keepdims=True)) 21 | Y = Y / np.sqrt((Y**2).sum(axis=1, keepdims=True)) 22 | 23 | return np.sum((X - Y)**2, axis=1, keepdims=True) 24 | 25 | 26 | def compute_threshold(distances, matches): 27 | a, b = 0.0, distances.max() 28 | N = len(distances) 29 | while b-a > 1e-5: 30 | m = (a + b)/2 31 | t1 = (a + m)/2 32 | t2 = (m + b)/2 33 | 34 | n0 = np.sum((distances < m).astype(float) == matches) 35 | n1 = np.sum((distances < t1).astype(float) == matches) 36 | n2 = np.sum((distances < t2).astype(float) == matches) 37 | 38 | if n0 > n1 and n0 > n2: 39 | a, b = t1, t2 40 | elif n1 > n0 and n0 >= n2: 41 | b = m 42 | elif n2 > n0 and n0 >= n1: 43 | a = m 44 | else: 45 | return m 46 | 47 | return (a+b)/2 48 | 49 | 50 | def evaluate(representations, dataset): 51 | pairs_train, matches_train = dataset.train_data[:] 52 | distances = compute_distances(representations, pairs_train) 53 | t = compute_threshold(distances, matches_train) 54 | 55 | pairs_test, matches_test = dataset.test_data[:] 56 | distances = compute_distances(representations, pairs_test) 57 | 58 | return ((distances < t).astype(float) == matches_test).astype(float).mean() 59 | 60 | 61 | def main(argv): 62 | parser = argparse.ArgumentParser( 63 | description="Evaluate a representation on the lfw dataset" 64 | ) 65 | 66 | parser.add_argument( 67 | "representations", 68 | help="The representations of the faces of LFW" 69 | ) 70 | 71 | parser.add_argument( 72 | "--embedding", 73 | type=int, 74 | default=128, 75 | help="Choose the dimensionality of the representation" 76 | ) 77 | parser.add_argument( 78 | "--folds", 79 | type=int, 80 | default=10, 81 | help="How many folds does the dataset have" 82 | ) 83 | parser.add_argument( 84 | "--dataset_path", 85 | default=os.getenv("LFW", ""), 86 | help="The basepath of the LFW dataset" 87 | ) 88 | 89 | args = parser.parse_args(argv) 90 | 91 | print "Loading representations..." 92 | representations = np.fromfile(args.representations, dtype=np.float32) 93 | representations = representations.reshape(-1, args.embedding) 94 | 95 | for fold in range(1, args.folds+1): 96 | dataset = LFW(args.dataset_path, fold=fold, idxs=True) 97 | print "Fold {}: {}".format(fold, evaluate(representations, dataset)) 98 | 99 | 100 | if __name__ == "__main__": 101 | import sys 102 | main(sys.argv[1:]) 103 | -------------------------------------------------------------------------------- /scripts/lfw_forward_pass.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 4 | # Written by Angelos Katharopoulos 5 | # 6 | 7 | """Run the forward pass on the lfw images so that we can then run the 8 | evaluation""" 9 | 10 | import argparse 11 | import os 12 | 13 | from importance_sampling.datasets import LFW 14 | from importance_sampling.models import get as get_model 15 | from importance_sampling.utils import tf_config, keras_utils 16 | 17 | 18 | def make_slice(x): 19 | def int_or_none(x): 20 | try: 21 | return int(x) 22 | except ValueError: 23 | return None 24 | return slice(*list(map(int_or_none, x.split(":")))) 25 | 26 | 27 | def batch_gen(dataset, part, batch_size): 28 | idxs = np.arange(len(dataset.train_data))[part] 29 | # Keras annoyingly requires infinite iterators 30 | while True: 31 | for i in range(0, len(idxs), batch_size): 32 | yield dataset.train_data[idxs[i:i+batch_size]][0] 33 | 34 | 35 | def main(argv): 36 | parser = argparse.ArgumentParser( 37 | description="Compute representations for the lfw dataset" 38 | ) 39 | 40 | parser.add_argument( 41 | "model", 42 | choices=["pretrained_resnet50"], 43 | help="Choose the architecture to load" 44 | ) 45 | parser.add_argument( 46 | "weights", 47 | help="Load those weights" 48 | ) 49 | parser.add_argument( 50 | "output", 51 | help="Save the produced representations to this file" 52 | ) 53 | 54 | parser.add_argument( 55 | "--embedding", 56 | type=int, 57 | default=128, 58 | help="Choose the dimensionality of the representation" 59 | ) 60 | parser.add_argument( 61 | "--slice", 62 | type=make_slice, 63 | default=":", 64 | help="Slice the dataset to get a part to transform" 65 | ) 66 | parser.add_argument( 67 | "--batch_size", 68 | type=int, 69 | default=16, 70 | help="The batch size used for the forward pass" 71 | ) 72 | parser.add_argument( 73 | "--dataset_path", 74 | default=os.getenv("LFW", ""), 75 | help="The basepath of the LFW dataset" 76 | ) 77 | 78 | args = parser.parse_args(argv) 79 | print "Loading dataset..." 80 | dataset = LFW(args.dataset_path, fold=None) 81 | print "Loading model..." 82 | model = get_model(args.model)( 83 | dataset.shape, 84 | args.embedding 85 | ) 86 | keras_utils.load_weights_by_name(args.weights, model) 87 | 88 | print "Transforming..." 89 | representations = model.predict_generator( 90 | keras_utils.DatasetSequence( 91 | dataset, 92 | part=args.slice, 93 | batch_size=args.batch_size 94 | ), 95 | verbose=1 96 | ) 97 | 98 | print "Saving {} representations...".format(len(representations)) 99 | representations.tofile(args.output) 100 | 101 | 102 | if __name__ == "__main__": 103 | import sys 104 | main(sys.argv[1:]) 105 | -------------------------------------------------------------------------------- /scripts/lsexp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 4 | # Written by Angelos Katharopoulos 5 | # 6 | 7 | import argparse 8 | from itertools import product 9 | import os 10 | from os import path 11 | 12 | 13 | LINKS = set([".", ".."]) 14 | 15 | 16 | class Experiment(object): 17 | def __init__(self, root, dataset, network, sampler, score, reweighting, 18 | params): 19 | self.root = root 20 | self.dataset = dataset 21 | self.network = network 22 | self.sampler = sampler 23 | self.score = score 24 | self.reweighting = reweighting 25 | self.params = params 26 | 27 | @property 28 | def path(self): 29 | return path.join( 30 | self.root, 31 | self.dataset, 32 | self.network, 33 | self.sampler, 34 | self.score, 35 | self.reweighting, 36 | self.params 37 | ) 38 | 39 | @property 40 | def exists(self): 41 | """Return if this experiment is a valid combination of dataset, 42 | network, ...""" 43 | return path.exists(self.path) 44 | 45 | @property 46 | def started(self): 47 | """Has the experiment started running?""" 48 | return path.exists(path.join(self.path, "stdout")) 49 | 50 | @property 51 | def updated(self): 52 | """Check when was the last time train.txt was updated""" 53 | try: 54 | return path.getmtime(path.join(self.path, "train.txt")) 55 | except OSError: 56 | return path.getmtime(self.path) 57 | 58 | @property 59 | def epochs(self): 60 | """Return the number of epochs that this experiment has run""" 61 | cnt = 0 62 | try: 63 | with open(path.join(self.path, "val_eval.txt")) as f: 64 | for l in f: 65 | cnt += 1 66 | except IOError: 67 | pass 68 | return max(cnt - 1, 0) 69 | 70 | def __repr__(self): 71 | return "Experiment(%r, %r, %r, %r, %r, %r, %r)" % ( 72 | self.root, 73 | self.dataset, 74 | self.network, 75 | self.sampler, 76 | self.score, 77 | self.reweighting, 78 | self.params 79 | ) 80 | 81 | 82 | def _is_existing_dir(x): 83 | return path.exists(x) and path.isdir(x) 84 | 85 | 86 | if __name__ == "__main__": 87 | parser = argparse.ArgumentParser( 88 | description="List the importance sampling experiments" 89 | ) 90 | parser.add_argument( 91 | "--expdir", 92 | default="." 93 | ) 94 | parser.add_argument( 95 | "--latest", "-t", 96 | action="store_true", 97 | help="Sort based on time last updated" 98 | ) 99 | parser.add_argument( 100 | "--pending", "-p", 101 | action="store_true", 102 | help="Show only pending (not started) experiments" 103 | ) 104 | 105 | args = parser.parse_args() 106 | 107 | root = args.expdir 108 | datasets = set( 109 | x 110 | for x in os.listdir(root) 111 | if _is_existing_dir(path.join(root, x)) 112 | ) - LINKS 113 | networks = set(sum([ 114 | os.listdir(path.join(root, *x)) 115 | for x in product(datasets) 116 | if _is_existing_dir(path.join(root, *x)) 117 | ], [])) - LINKS 118 | samplers = set(sum([ 119 | os.listdir(path.join(root, *x)) 120 | for x in product(datasets, networks) 121 | if _is_existing_dir(path.join(root, *x)) 122 | ], [])) - LINKS 123 | scores = set(sum([ 124 | os.listdir(path.join(root, *x)) 125 | for x in product(datasets, networks, samplers) 126 | if _is_existing_dir(path.join(root, *x)) 127 | ], [])) - LINKS 128 | reweightings = set(sum([ 129 | os.listdir(path.join(root, *x)) 130 | for x in product(datasets, networks, samplers, scores) 131 | if _is_existing_dir(path.join(root, *x)) 132 | ], [])) - LINKS 133 | params = set(sum([ 134 | os.listdir(path.join(root, *x)) 135 | for x in product(datasets, networks, samplers, scores, reweightings) 136 | if _is_existing_dir(path.join(root, *x)) 137 | ], [])) - LINKS 138 | 139 | experiments = filter( 140 | lambda e: e.exists and e.started != args.pending, 141 | [ 142 | Experiment(root, *x) 143 | for x in product(datasets, networks, samplers, scores, 144 | reweightings, params) 145 | ] 146 | ) 147 | 148 | if args.latest: 149 | experiments.sort(key=lambda e: e.updated) 150 | else: 151 | experiments.sort(key=lambda e: e.path) 152 | 153 | print "Dataset\tNetwork\tSampler\tScore\tReweighting\tParams\tEpochs" 154 | for e in experiments: 155 | print "\t".join([ 156 | e.dataset, 157 | e.network, 158 | e.sampler, 159 | e.score, 160 | e.reweighting, 161 | e.params, 162 | str(e.epochs) 163 | ]) 164 | -------------------------------------------------------------------------------- /scripts/plot_distribution.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 4 | # Written by Angelos Katharopoulos 5 | # 6 | 7 | import argparse 8 | import time 9 | 10 | from matplotlib.animation import FuncAnimation 11 | import matplotlib.pyplot as plt 12 | import seaborn as sns 13 | 14 | 15 | def lines(f, delim): 16 | while True: 17 | line = f.readline() 18 | if line == "": 19 | break 20 | yield map(float, line.strip().split(delim)) 21 | 22 | 23 | def update(data, ax, xlim, ylim, vl): 24 | ax.clear() 25 | sns.distplot(data, ax=ax) 26 | if xlim: 27 | ax.set_xlim(xlim) 28 | if ylim: 29 | ax.set_ylim(ylim) 30 | 31 | if vl is not None: 32 | ax.plot([vl, vl], ax.get_ylim(), "k--") 33 | 34 | return ax 35 | 36 | 37 | if __name__ == "__main__": 38 | import sys 39 | 40 | parser = argparse.ArgumentParser( 41 | description="Plot the distribution of values in each line" 42 | ) 43 | parser.add_argument( 44 | "--delimiter", "-d", 45 | default=" ", 46 | help="Define the field delimiter" 47 | ) 48 | parser.add_argument( 49 | "--xlim", 50 | type=lambda x: None if not x else map(float, x.split(",")), 51 | default="", 52 | help="Specific limits for the x axis" 53 | ) 54 | parser.add_argument( 55 | "--ylim", 56 | type=lambda x: None if not x else map(float, x.split(",")), 57 | default="", 58 | help="Specific limits for the y axis" 59 | ) 60 | parser.add_argument( 61 | "--frames", 62 | type=int, 63 | default=100, 64 | help="Number of frames expected" 65 | ) 66 | parser.add_argument( 67 | "--to_file", 68 | help="Save the animation as a video to that file" 69 | ) 70 | parser.add_argument( 71 | "--vline", 72 | type=float, 73 | help="Plot a vertical line" 74 | ) 75 | 76 | args = parser.parse_args(sys.argv[1:]) 77 | 78 | data_gen = lines(sys.stdin, args.delimiter) 79 | 80 | fig, ax = plt.subplots() 81 | anim = FuncAnimation( 82 | fig, 83 | update, 84 | data_gen, 85 | fargs=(ax, args.xlim, args.ylim, args.vline), 86 | interval=100, 87 | save_count=args.frames 88 | ) 89 | if args.to_file: 90 | anim.save(args.to_file) 91 | else: 92 | plt.show() 93 | -------------------------------------------------------------------------------- /scripts/plot_loss_evolution.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 4 | # Written by Angelos Katharopoulos 5 | # 6 | 7 | import argparse 8 | import sys 9 | 10 | import matplotlib.pyplot as plt 11 | from matplotlib.animation import FuncAnimation, writers as animation_writers 12 | from matplotlib.cm import ScalarMappable 13 | import numpy as np 14 | from sklearn.linear_model import LinearRegression 15 | 16 | 17 | def maybe_int(x): 18 | try: 19 | return int(x) 20 | except: 21 | return None 22 | 23 | 24 | def file_or_stdin(x): 25 | if x == "-": 26 | return sys.stdin 27 | else: 28 | return x 29 | 30 | 31 | def colors(x): 32 | x = np.array(x) 33 | x -= x.min() 34 | if x.max() != 0: 35 | x /= x.max() 36 | x *= 255 37 | x = np.round(x).astype(int) 38 | return [ 39 | plt.cm.viridis.colors[xx] 40 | for xx in x 41 | ] 42 | 43 | 44 | def main(argv): 45 | parser = argparse.ArgumentParser( 46 | description="Plot the loss evolving through time" 47 | ) 48 | 49 | parser.add_argument( 50 | "metrics", 51 | type=file_or_stdin, 52 | help="The file containing the loss" 53 | ) 54 | 55 | parser.add_argument( 56 | "--to_file", 57 | help="Save the animation to a video file" 58 | ) 59 | parser.add_argument( 60 | "--step", 61 | type=int, 62 | default=1000, 63 | help="Change that many datapoints in between frames" 64 | ) 65 | parser.add_argument( 66 | "--n_points", 67 | type=int, 68 | default=10000, 69 | help="That many points in each frame" 70 | ) 71 | parser.add_argument( 72 | "--frames", 73 | type=lambda x: slice(*map(maybe_int, x.split(":"))), 74 | default=":", 75 | help="Choose only those frames" 76 | ) 77 | parser.add_argument( 78 | "--lim", 79 | type=lambda x: map(float, x.split(",")), 80 | help="Define the limits of the axes" 81 | ) 82 | parser.add_argument( 83 | "--no_colorbar", 84 | action="store_false", 85 | dest="colorbar", 86 | help="Do not display a colorbar" 87 | ) 88 | 89 | args = parser.parse_args(argv) 90 | loss = np.loadtxt(args.metrics) 91 | 92 | fig, ax = plt.subplots() 93 | lr = LinearRegression() 94 | sc = ax.scatter(loss[:args.n_points, 0], loss[:args.n_points, 1], c=colors(loss[:args.n_points, 2])) 95 | lims = args.lim if args.lim else [0, loss[:, 0].max()] 96 | ln, = ax.plot(lims, lims, "--", color="black", label="linear fit") 97 | ax.set_xlim(lims) 98 | ax.set_ylim(lims) 99 | ax.set_xlabel("$L(\cdot)$") 100 | ax.set_ylabel("$\hat{L}(\cdot)$") 101 | if args.colorbar: 102 | mappable = ScalarMappable(cmap="viridis") 103 | mappable.set_array(loss[:10000, 2]) 104 | plt.colorbar(mappable) 105 | 106 | STEP = args.step 107 | N_POINTS = args.n_points 108 | def update(i): 109 | s = i*STEP 110 | e = s + N_POINTS 111 | lr.fit(loss[s:e, :1], loss[s:e, 1].ravel()) 112 | ln.set_ydata([ 113 | lr.intercept_.ravel(), 114 | lr.intercept_.ravel() + lims[1]*lr.coef_.ravel() 115 | ]) 116 | ax.set_title("Showing samples %d to %d" % (s, e)) 117 | sc.set_facecolor(colors(loss[s:e, 2])) 118 | sc.set_offsets(loss[s:e, :2]) 119 | return ax, sc, ln 120 | 121 | anim = FuncAnimation( 122 | fig, update, 123 | interval=100, 124 | frames=np.arange(len(loss) / STEP)[args.frames], 125 | blit=False, repeat=False 126 | ) 127 | if args.to_file: 128 | writer = animation_writers["ffmpeg"](fps=15) 129 | anim.save(args.to_file, writer=writer) 130 | else: 131 | plt.show() 132 | 133 | 134 | if __name__ == "__main__": 135 | import sys 136 | main(sys.argv[1:]) 137 | -------------------------------------------------------------------------------- /scripts/variance_reduction.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 4 | # Written by Angelos Katharopoulos 5 | # 6 | 7 | import argparse 8 | import os 9 | from os import path 10 | 11 | from keras import backend as K 12 | from keras.losses import get as get_loss 13 | from keras.utils.generic_utils import Progbar 14 | import numpy as np 15 | 16 | from importance_sampling import models 17 | from importance_sampling.datasets import CIFAR10, CIFAR100, MNIST, \ 18 | OntheflyAugmentedImages, ImageNetDownsampled, PennTreeBank, ZCAWhitening 19 | from importance_sampling.model_wrappers import OracleWrapper 20 | from importance_sampling.reweighting import BiasedReweightingPolicy 21 | from importance_sampling.utils import tf_config 22 | from importance_sampling.utils.functional import compose, partial, ___ 23 | 24 | 25 | def build_grad(network): 26 | """Return the gradient of the network.""" 27 | x = network.input 28 | y = network.output 29 | target_shape = (None, 1) if "sparse" in network.loss else K.int_shape(y) 30 | y_true = K.placeholder(shape=target_shape) 31 | sample_weights = K.placeholder(shape=(None,)) 32 | 33 | l = K.mean(sample_weights * get_loss(network.loss)(y_true, y)) 34 | grads = network.optimizer.get_gradients(l, network.trainable_weights) 35 | grad = K.concatenate([ 36 | K.reshape(g, (-1,)) 37 | for g in grads 38 | ]) 39 | 40 | return K.function( 41 | [x, y_true, sample_weights], 42 | [grad] 43 | ) 44 | 45 | 46 | def build_grad_batched(network, batch_size): 47 | """Compute the average gradient by splitting the inputs in batches of size 48 | 'batch_size' and averaging.""" 49 | grad = build_grad(network) 50 | def inner(inputs): 51 | X, y, w = inputs 52 | N = len(X) 53 | g = 0 54 | for i in range(0, N, batch_size): 55 | g = g + w[i:i+batch_size].sum() * grad([ 56 | X[i:i+batch_size], 57 | y[i:i+batch_size], 58 | w[i:i+batch_size] 59 | ])[0] 60 | return [g / w.sum()] 61 | 62 | return inner 63 | 64 | 65 | def load_dataset(dataset): 66 | datasets = { 67 | "mnist": MNIST, 68 | "cifar10": CIFAR10, 69 | "cifar100": CIFAR100, 70 | "cifar10-augmented": compose( 71 | partial(OntheflyAugmentedImages, ___, dict( 72 | featurewise_center=False, 73 | samplewise_center=False, 74 | featurewise_std_normalization=False, 75 | samplewise_std_normalization=False, 76 | zca_whitening=False, 77 | rotation_range=0, 78 | width_shift_range=0.1, 79 | height_shift_range=0.1, 80 | horizontal_flip=True, 81 | vertical_flip=False 82 | )), 83 | CIFAR10 84 | ), 85 | "cifar10-whitened-augmented": compose( 86 | partial(OntheflyAugmentedImages, ___, dict( 87 | featurewise_center=False, 88 | samplewise_center=False, 89 | featurewise_std_normalization=False, 90 | samplewise_std_normalization=False, 91 | zca_whitening=False, 92 | rotation_range=0, 93 | width_shift_range=0.1, 94 | height_shift_range=0.1, 95 | horizontal_flip=True, 96 | vertical_flip=False 97 | ), N=15*10**5), 98 | ZCAWhitening, 99 | CIFAR10 100 | ), 101 | "cifar100-augmented": compose( 102 | partial(OntheflyAugmentedImages, ___, dict( 103 | featurewise_center=False, 104 | samplewise_center=False, 105 | featurewise_std_normalization=False, 106 | samplewise_std_normalization=False, 107 | zca_whitening=False, 108 | rotation_range=0, 109 | width_shift_range=0.1, 110 | height_shift_range=0.1, 111 | horizontal_flip=True, 112 | vertical_flip=False 113 | )), 114 | CIFAR100 115 | ), 116 | "cifar100-whitened-augmented": compose( 117 | partial(OntheflyAugmentedImages, ___, dict( 118 | featurewise_center=False, 119 | samplewise_center=False, 120 | featurewise_std_normalization=False, 121 | samplewise_std_normalization=False, 122 | zca_whitening=False, 123 | rotation_range=0, 124 | width_shift_range=0.1, 125 | height_shift_range=0.1, 126 | horizontal_flip=True, 127 | vertical_flip=False 128 | ), N=15*10**5), 129 | ZCAWhitening, 130 | CIFAR100 131 | ), 132 | "imagenet-32x32": partial( 133 | ImageNetDownsampled, 134 | os.getenv("IMAGENET"), 135 | size=32 136 | ), 137 | "ptb": partial(PennTreeBank, 20), 138 | } 139 | 140 | return datasets[dataset]() 141 | 142 | 143 | def uniform_score(x, y, batch_size=None): 144 | return np.ones((len(x),)) 145 | 146 | 147 | def main(argv): 148 | parser = argparse.ArgumentParser( 149 | description=("Compute the variance reduction achieved by different " 150 | "importance sampling methods") 151 | ) 152 | 153 | parser.add_argument( 154 | "model", 155 | choices=[ 156 | "small_cnn", "cnn", "wide_resnet_28_2", "lstm_lm" 157 | ], 158 | help="Choose the type of the model" 159 | ) 160 | parser.add_argument( 161 | "weights", 162 | help="The file containing the model weights" 163 | ) 164 | parser.add_argument( 165 | "dataset", 166 | choices=[ 167 | "mnist", "cifar10", "cifar100", "cifar10-augmented", 168 | "cifar100-augmented", "imagenet-32x32", "ptb", 169 | "cifar10-whitened-augmented", "cifar100-whitened-augmented" 170 | ], 171 | help="Choose the dataset to compute the loss" 172 | ) 173 | parser.add_argument( 174 | "--samples", 175 | type=int, 176 | default=10, 177 | help="How many samples to choose" 178 | ) 179 | parser.add_argument( 180 | "--score", 181 | choices=["gnorm", "full_gnorm", "loss", "ones"], 182 | nargs="+", 183 | default="loss", 184 | help="Choose a score to perform sampling with" 185 | ) 186 | parser.add_argument( 187 | "--batch_size", 188 | type=int, 189 | default=128, 190 | help="The batch size for computing the loss" 191 | ) 192 | parser.add_argument( 193 | "--inner_batch_size", 194 | type=int, 195 | default=32, 196 | help=("The batch size to use for gradient computations " 197 | "(to decrease memory usage)") 198 | ) 199 | parser.add_argument( 200 | "--sample_size", 201 | type=int, 202 | default=1024, 203 | help="The sample size to compute the variance reduction" 204 | ) 205 | parser.add_argument( 206 | "--random_seed", 207 | type=int, 208 | default=0, 209 | help="A seed for the PRNG (mainly used for dataset generation)" 210 | ) 211 | parser.add_argument( 212 | "--save_scores", 213 | help="Directory to save the scores in" 214 | ) 215 | 216 | args = parser.parse_args(argv) 217 | 218 | np.random.seed(args.random_seed) 219 | 220 | dataset = load_dataset(args.dataset) 221 | network = models.get(args.model)(dataset.shape, dataset.output_size) 222 | network.load_weights(args.weights) 223 | grad = build_grad_batched(network, args.inner_batch_size) 224 | reweighting = BiasedReweightingPolicy() 225 | 226 | # Compute the full gradient 227 | idxs = np.random.choice(len(dataset.train_data), args.sample_size) 228 | x, y = dataset.train_data[idxs] 229 | full_grad = grad([x, y, np.ones(len(x))])[0] 230 | 231 | # Sample and approximate 232 | for score_metric in args.score: 233 | if score_metric != "ones": 234 | model = OracleWrapper(network, reweighting, score=score_metric) 235 | score = model.score 236 | else: 237 | score = uniform_score 238 | gs = np.zeros(shape=(10,) + full_grad.shape, dtype=np.float32) 239 | print "Calculating %s..." % (score_metric,) 240 | scores = score(x, y, batch_size=1) 241 | p = scores/scores.sum() 242 | pb = Progbar(args.samples) 243 | for i in range(args.samples): 244 | pb.update(i) 245 | idxs = np.random.choice(args.sample_size, args.batch_size, p=p) 246 | w = reweighting.sample_weights(idxs, scores).ravel() 247 | gs[i] = grad([x[idxs], y[idxs], w])[0] 248 | pb.update(args.samples) 249 | norms = np.sqrt(((full_grad - gs)**2).sum(axis=1)) 250 | alignment = gs.dot(full_grad[:, np.newaxis]) / np.sqrt(np.sum(full_grad**2)) 251 | alignment /= np.sqrt((gs**2).sum(axis=1, keepdims=True)) 252 | print "Mean of norms of diff", np.mean(norms) 253 | print "Variance of norms of diff", np.var(norms) 254 | print "Mean of alignment", np.mean(alignment) 255 | print "Variance of alignment", np.var(alignment) 256 | if args.save_scores: 257 | np.savetxt( 258 | path.join(args.save_scores, score_metric+".txt"), 259 | scores 260 | ) 261 | 262 | 263 | if __name__ == "__main__": 264 | import sys 265 | main(sys.argv[1:]) 266 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 4 | # Written by Angelos Katharopoulos 5 | # 6 | 7 | """Setup importance-sampling""" 8 | 9 | from itertools import dropwhile 10 | from os import path 11 | from setuptools import find_packages, setup 12 | 13 | 14 | def collect_docstring(lines): 15 | """Return document docstring if it exists""" 16 | lines = dropwhile(lambda x: not x.startswith('"""'), lines) 17 | doc = "" 18 | for line in lines: 19 | doc += line 20 | if doc.endswith('"""\n'): 21 | break 22 | 23 | return doc[3:-4].replace("\r", "").replace("\n", " ") 24 | 25 | 26 | def collect_metadata(): 27 | meta = {} 28 | with open(path.join("importance_sampling","__init__.py")) as f: 29 | lines = iter(f) 30 | meta["description"] = collect_docstring(lines) 31 | for line in lines: 32 | if line.startswith("__"): 33 | key, value = map(lambda x: x.strip(), line.split("=")) 34 | meta[key[2:-2]] = value[1:-1] 35 | 36 | return meta 37 | 38 | def setup_package(): 39 | with open("README.rst") as f: 40 | long_description = f.read() 41 | meta = collect_metadata() 42 | setup( 43 | name="keras-importance-sampling", 44 | version=meta["version"], 45 | description=meta["description"], 46 | long_description=long_description, 47 | maintainer=meta["maintainer"], 48 | maintainer_email=meta["email"], 49 | url=meta["url"], 50 | license=meta["license"], 51 | classifiers=[ 52 | "Intended Audience :: Science/Research", 53 | "Intended Audience :: Developers", 54 | "License :: OSI Approved :: MIT License", 55 | "Topic :: Scientific/Engineering", 56 | "Programming Language :: Python", 57 | "Programming Language :: Python :: 2", 58 | "Programming Language :: Python :: 2.7", 59 | "Programming Language :: Python :: 3", 60 | "Programming Language :: Python :: 3.4", 61 | "Programming Language :: Python :: 3.5", 62 | ], 63 | packages=find_packages(exclude=["docs", "tests", "scripts", "examples"]), 64 | install_requires=["keras>=2", "blinker", "numpy"] 65 | ) 66 | 67 | 68 | if __name__ == "__main__": 69 | setup_package() 70 | -------------------------------------------------------------------------------- /tests/importance_sampling: -------------------------------------------------------------------------------- 1 | ../importance_sampling/ -------------------------------------------------------------------------------- /tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | from functools import partial 7 | import os 8 | import unittest 9 | 10 | import numpy as np 11 | 12 | from importance_sampling.datasets import CIFAR10, CIFARSanityCheck, MNIST, \ 13 | CanevetICML2016, OntheflyAugmentedImages, PennTreeBank, GeneratorDataset, \ 14 | InMemoryImageDataset, ZCAWhitening 15 | from importance_sampling.utils.functional import compose 16 | 17 | 18 | class TestDatasets(unittest.TestCase): 19 | def _test_dset(self, dset, n_train, n_test, shape, output_size): 20 | dset = dset() 21 | self.assertEqual(len(dset.train_data), n_train) 22 | self.assertEqual(len(dset.test_data), n_test) 23 | self.assertEqual(dset.shape, shape) 24 | self.assertEqual(dset.train_data[[0]][0][0].shape, shape) 25 | self.assertEqual(dset.output_size, output_size) 26 | 27 | @unittest.skipUnless(os.getenv("TEST_DATASETS", False), 28 | "Datasets need not be tested all the time") 29 | def test_datasets(self): 30 | datasets = [ 31 | (CIFAR10, 50000, 10000, (32, 32, 3), 10), 32 | (CIFARSanityCheck, 40000, 2000, (32, 32, 3), 2), 33 | (MNIST, 60000, 10000, (28, 28, 1), 10), 34 | ( 35 | partial(CanevetICML2016, N=256), int(256**2 - 256**2 * 0.33) + 1, 36 | int(256**2 * 0.33), (2,), 2 37 | ), 38 | ( 39 | compose( 40 | partial(OntheflyAugmentedImages, augmentation_params=dict( 41 | featurewise_center=False, 42 | samplewise_center=False, 43 | featurewise_std_normalization=False, 44 | samplewise_std_normalization=False, 45 | zca_whitening=False, 46 | rotation_range=0, 47 | width_shift_range=0.1, 48 | height_shift_range=0.1, 49 | horizontal_flip=True, 50 | vertical_flip=False 51 | )), CIFAR10), 500000, 10000, (32, 32, 3), 10 52 | ), 53 | (partial(PennTreeBank, 20), 887521, 70390, (20,), 10000) 54 | ] 55 | 56 | for args in datasets: 57 | self._test_dset(*args) 58 | 59 | def test_zca_whitening(self): 60 | dset = ZCAWhitening(InMemoryImageDataset( 61 | np.random.rand(1000, 32, 32, 3), 62 | (np.random.rand(1000, 1)*10).astype(np.int32), 63 | np.random.rand(1000, 32, 32, 3), 64 | (np.random.rand(1000, 1)*10).astype(np.int32) 65 | )) 66 | 67 | self.assertEqual(len(dset.train_data), 1000) 68 | self.assertEqual(len(dset.test_data), 1000) 69 | 70 | idxs = np.random.choice(len(dset.train_data), 100) 71 | x_r, y_r = dset.train_data[idxs] 72 | self.assertEqual(x_r.shape, (100, 32, 32, 3)) 73 | self.assertEqual(y_r.shape, (100, 10)) 74 | for i in range(10): 75 | x, y = dset.train_data[idxs] 76 | self.assertTrue(np.all(x_r == x)) 77 | self.assertTrue(np.all(y_r == y)) 78 | 79 | def test_image_augmentation(self): 80 | orig_dset = InMemoryImageDataset( 81 | np.random.rand(1000, 32, 32, 3), 82 | (np.random.rand(1000, 1)*10).astype(np.int32), 83 | np.random.rand(1000, 32, 32, 3), 84 | (np.random.rand(1000, 1)*10).astype(np.int32) 85 | ) 86 | 87 | dset = OntheflyAugmentedImages( 88 | orig_dset, 89 | dict( 90 | featurewise_center=False, 91 | samplewise_center=False, 92 | featurewise_std_normalization=False, 93 | samplewise_std_normalization=False, 94 | zca_whitening=False, 95 | rotation_range=0, 96 | width_shift_range=0.1, 97 | height_shift_range=0.1, 98 | horizontal_flip=True, 99 | vertical_flip=False 100 | ) 101 | ) 102 | 103 | idxs = np.random.choice(len(dset.train_data), 100) 104 | x_r, y_r = dset.train_data[idxs] 105 | for i in range(10): 106 | x, y = dset.train_data[idxs] 107 | self.assertTrue(np.all(x_r == x)) 108 | self.assertTrue(np.all(y_r == y)) 109 | 110 | dset = OntheflyAugmentedImages( 111 | orig_dset, 112 | dict( 113 | featurewise_center=True, 114 | samplewise_center=False, 115 | featurewise_std_normalization=False, 116 | samplewise_std_normalization=False, 117 | zca_whitening=True, 118 | rotation_range=0, 119 | width_shift_range=0.1, 120 | height_shift_range=0.1, 121 | horizontal_flip=True, 122 | vertical_flip=False 123 | ) 124 | ) 125 | 126 | idxs = np.random.choice(len(dset.train_data), 100) 127 | x_r, y_r = dset.train_data[idxs] 128 | for i in range(10): 129 | x, y = dset.train_data[idxs] 130 | self.assertTrue(np.all(x_r == x)) 131 | self.assertTrue(np.all(y_r == y)) 132 | 133 | def test_generator(self): 134 | def data(): 135 | while True: 136 | yield np.random.rand(32, 10), np.random.rand(32, 1) 137 | 138 | # Test with training data only 139 | dset = GeneratorDataset(data()) 140 | self.assertEqual(dset.shape, (10,)) 141 | self.assertEqual(dset.output_size, 1) 142 | with self.assertRaises(RuntimeError): 143 | len(dset.train_data) 144 | x, y = dset.train_data[:10] 145 | self.assertEqual(10, len(x)) 146 | self.assertEqual(10, len(y)) 147 | x, y = dset.train_data[:100] 148 | self.assertEqual(100, len(x)) 149 | self.assertEqual(100, len(y)) 150 | x, y = dset.train_data[[1, 2, 17]] 151 | self.assertEqual(3, len(x)) 152 | self.assertEqual(3, len(y)) 153 | x, y = dset.train_data[99] 154 | self.assertEqual(1, len(x)) 155 | self.assertEqual(1, len(y)) 156 | with self.assertRaises(RuntimeError): 157 | len(dset.test_data) 158 | with self.assertRaises(RuntimeError): 159 | dset.test_data[0] 160 | 161 | # Test with in memory test data 162 | test_data = (np.random.rand(120, 10), np.random.rand(120, 1)) 163 | dset = GeneratorDataset(data(), test_data) 164 | self.assertEqual(dset.shape, (10,)) 165 | self.assertEqual(dset.output_size, 1) 166 | with self.assertRaises(RuntimeError): 167 | len(dset.train_data) 168 | x, y = dset.train_data[:10] 169 | self.assertEqual(10, len(x)) 170 | self.assertEqual(10, len(y)) 171 | x, y = dset.train_data[:100] 172 | self.assertEqual(100, len(x)) 173 | self.assertEqual(100, len(y)) 174 | x, y = dset.train_data[[1, 2, 17]] 175 | self.assertEqual(3, len(x)) 176 | self.assertEqual(3, len(y)) 177 | x, y = dset.train_data[99] 178 | self.assertEqual(1, len(x)) 179 | self.assertEqual(1, len(y)) 180 | self.assertEqual(120, len(dset.test_data)) 181 | self.assertTrue(np.all(test_data[0] == dset.test_data[:][0])) 182 | idxs = [10, 20, 33] 183 | self.assertTrue(np.all(test_data[1][idxs] == dset.test_data[idxs][1])) 184 | 185 | # Test with generator test data 186 | dset = GeneratorDataset(data(), data(), 120) 187 | self.assertEqual(120, len(dset.test_data)) 188 | x, y = dset.test_data[:10] 189 | self.assertEqual(10, len(x)) 190 | self.assertEqual(10, len(y)) 191 | x, y = dset.test_data[:100] 192 | self.assertEqual(100, len(x)) 193 | self.assertEqual(100, len(y)) 194 | x, y = dset.test_data[[1, 2, 17]] 195 | self.assertEqual(3, len(x)) 196 | self.assertEqual(3, len(y)) 197 | x, y = dset.test_data[99] 198 | self.assertEqual(1, len(x)) 199 | self.assertEqual(1, len(y)) 200 | 201 | 202 | if __name__ == "__main__": 203 | unittest.main() 204 | -------------------------------------------------------------------------------- /tests/test_finetuning.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | import unittest 7 | 8 | from keras.applications import Xception 9 | from keras.layers import Dense, Activation 10 | from keras.models import Model 11 | import numpy as np 12 | 13 | from importance_sampling.training import ImportanceTraining 14 | 15 | 16 | class TestFinetuning(unittest.TestCase): 17 | def _generate_images(self, batch_size=16): 18 | while True: 19 | xi = np.random.rand(batch_size, 71, 71, 3) 20 | yi = np.zeros((batch_size, 10)) 21 | yi[np.arange(batch_size), np.random.choice(10, batch_size)] = 1.0 22 | yield xi, yi 23 | 24 | def test_simple_cifar(self): 25 | base = Xception( 26 | input_shape=(71, 71, 3), 27 | include_top=False, 28 | pooling="avg" 29 | ) 30 | y = Dense(10)(base.output) 31 | y = Activation("softmax")(y) 32 | model = Model(base.input, y) 33 | model.compile("sgd", "categorical_crossentropy", metrics=["accuracy"]) 34 | 35 | history = ImportanceTraining(model, presample=16).fit_generator( 36 | self._generate_images(batch_size=8), 37 | steps_per_epoch=10, 38 | epochs=1, 39 | batch_size=8 40 | ) 41 | self.assertTrue("loss" in history.history) 42 | self.assertTrue("accuracy" in history.history) 43 | self.assertEqual(len(history.history["loss"]), 1) 44 | self.assertEqual(len(history.history["accuracy"]), 1) 45 | 46 | 47 | if __name__ == "__main__": 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /tests/test_keras_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | from tempfile import NamedTemporaryFile 7 | import unittest 8 | 9 | from keras.layers import Dense, Input 10 | from keras.models import Model, Sequential 11 | import numpy as np 12 | 13 | from importance_sampling.utils import keras_utils 14 | 15 | 16 | class TestKerasUtils(unittest.TestCase): 17 | def test_simple_save_load(self): 18 | f = NamedTemporaryFile() 19 | m = Sequential([ 20 | Dense(10, input_dim=2), 21 | Dense(2) 22 | ]) 23 | m.save(f.name) 24 | 25 | w = m.get_weights() 26 | updates = keras_utils.load_weights_by_name(f.name, m) 27 | 28 | self.assertEqual(4, len(updates)) 29 | for w1, w2 in zip(w, m.get_weights()): 30 | self.assertTrue(np.allclose(w1, w2)) 31 | 32 | def test_nested_save_load(self): 33 | f = NamedTemporaryFile() 34 | m = Sequential([ 35 | Dense(10, input_dim=2), 36 | Dense(2) 37 | ]) 38 | x1, x2 = Input(shape=(2,)), Input(shape=(2,)) 39 | y1, y2 = m(x1), m(x2) 40 | model = Model([x1, x2], [y1, y2]) 41 | model.save(f.name) 42 | 43 | w = m.get_weights() 44 | updates = keras_utils.load_weights_by_name(f.name, m) 45 | 46 | self.assertEqual(4, len(updates)) 47 | for w1, w2 in zip(w, m.get_weights()): 48 | self.assertTrue(np.allclose(w1, w2)) 49 | 50 | 51 | if __name__ == "__main__": 52 | unittest.main() 53 | -------------------------------------------------------------------------------- /tests/test_model_wrappers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | from contextlib import contextmanager 7 | import unittest 8 | 9 | from blinker import signal 10 | from keras import backend as K 11 | from keras.layers import Activation, Dense, Input, Lambda, dot 12 | from keras.models import Sequential, Model 13 | import numpy as np 14 | 15 | from importance_sampling.model_wrappers import OracleWrapper 16 | from importance_sampling.reweighting import BiasedReweightingPolicy 17 | 18 | 19 | def log_signal(idx, n, success, fun): 20 | def inner(x): 21 | success[idx] = fun(idx, n) 22 | return inner 23 | 24 | 25 | @contextmanager 26 | def assert_signals(test, signal_name, fun=lambda x, n: True): 27 | if not isinstance(signal_name, (tuple, list)): 28 | signal_name = [signal_name] 29 | success = [False]*len(signal_name) 30 | loggers = [ 31 | log_signal(i, n, success, fun) 32 | for i, n in 33 | enumerate(signal_name) 34 | ] 35 | for i, n in enumerate(signal_name): 36 | signal(n).connect(loggers[i]) 37 | yield 38 | for n, s in zip(signal_name, success): 39 | test.assertTrue(s, msg=n + " signal was not received") 40 | 41 | 42 | class TestModelWrappers(unittest.TestCase): 43 | def _get_model(self): 44 | model = Sequential([ 45 | Dense(10, activation="relu", input_dim=2), 46 | Dense(10, activation="relu"), 47 | Dense(2), 48 | Activation("softmax") 49 | ]) 50 | model.compile(loss="categorical_crossentropy", optimizer="adam") 51 | 52 | wrapped = OracleWrapper(model, BiasedReweightingPolicy(), score="loss") 53 | 54 | x = np.random.rand(16, 2) 55 | y = np.zeros((16, 2)) 56 | y[range(16), np.random.choice(2, 16)] = 1.0 57 | 58 | return model, wrapped, x, y 59 | 60 | def _get_model2(self): 61 | x1 = Input(shape=(10,)) 62 | x2 = Input(shape=(10,)) 63 | y = dot([ 64 | Dense(10)(x1), 65 | Dense(10)(x2) 66 | ], axes=1) 67 | model = Model(inputs=[x1, x2], outputs=y) 68 | model.compile(loss="mse", optimizer="adam") 69 | 70 | wrapped = OracleWrapper(model, BiasedReweightingPolicy(), score="loss") 71 | 72 | x = [np.random.rand(16, 10), np.random.rand(16, 10)] 73 | y = np.random.rand(16, 1) 74 | 75 | return model, wrapped, x, y 76 | 77 | def test_model_methods(self): 78 | model_factories = [self._get_model, self._get_model2] 79 | for get_model in model_factories: 80 | model, wrapped, x, y = get_model() 81 | 82 | wrapped.set_lr(0.001) 83 | 84 | scores = wrapped.score(x, y) 85 | self.assertTrue(np.all(scores == wrapped.score(x, y))) 86 | 87 | wl, _, sc = wrapped.train_batch(x, y, np.ones((16, 1)) / y.shape[1]) 88 | self.assertTrue(np.all(wl*y.shape[1] == sc)) 89 | 90 | l = wrapped.evaluate(x, y) 91 | self.assertEqual(l.size, 1) 92 | l = wrapped.evaluate_batch(x, y) 93 | self.assertEqual(l.size, 16) 94 | 95 | def test_model_learning_phase(self): 96 | def one_zero(x): 97 | return K.in_train_phase( 98 | K.zeros_like(x), 99 | K.ones_like(x) 100 | ) 101 | 102 | model = Sequential([ 103 | Lambda(one_zero, input_shape=(1,)) 104 | ]) 105 | model.compile("sgd", "mse") 106 | 107 | x = np.random.rand(10, 1) 108 | l1 = model.test_on_batch(x, np.ones_like(x)) 109 | l2 = model.train_on_batch(x, np.zeros_like(x)) 110 | self.assertEqual(l1, 0) 111 | self.assertEqual(l2, 0) 112 | 113 | model = OracleWrapper(model, BiasedReweightingPolicy(), score="loss") 114 | l1 = model.evaluate_batch(x, np.ones_like(x))[0].sum() 115 | l2 = model.train_batch(x, np.zeros_like(x), np.ones_like(x))[0].sum() 116 | self.assertEqual(l1, 0) 117 | self.assertEqual(l2, 0) 118 | 119 | def test_signals(self): 120 | model, wrapped, x, y = self._get_model() 121 | 122 | with assert_signals(self, ["is.evaluation", "is.evaluate_batch"]): 123 | wrapped.evaluate(x, y) 124 | with assert_signals(self, "is.score"): 125 | wrapped.score(x, y) 126 | with assert_signals(self, "is.training"): 127 | wrapped.train_batch(x, y, np.ones((len(x), 1))) 128 | 129 | @unittest.skip("Not done yet") 130 | def test_metrics(self): 131 | pass 132 | 133 | @unittest.skip("Not done yet") 134 | def test_losses(self): 135 | pass 136 | 137 | 138 | if __name__ == "__main__": 139 | unittest.main() 140 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | import unittest 7 | 8 | from keras.models import Model 9 | import numpy as np 10 | 11 | from importance_sampling.models import wide_resnet 12 | from importance_sampling.pretrained import ResNet50 13 | 14 | 15 | class TestModels(unittest.TestCase): 16 | def _test_model(self, model, input_shape, output_shape, 17 | loss="categorical_crossentropy"): 18 | B = 10 19 | X = np.random.rand(B, *input_shape).astype(np.float32) 20 | y = np.random.rand(B, *output_shape).astype(np.float32) 21 | 22 | # It is indeed a model 23 | self.assertTrue(isinstance(model, Model)) 24 | 25 | # It can predict 26 | y_hat = model.predict_on_batch(X) 27 | self.assertEqual(y_hat.shape[1:], output_shape) 28 | 29 | # It can evaluate 30 | model.compile("sgd", loss) 31 | loss = model.train_on_batch(X, y) 32 | 33 | def test_wide_resnet(self): 34 | self._test_model( 35 | wide_resnet(28, 2)((32, 32, 3), 10), 36 | (32, 32, 3), 37 | (10,) 38 | ) 39 | self._test_model( 40 | wide_resnet(18, 5)((32, 32, 3), 10), 41 | (32, 32, 3), 42 | (10,) 43 | ) 44 | 45 | def test_pretrained_resnet50(self): 46 | self._test_model( 47 | ResNet50(), 48 | (224, 224, 3), 49 | (1000,) 50 | ) 51 | self._test_model( 52 | ResNet50(input_shape=(200, 200, 3), output_size=10), 53 | (200, 200, 3), 54 | (10,) 55 | ) 56 | 57 | 58 | if __name__ == "__main__": 59 | unittest.main() 60 | -------------------------------------------------------------------------------- /tests/test_reweighting.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | import unittest 7 | 8 | from keras.layers import Input 9 | from keras.models import Model 10 | import numpy as np 11 | 12 | from importance_sampling.reweighting import BiasedReweightingPolicy, \ 13 | NoReweightingPolicy 14 | 15 | 16 | class TestReweighting(unittest.TestCase): 17 | def _test_external_reweighting_layer(self, rw): 18 | s1, s2 = Input(shape=(1,)), Input(shape=(1,)) 19 | w = rw.weight_layer()([s1, s2]) 20 | m = Model(inputs=[s1, s2], outputs=[w]) 21 | m.compile("sgd", "mse") 22 | 23 | r = np.random.rand(100, 1).astype(np.float32) 24 | r_hat = m.predict([np.zeros((100, 1)), r]) 25 | self.assertTrue(np.all(r == r_hat)) 26 | 27 | def test_biased_reweighting(self): 28 | rw = BiasedReweightingPolicy(k=1.) 29 | s = np.random.rand(100) 30 | i = np.arange(100) 31 | w = rw.sample_weights(i, s).ravel() 32 | 33 | self.assertEqual(rw.weight_size, 1) 34 | self.assertAlmostEqual(w.dot(s), s.sum()) 35 | self._test_external_reweighting_layer(rw) 36 | 37 | # Make sure that it is just a normalized version of the same weights 38 | # raised to k 39 | rw = BiasedReweightingPolicy(k=0.5) 40 | w_hat = rw.sample_weights(i, s).ravel() 41 | scales = w**0.5 / w_hat 42 | scales_hat = np.ones(100)*scales[0] 43 | self.assertTrue(np.allclose(scales, scales_hat)) 44 | 45 | def test_no_reweighting(self): 46 | rw = NoReweightingPolicy() 47 | self.assertTrue( 48 | np.all( 49 | rw.sample_weights(np.arange(100), np.random.rand(100)) == 1.0 50 | ) 51 | ) 52 | self._test_external_reweighting_layer(rw) 53 | 54 | def test_adjusted_biased_reweighting(self): 55 | self.skipTest("Not implemented yet") 56 | 57 | def test_correcting_reweighting_policy(self): 58 | self.skipTest("Not implemented yet") 59 | 60 | 61 | if __name__ == "__main__": 62 | unittest.main() 63 | -------------------------------------------------------------------------------- /tests/test_samplers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | import unittest 7 | 8 | import numpy as np 9 | 10 | from importance_sampling.datasets import InMemoryDataset 11 | from importance_sampling.reweighting import UNWEIGHTED, UNBIASED 12 | from importance_sampling.samplers import AdaptiveAdditiveSmoothingSampler, \ 13 | AdditiveSmoothingSampler, ModelSampler, PowerSmoothingSampler, \ 14 | UniformSampler, ConstantVarianceSampler 15 | 16 | 17 | class MockModel(object): 18 | def __init__(self, positive_score = 1.0): 19 | self._score = positive_score 20 | 21 | def score(self, x, y, batch_size=1): 22 | s = np.ones((len(x),)) 23 | s[y[:, 1] != 0] = self._score 24 | 25 | return s 26 | 27 | 28 | class TestSamplers(unittest.TestCase): 29 | def __init__(self, *args, **kwargs): 30 | # Create a toy 2D circle dataset 31 | X = np.random.rand(1000, 2) 32 | y = (((X - np.array([[0.5, 0.5]]))**2).sum(axis=1) < 0.1).astype(int) 33 | self.dataset = InMemoryDataset( 34 | X[:600], 35 | y[:600], 36 | X[600:], 37 | y[600:] 38 | ) 39 | 40 | # The probability of selecting a point in the circle with 41 | # uniform sampling 42 | self.prior = (y[:600] == 1).sum() / 600. 43 | 44 | super(TestSamplers, self).__init__(*args, **kwargs) 45 | 46 | def _get_prob(self, a, b=1.0): 47 | """Compute the probability of sampling a positive class given the 48 | relative importances a of positive and b of negative""" 49 | p = self.prior 50 | return (p * a) / ((1-p)*b + p*a) 51 | 52 | def _test_sampler(self, sampler, N, expected_ones, error=0.02): 53 | idxs, xy, w = sampler.sample(100) 54 | 55 | self.assertEqual(len(idxs), 100) 56 | self.assertEqual(len(idxs), len(xy[0])) 57 | self.assertEqual(len(idxs), len(xy[1])) 58 | self.assertTrue(np.all(w == 1.)) 59 | 60 | ones = 0 61 | for i in range(N//100): 62 | _, (x, y), _ = sampler.sample(100) 63 | ones += y[:, 1].sum() 64 | self.assertTrue( 65 | expected_ones - N*error < ones < expected_ones + N*error, 66 | "Got %d and expected %d" % (ones, expected_ones) 67 | ) 68 | 69 | def test_uniform_sampler(self): 70 | N = 10000 71 | expected_ones = self.prior * N 72 | 73 | self._test_sampler( 74 | UniformSampler(self.dataset, UNWEIGHTED), 75 | N, 76 | expected_ones 77 | ) 78 | 79 | def test_model_sampler(self): 80 | importance = 4.0 81 | N = 10000 82 | expected_ones = N * self._get_prob(importance) 83 | 84 | self._test_sampler( 85 | ModelSampler(self.dataset, UNWEIGHTED, MockModel(importance)), 86 | N, 87 | expected_ones 88 | ) 89 | 90 | def test_additive_smoothing_sampler(self): 91 | importance = 4.0 92 | c = 2.0 93 | N = 10000 94 | expected_ones = N * self._get_prob(importance + c, 1.0 + c) 95 | 96 | self._test_sampler( 97 | AdditiveSmoothingSampler( 98 | ModelSampler(self.dataset, UNWEIGHTED, MockModel(importance)), 99 | c=c 100 | ), 101 | N, 102 | expected_ones 103 | ) 104 | 105 | def test_adaptive_additive_smoothing_sampler(self): 106 | importance = 4.0 107 | c = (self.prior * 4.0 + (1.0 - self.prior) * 1.0) / 2. 108 | N = 10000 109 | expected_ones = N * self._get_prob(importance + c, 1.0 + c) 110 | 111 | self._test_sampler( 112 | AdaptiveAdditiveSmoothingSampler( 113 | ModelSampler(self.dataset, UNWEIGHTED, MockModel(importance)) 114 | ), 115 | N, 116 | expected_ones 117 | ) 118 | 119 | def test_power_smoothing_sampler(self): 120 | importance = 4.0 121 | N = 10000 122 | expected_ones = N * self._get_prob(importance**0.5) 123 | 124 | self._test_sampler( 125 | PowerSmoothingSampler( 126 | ModelSampler(self.dataset, UNWEIGHTED, MockModel(importance)) 127 | ), 128 | N, 129 | expected_ones 130 | ) 131 | 132 | def test_constant_variance_sampler(self): 133 | importance = 100.0 134 | y = np.zeros(1024) 135 | y[np.random.choice(1024, 10)] = 1.0 136 | dataset = InMemoryDataset( 137 | np.random.rand(1024, 2), 138 | y, 139 | np.random.rand(100, 2), 140 | y[:100] 141 | ) 142 | model = MockModel(importance) 143 | sampler = ConstantVarianceSampler(dataset, UNBIASED, model) 144 | 145 | idxs1, xy, w = sampler.sample(100) 146 | sampler.update(idxs1, model.score(xy[0], xy[1])) 147 | for i in range(30): 148 | _, xy, _ = sampler.sample(100) 149 | sampler.update(idxs1, model.score(xy[0], xy[1])) 150 | idxs2, xy, w = sampler.sample(100) 151 | sampler.update(idxs2, model.score(xy[0], xy[1])) 152 | 153 | self.assertEqual(len(idxs1), 100) 154 | self.assertLess(len(idxs2), 100) 155 | 156 | @unittest.skip("Not implemented yet") 157 | def test_lstm_sampler(self): 158 | pass 159 | 160 | @unittest.skip("Not implemented yet") 161 | def test_per_class_gaussian_sampler(self): 162 | pass 163 | 164 | 165 | if __name__ == "__main__": 166 | unittest.main() 167 | -------------------------------------------------------------------------------- /tests/test_save_load.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | from os import path 7 | import shutil 8 | import tempfile 9 | import unittest 10 | 11 | from keras.callbacks import ModelCheckpoint 12 | from keras.layers import Dense 13 | from keras.models import Sequential 14 | import numpy as np 15 | 16 | from importance_sampling.training import ImportanceTraining 17 | 18 | 19 | class TestSaveLoad(unittest.TestCase): 20 | @classmethod 21 | def setUpClass(cls): 22 | cls.tmpdir = tempfile.mkdtemp() 23 | 24 | @classmethod 25 | def tearDownClass(cls): 26 | shutil.rmtree(cls.tmpdir) 27 | 28 | def test_checkpoint(self): 29 | m = Sequential([ 30 | Dense(10, activation="relu", input_shape=(2,)), 31 | Dense(2) 32 | ]) 33 | m.compile("sgd", "mse") 34 | x = np.random.rand(32, 2) 35 | y = np.random.rand(32, 2) 36 | print(m.loss) 37 | ImportanceTraining(m).fit( 38 | x, y, 39 | epochs=1, 40 | callbacks=[ModelCheckpoint( 41 | path.join(self.tmpdir, "model.h5") 42 | )] 43 | ) 44 | 45 | 46 | if __name__ == "__main__": 47 | unittest.main() 48 | -------------------------------------------------------------------------------- /tests/test_seq2seq.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | import unittest 7 | 8 | from keras.layers import Activation, Embedding, LSTM 9 | from keras.models import Sequential 10 | import numpy as np 11 | 12 | from importance_sampling.training import ImportanceTraining 13 | 14 | 15 | class TestSeq2Seq(unittest.TestCase): 16 | def test_simple_seq2seq(self): 17 | model = Sequential([ 18 | Embedding(100, 32, mask_zero=True, input_length=10), 19 | LSTM(32, return_sequences=True), 20 | LSTM(10, return_sequences=True), 21 | Activation("softmax") 22 | ]) 23 | model.compile("adam", "categorical_crossentropy") 24 | 25 | x = (np.random.rand(10, 10)*100).astype(np.int32) 26 | y = np.random.rand(10, 10, 10) 27 | y /= y.sum(axis=-1, keepdims=True) 28 | ImportanceTraining(model).fit(x, y, batch_size=10) 29 | 30 | 31 | if __name__ == "__main__": 32 | unittest.main() 33 | -------------------------------------------------------------------------------- /tests/test_training.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | import unittest 7 | 8 | from keras import backend as K 9 | from keras.callbacks import LambdaCallback 10 | from keras.layers import Dense, Input, dot 11 | from keras.models import Model, Sequential 12 | import numpy as np 13 | 14 | from importance_sampling.training import ImportanceTraining, \ 15 | BiasedImportanceTraining, ApproximateImportanceTraining, \ 16 | ConstantVarianceImportanceTraining, ConstantTimeImportanceTraining 17 | from importance_sampling.samplers import BaseSampler 18 | 19 | 20 | class TestTraining(unittest.TestCase): 21 | TRAININGS = [ 22 | ImportanceTraining, BiasedImportanceTraining, 23 | ApproximateImportanceTraining, ConstantVarianceImportanceTraining, 24 | ConstantTimeImportanceTraining 25 | ] 26 | 27 | def __init__(self, *args, **kwargs): 28 | self.model = Sequential([ 29 | Dense(10, activation="relu", input_shape=(2,)), 30 | Dense(10, activation="relu"), 31 | Dense(2) 32 | ]) 33 | self.model.compile("sgd", "mse", metrics=["mae"]) 34 | 35 | x1 = Input(shape=(10,)) 36 | x2 = Input(shape=(10,)) 37 | y = dot([ 38 | Dense(10)(x1), 39 | Dense(10)(x2) 40 | ], axes=1) 41 | self.model2 = Model(inputs=[x1, x2], outputs=y) 42 | self.model2.compile(loss="mse", optimizer="adam") 43 | 44 | super(TestTraining, self).__init__(*args, **kwargs) 45 | 46 | def test_simple_training(self): 47 | for Training in self.TRAININGS: 48 | model = Training(self.model) 49 | x = np.random.rand(128, 2) 50 | y = np.random.rand(128, 2) 51 | 52 | history = model.fit(x, y, epochs=5) 53 | self.assertTrue("loss" in history.history) 54 | self.assertEqual(len(history.history["loss"]), 5) 55 | self.assertFalse(any(np.isnan(history.history["loss"]))) 56 | 57 | def test_generator_training(self): 58 | def gen(): 59 | while True: 60 | yield np.random.rand(16, 2), np.random.rand(16, 2) 61 | 62 | def gen2(): 63 | while True: 64 | yield (np.random.rand(16, 10), np.random.rand(16, 10)), \ 65 | np.random.rand(16, 1) 66 | x_val1, y_val1 = np.random.rand(32, 2), np.random.rand(32, 2) 67 | x_val2, y_val2 = (np.random.rand(32, 10), np.random.rand(32, 10)), \ 68 | np.random.rand(32, 1) 69 | 70 | for Training in [ImportanceTraining]: 71 | model = Training(self.model) 72 | history = model.fit_generator( 73 | gen(), validation_data=(x_val1, y_val1), 74 | steps_per_epoch=8, epochs=5 75 | ) 76 | self.assertTrue("loss" in history.history) 77 | self.assertEqual(len(history.history["loss"]), 5) 78 | 79 | model = Training(self.model2) 80 | history = model.fit_generator( 81 | gen2(), validation_data=(x_val2, y_val2), 82 | steps_per_epoch=8, epochs=5 83 | ) 84 | self.assertTrue("loss" in history.history) 85 | self.assertEqual(len(history.history["loss"]), 5) 86 | self.assertFalse(any(np.isnan(history.history["loss"]))) 87 | 88 | with self.assertRaises(NotImplementedError): 89 | ApproximateImportanceTraining(self.model).fit_generator( 90 | gen(), validation_data=(x_val1, y_val1), 91 | steps_per_epoch=8, epochs=5 92 | ) 93 | 94 | def test_multiple_inputs(self): 95 | x1 = np.random.rand(64, 10) 96 | x2 = np.random.rand(64, 10) 97 | y = np.random.rand(64, 1) 98 | 99 | for Training in [ImportanceTraining]: 100 | model = Training(self.model2) 101 | 102 | history = model.fit([x1, x2], y, epochs=5, batch_size=16) 103 | self.assertTrue("loss" in history.history) 104 | self.assertEqual(len(history.history["loss"]), 5) 105 | self.assertFalse(any(np.isnan(history.history["loss"]))) 106 | 107 | def test_regularizers(self): 108 | reg = lambda w: 10 109 | model = Sequential([ 110 | Dense(10, activation="relu", kernel_regularizer=reg, 111 | input_shape=(2,)), 112 | Dense(10, activation="relu", kernel_regularizer=reg), 113 | Dense(2) 114 | ]) 115 | model.compile("sgd", "mse") 116 | model = ImportanceTraining(model) 117 | history = model.fit(np.random.rand(64, 2), np.random.rand(64, 2)) 118 | 119 | self.assertGreater(history.history["loss"][0], 20.) 120 | 121 | def test_on_sample(self): 122 | calls = [0] 123 | def on_sample(sampler, idxs, w, scores): 124 | calls[0] += 1 125 | self.assertTrue(isinstance(sampler, BaseSampler)) 126 | self.assertEqual(len(idxs), len(w)) 127 | self.assertEqual(len(idxs), len(scores)) 128 | 129 | model = Sequential([ 130 | Dense(10, activation="relu", input_shape=(2,)), 131 | Dense(10, activation="relu"), 132 | Dense(2) 133 | ]) 134 | model.compile("sgd", "mse") 135 | 136 | for Training in self.TRAININGS: 137 | Training(model).fit( 138 | np.random.rand(64, 2), np.random.rand(64, 2), 139 | batch_size=16, 140 | epochs=4, 141 | on_sample=on_sample 142 | ) 143 | self.assertEqual(16, calls[0]) 144 | calls[0] = 0 145 | 146 | def test_metric_names(self): 147 | def metric1(y, y_hat): 148 | return K.mean((y - y_hat), axis=-1) 149 | 150 | model = Sequential([ 151 | Dense(10, activation="relu", input_shape=(2,)), 152 | Dense(10, activation="relu"), 153 | Dense(2) 154 | ]) 155 | model.compile("sgd", "mse", metrics=["mse", metric1]) 156 | for Training in self.TRAININGS: 157 | wm = Training(model) 158 | self.assertEqual( 159 | tuple(wm.metrics_names), 160 | ("loss", "mse", "metric1", "score") 161 | ) 162 | 163 | 164 | if __name__ == "__main__": 165 | unittest.main() 166 | --------------------------------------------------------------------------------