├── .gitignore ├── LICENSE ├── README.md └── fid.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | .dmypy.json 121 | dmypy.json 122 | 123 | # Pyre type checker 124 | .pyre/ 125 | 126 | .DS_Store 127 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Jussi Leinonen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Frechet Inception Distance for Keras GANs 2 | 3 | This module contains an implementation of the Frechet Inception Distance (FID) 4 | metric for Keras-based generative adversarial network (GAN) generators. 5 | 6 | The FID is defined here: https://arxiv.org/abs/1706.08500 7 | 8 | ### Usage 9 | 10 | A basic example: 11 | ```python 12 | import fid 13 | 14 | generator = ... # Your code for creating the GAN generator 15 | noise = ... # Your generator inputs (usually random noise) 16 | real_images = ... # Your training dataset 17 | 18 | # change (0,1) to the range of values in your dataset 19 | fd = fid.FrechetInceptionDistance(generator, (0,1)) 20 | gan_fid = fd(real_images, noise) 21 | ``` 22 | 23 | If you already have the means and covariances: 24 | ```python 25 | gan_fid = fid.frechet_distance(mean1, cov1, mean2, cov2) 26 | ``` 27 | 28 | ### More information 29 | 30 | See the docstrings for the `fid.FrechetInceptionDistance` class and 31 | the `fid.frechet_distance` function for more detailed documentation. 32 | -------------------------------------------------------------------------------- /fid.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from keras.applications.inception_v3 import InceptionV3 4 | from keras import backend as K 5 | import numpy as np 6 | 7 | 8 | def update_mean_cov(mean, cov, N, batch): 9 | batch_N = batch.shape[0] 10 | 11 | x = batch 12 | N += batch_N 13 | x_norm_old = batch-mean 14 | mean = mean + x_norm_old.sum(axis=0)/N 15 | x_norm_new = batch-mean 16 | cov = ((N-batch_N)/N)*cov + x_norm_old.T.dot(x_norm_new)/N 17 | 18 | return (mean, cov, N) 19 | 20 | 21 | def frechet_distance(mean1, cov1, mean2, cov2): 22 | """Frechet distance between two multivariate Gaussians. 23 | 24 | Arguments: 25 | mean1, cov1, mean2, cov2: The means and covariances of the two 26 | multivariate Gaussians. 27 | 28 | Returns: 29 | The Frechet distance between the two distributions. 30 | """ 31 | 32 | def check_nonpositive_eigvals(l): 33 | nonpos = (l < 0) 34 | if nonpos.any(): 35 | warnings.warn('Rank deficient covariance matrix, ' 36 | 'Frechet distance will not be accurate.', Warning) 37 | l[nonpos] = 0 38 | 39 | (l1,v1) = np.linalg.eigh(cov1) 40 | check_nonpositive_eigvals(l1) 41 | cov1_sqrt = (v1*np.sqrt(l1)).dot(v1.T) 42 | cov_prod = cov1_sqrt.dot(cov2).dot(cov1_sqrt) 43 | lp = np.linalg.eigvalsh(cov_prod) 44 | check_nonpositive_eigvals(lp) 45 | 46 | trace = l1.sum() + np.trace(cov2) - 2*np.sqrt(lp).sum() 47 | diff_mean = mean1-mean2 48 | fd = diff_mean.dot(diff_mean) + trace 49 | 50 | return fd 51 | 52 | 53 | class InputIterator(object): 54 | def __init__(self, inputs, batch_size=64, shuffle=True, seed=None): 55 | self._inputs = inputs 56 | self._inputs_list = isinstance(inputs, list) 57 | self._N = self._inputs[0].shape[0] if self._inputs_list else \ 58 | self._inputs.shape[0] 59 | self.batch_size = batch_size 60 | self._shuffle = shuffle 61 | self._prng = np.random.RandomState(seed=seed) 62 | self._next_indices = np.array([], dtype=np.uint) 63 | 64 | def __iter__(self): 65 | return self 66 | 67 | def __next__(self): 68 | while len(self._next_indices) < self.batch_size: 69 | next_ind = np.arange(self._N, dtype=np.uint) 70 | if self._shuffle: 71 | self._prng.shuffle(next_ind) 72 | self._next_indices = np.concatenate(( 73 | self._next_indices, next_ind)) 74 | 75 | ind = self._next_indices[:self.batch_size] 76 | self._next_indices = self._next_indices[self.batch_size:] 77 | 78 | if self._inputs_list: 79 | batch = [inp[ind,...] for inp in self._inputs] 80 | else: 81 | batch = self._inputs[ind,...] 82 | 83 | return batch 84 | 85 | 86 | class FrechetInceptionDistance(object): 87 | """Frechet Inception Distance. 88 | 89 | Class for evaluating Keras-based GAN generators using the Frechet 90 | Inception Distance (Heusel et al. 2017, 91 | https://arxiv.org/abs/1706.08500). 92 | 93 | Arguments to constructor: 94 | generator: a Keras model trained as a GAN generator 95 | image_range: A tuple giving the range of values in the images output 96 | by the generator. This is used to rescale to the (-1,1) range 97 | expected by the Inception V3 network. 98 | generator_postprocessing: A function, preserving the shape of the 99 | output, to be applied to all generator outputs for further 100 | postprocessing. If None (default), no postprocessing will be 101 | done. 102 | 103 | Attributes: The arguments above all have a corresponding attribute 104 | with the same name that can be safely changed after initialization. 105 | 106 | Arguments to call: 107 | real_images: An 4D NumPy array of images from the training dataset, 108 | or a Python generator outputting training batches. The number of 109 | channels must be either 3 or 1 (in the latter case, the single 110 | channel is distributed to each of the 3 channels expected by the 111 | Inception network). 112 | generator_inputs: One of the following: 113 | 1. A NumPy array with generator inputs, or 114 | 2. A list of NumPy arrays (if the generator has multiple inputs) 115 | 3. A Python generator outputting batches of generator inputs 116 | (either a single array or a list of arrays) 117 | batch_size: The size of the batches in which the data is processed. 118 | No effect if Python generators are passed as real_images or 119 | generator_inputs. 120 | num_batches_real: Number of batches to use to evaluate the mean and 121 | the covariance of the real samples. 122 | num_batches_gen: Number of batches to use to evaluate the mean and 123 | the covariance of the generated samples. If None (default), set 124 | equal to num_batches_real. 125 | shuffle: If True (default), samples are randomly selected from the 126 | input arrays. No effect if real_images or generator_inputs is 127 | a Python generator. 128 | seed: A random seed for shuffle (to provide reproducible results) 129 | 130 | Returns (call): 131 | The Frechet Inception Distance between the real and generated data. 132 | """ 133 | 134 | def __init__(self, generator, image_range=, 135 | generator_postprocessing=None): 136 | 137 | self._inception_v3 = None 138 | self.generator = generator 139 | self.generator_postprocessing = generator_postprocessing 140 | self.image_range = image_range 141 | self._channels_axis = \ 142 | -1 if K.image_data_format()=="channels_last" else -3 143 | 144 | def _setup_inception_network(self): 145 | self._inception_v3 = InceptionV3( 146 | include_top=False, pooling='avg') 147 | self._pool_size = self._inception_v3.output_shape[-1] 148 | 149 | def _preprocess(self, images): 150 | if self.image_range != (-1,1): 151 | images = images - self.image_range[0] 152 | images /= (self.image_range[1]-self.image_range[0])/2.0 153 | images -= 1.0 154 | if images.shape[self._channels_axis] == 1: 155 | images = np.concatenate([images]*3, axis=self._channels_axis) 156 | return images 157 | 158 | def _stats(self, inputs, input_type="real", postprocessing=None, 159 | batch_size=64, num_batches=128, shuffle=True, seed=None): 160 | 161 | mean = np.zeros(self._pool_size) 162 | cov = np.zeros((self._pool_size,self._pool_size)) 163 | N = 0 164 | 165 | for i in range(num_batches): 166 | try: 167 | # draw a batch from generator input iterator 168 | batch = next(inputs) 169 | except TypeError: 170 | # assume that an array or a list of arrays was passed 171 | # instead 172 | inputs = InputIterator(inputs, 173 | batch_size=batch_size, shuffle=shuffle, seed=seed) 174 | batch = next(inputs) 175 | 176 | if input_type=="generated": 177 | batch = self.generator.predict(batch) 178 | if postprocessing is not None: 179 | batch = postprocessing(batch) 180 | batch = self._preprocess(batch) 181 | pool = self._inception_v3.predict(batch, batch_size=batch_size) 182 | 183 | (mean, cov, N) = update_mean_cov(mean, cov, N, pool) 184 | 185 | return (mean, cov) 186 | 187 | def __call__(self, 188 | real_images, 189 | generator_inputs, 190 | batch_size=64, 191 | num_batches_real=128, 192 | num_batches_gen=None, 193 | shuffle=True, 194 | seed=None 195 | ): 196 | 197 | if self._inception_v3 is None: 198 | self._setup_inception_network() 199 | 200 | (real_mean, real_cov) = self._stats(real_images, 201 | "real", batch_size=batch_size, num_batches=num_batches_real, 202 | shuffle=shuffle, seed=seed) 203 | if num_batches_gen is None: 204 | num_batches_gen = num_batches_real 205 | (gen_mean, gen_cov) = self._stats(generator_inputs, 206 | "generated", batch_size=batch_size, num_batches=num_batches_gen, 207 | postprocessing=self.generator_postprocessing, 208 | shuffle=shuffle, seed=seed) 209 | 210 | return frechet_distance(real_mean, real_cov, gen_mean, gen_cov) 211 | --------------------------------------------------------------------------------