├── .gitignore ├── LICENSE ├── README.md ├── pytorch_fid_wrapper ├── __init__.py ├── fid_score.py ├── inception.py └── params.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # pfw 2 | .vscode 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | 137 | # pytype static type analyzer 138 | .pytype/ 139 | 140 | # Cython debug symbols 141 | cython_debug/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-fid-wrapper 2 | A simple wrapper around [@mseitzer](https://github.com/mseitzer)'s great [**pytorch-fid**](https://github.com/mseitzer/pytorch-fid) work. 3 | 4 | The goal is to compute the Fréchet Inception Distance between two sets of images *in-memory* using PyTorch. 5 | 6 | ## Installation 7 | 8 | [![PyPI](https://img.shields.io/pypi/v/pytorch-fid-wrapper.svg)](https://pypi.org/project/pytorch-fid/) 9 | 10 | ``` 11 | pip install pytorch-fid-wrapper 12 | ``` 13 | 14 | Requires (and will install) (as `pytorch-fid`): 15 | * Python >= 3.5 16 | * Pillow 17 | * Numpy 18 | * Scipy 19 | * Torch 20 | * Torchvision 21 | 22 | ## Usage 23 | 24 | ```python 25 | import pytorch_fid_wrapper as pfw 26 | 27 | # --------------------------- 28 | # ----- Initial Setup ----- 29 | # --------------------------- 30 | 31 | # Optional: set pfw's configuration with your parameters once and for all 32 | pfw.set_config(batch_size=BATCH_SIZE, dims=DIMS, device=DEVICE) 33 | 34 | # Optional: compute real_m and real_s only once, they will not change during training 35 | real_m, real_s = pfw.get_stats(real_images) 36 | 37 | ... 38 | 39 | # ------------------------------------- 40 | # ----- Computing the FID Score ----- 41 | # ------------------------------------- 42 | 43 | val_fid = pfw.fid(fake_images, real_m=real_m, real_s=real_s) # (1) 44 | 45 | # OR 46 | 47 | val_fid = pfw.fid(fake_images, real_images=new_real_images) # (2) 48 | ``` 49 | 50 | All `_images` variables in the example above are `torch.Tensor` instances with shape `N x C x H x W`. They will be sent to the appropriate device depending on what you ask for (see [Config](#config)). 51 | 52 | To compute the FID score between your fake images and some real dataset, you can **either** re-use pre-computed stats `real_m`, `real_s` at each validation stage `(1)`, **or** provide another dataset for which the stats will be computed (in addition to your fake images' which are computed in both scenarios) `(2)`. Score is computed in `pfw.fid_score.calculate_frechet_distance(...)`, following [`pytorch-fid`](https://github.com/mseitzer/pytorch-fid)'s implementation. 53 | 54 | Please refer to [**pytorch-fid**](https://github.com/mseitzer/pytorch-fid) for any documentation on the InceptionV3 implementation or FID calculations. 55 | 56 | ## Config 57 | 58 | `pfw.get_stats(...)` and `pfw.fid(...)` need to know what block of the InceptionV3 model to use (`dims`), on what device to compute inference (`device`) and with what batch size (`batch_size`). 59 | 60 | Default values are in `pfw.params`: `batch_size = 50`, `dims = 2048` and `device = "cpu"`. If you want to override those, you have two options: 61 | 62 | 1/ override any of these parameters in the function calls. For instance: 63 | ```python 64 | pfw.fid(fake_images, new_real_data, device="cuda:0") 65 | ``` 66 | 2/ override the params globally with `pfw.set_config` and set them for all future calls without passing parameters again. For instance: 67 | ```python 68 | pfw.set_config(batch_size=100, dims=768, device="cuda:0") 69 | ... 70 | pfw.fid(fake_images, new_real_data) 71 | ``` 72 | 73 | ## Recognition 74 | 75 | Remember to cite their work if using [`pytorch-fid-wrapper`](https://github.com/vict0rsch/pytorch-fid-wrapper) or [`pytorch-fid`](https://github.com/mseitzer/pytorch-fid): 76 | 77 | ``` 78 | @misc{Seitzer2020FID, 79 | author={Maximilian Seitzer}, 80 | title={{pytorch-fid: FID Score for PyTorch}}, 81 | month={August}, 82 | year={2020}, 83 | note={Version 0.1.1}, 84 | howpublished={\url{https://github.com/mseitzer/pytorch-fid}}, 85 | } 86 | ``` 87 | 88 | ## License 89 | 90 | This implementation is licensed under the Apache License 2.0. 91 | 92 | FID was introduced by Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler and Sepp Hochreiter in "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium", see [https://arxiv.org/abs/1706.08500](https://arxiv.org/abs/1706.08500) 93 | 94 | The original implementation is by the Institute of Bioinformatics, JKU Linz, licensed under the Apache License 2.0. 95 | See [https://github.com/bioinf-jku/TTUR](https://github.com/bioinf-jku/TTUR). 96 | -------------------------------------------------------------------------------- /pytorch_fid_wrapper/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.4" 2 | from importlib import import_module 3 | from pathlib import Path 4 | from pytorch_fid_wrapper import params 5 | from pytorch_fid_wrapper.fid_score import fid, get_stats 6 | from pytorch_fid_wrapper.inception import InceptionV3 7 | 8 | 9 | def set_config(batch_size=None, dims=None, device=None): 10 | """ 11 | Sets pfw's global configuration to get rid of parameters in function 12 | calls when they don't change over the course of training. 13 | 14 | Any one of them can be set independently. 15 | 16 | Args: 17 | batch_size (int, optional): batch_size for inception inference. 18 | Defaults to None. 19 | dims (int, optional): which inception block to select. 20 | See InceptionV3.BLOCK_INDEX_BY_DIM. Defaults to None. 21 | device (any, optional): PyTorch device, as a string or device instance. 22 | Defaults to None. 23 | """ 24 | if batch_size is not None: 25 | assert isinstance(batch_size, int) 26 | assert batch_size > 0 27 | params.batch_size = batch_size 28 | if dims is not None: 29 | assert isinstance(dims, int) 30 | assert dims in InceptionV3.BLOCK_INDEX_BY_DIM 31 | params.dims = dims 32 | if device is not None: 33 | params.device = device 34 | 35 | 36 | __all__ = [ 37 | import_module(f".{f.stem}", __package__) 38 | for f in Path(__file__).parent.glob("*.py") 39 | if "__" not in f.stem 40 | ] 41 | 42 | del import_module, Path 43 | -------------------------------------------------------------------------------- /pytorch_fid_wrapper/fid_score.py: -------------------------------------------------------------------------------- 1 | """ 2 | # ---------------------------- 3 | # ----- pfw docstrings ----- 4 | # ---------------------------- 5 | 6 | Adapted from: 7 | https://github.com/mseitzer/pytorch-fid/blob/4d7695b39764ba1d54ab6639e0695e5c4e6f346a/pytorch_fid/fid_score.py 8 | 9 | Modifications are: 10 | * modify calculate_activation_statistics ot handle in-memory N x C x H x W tensors 11 | instead of file lists with a dataloader 12 | * add fid() and get_stats() 13 | 14 | # --------------------------------------------- 15 | # ----- pytorch-fid original docstrings ----- 16 | # --------------------------------------------- 17 | 18 | Calculates the Frechet Inception Distance (FID) to evaluate GANs 19 | The FID metric calculates the distance between two distributions of images. 20 | Typically, we have summary statistics (mean & covariance matrix) of one 21 | of these distributions, while the 2nd distribution is given by a GAN. 22 | When run as a stand-alone program, it compares the distribution of 23 | images that are stored as PNG/JPEG at a specified location with a 24 | distribution given by summary statistics (in pickle format). 25 | The FID is calculated by assuming that X_1 and X_2 are the activations of 26 | the pool_3 layer of the inception net for generated samples and real world 27 | samples respectively. 28 | See --help to see further details. 29 | Code adapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 30 | of Tensorflow 31 | Copyright 2018 Institute of Bioinformatics, JKU Linz 32 | Licensed under the Apache License, Version 2.0 (the "License"); 33 | you may not use this file except in compliance with the License. 34 | You may obtain a copy of the License at 35 | http://www.apache.org/licenses/LICENSE-2.0 36 | Unless required by applicable law or agreed to in writing, software 37 | distributed under the License is distributed on an "AS IS" BASIS, 38 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 39 | See the License for the specific language governing permissions and 40 | limitations under the License. 41 | """ 42 | import numpy as np 43 | import torch 44 | from scipy import linalg 45 | from torch.nn.functional import adaptive_avg_pool2d 46 | 47 | 48 | from pytorch_fid_wrapper.inception import InceptionV3 49 | from pytorch_fid_wrapper import params as pfw_params 50 | 51 | 52 | def get_activations(images, model, batch_size=50, dims=2048, device="cpu"): 53 | """ 54 | Calculates the activations of the pool_3 layer for all images. 55 | 56 | Args: 57 | images ([type]): Tensor of images N x C x H x W 58 | model ([type]): Instance of inception model 59 | batch_size (int, optional): Batch size of images for the model to process at 60 | once. Make sure that the number of samples is a multiple of 61 | the batch size, otherwise some samples are ignored. This behavior is 62 | retained to match the original FID score implementation. Defaults to 50. 63 | dims (int, optional): Dimensionality of features returned by Inception. 64 | Defaults to 2048. 65 | device (str | torch.device, optional): Device to run calculations. 66 | Defaults to "cpu". 67 | 68 | Returns: 69 | np.ndarray: A numpy array of dimension (num images, dims) that contains the 70 | activations of the given tensor when feeding inception with the query 71 | tensor. 72 | """ 73 | 74 | model.eval() 75 | 76 | n_batches = len(images) // batch_size 77 | 78 | assert n_batches > 0, ( 79 | "Not enough images to make at least 1 full batch. " 80 | + "Provide more images or decrease batch_size" 81 | ) 82 | 83 | pred_arr = np.empty((len(images), dims)) 84 | 85 | start_idx = 0 86 | 87 | for b in range(n_batches): 88 | 89 | batch = images[b * batch_size : (b + 1) * batch_size].to(device) 90 | 91 | if batch.nelement() == 0: 92 | continue 93 | 94 | with torch.no_grad(): 95 | pred = model(batch)[0] 96 | 97 | # If model output is not scalar, apply global spatial average pooling. 98 | # This happens if you choose a dimensionality not equal 2048. 99 | if pred.size(2) != 1 or pred.size(3) != 1: 100 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 101 | 102 | pred = pred.squeeze(3).squeeze(2).cpu().numpy() 103 | 104 | pred_arr[start_idx : start_idx + pred.shape[0]] = pred 105 | 106 | start_idx = start_idx + pred.shape[0] 107 | 108 | return pred_arr 109 | 110 | 111 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 112 | """ 113 | Numpy implementation of the Frechet Distance. 114 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 115 | and X_2 ~ N(mu_2, C_2) is 116 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 117 | Stable version by Dougal J. Sutherland. 118 | 119 | Args: 120 | mu1 (np.ndarray): Numpy array containing the activations of a layer of the 121 | inception net (like returned by the function 'get_predictions') 122 | for generated samples. 123 | sigma1 (np.ndarray): The covariance matrix over activations for generated 124 | samples. 125 | mu2 (np.ndarray): The sample mean over activations, precalculated on a 126 | representative data set. 127 | sigma2 (np.ndarray): The covariance matrix over activations, precalculated on an 128 | representative data set. 129 | eps (float, optional): Fallback in case of infinite covariance. 130 | Defaults to 1e-6. 131 | 132 | Returns: 133 | float: The Frechet Distance. 134 | """ 135 | 136 | mu1 = np.atleast_1d(mu1) 137 | mu2 = np.atleast_1d(mu2) 138 | 139 | sigma1 = np.atleast_2d(sigma1) 140 | sigma2 = np.atleast_2d(sigma2) 141 | 142 | assert ( 143 | mu1.shape == mu2.shape 144 | ), "Training and test mean vectors have different lengths" 145 | assert ( 146 | sigma1.shape == sigma2.shape 147 | ), "Training and test covariances have different dimensions" 148 | 149 | diff = mu1 - mu2 150 | 151 | # Product might be almost singular 152 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 153 | if not np.isfinite(covmean).all(): 154 | msg = ( 155 | "fid calculation produces singular product; " 156 | "adding %s to diagonal of cov estimates" 157 | ) % eps 158 | print(msg) 159 | offset = np.eye(sigma1.shape[0]) * eps 160 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 161 | 162 | # Numerical error might give slight imaginary component 163 | if np.iscomplexobj(covmean): 164 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 165 | m = np.max(np.abs(covmean.imag)) 166 | raise ValueError("Imaginary component {}".format(m)) 167 | covmean = covmean.real 168 | 169 | tr_covmean = np.trace(covmean) 170 | 171 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 172 | 173 | 174 | def calculate_activation_statistics( 175 | images, model, batch_size=50, dims=2048, device="cpu" 176 | ): 177 | """ 178 | Calculation of the statistics used by the FID. 179 | 180 | Args: 181 | images (torch.Tensor): Tensor of images N x C x H x W 182 | model (torch.nn.Module): Instance of inception model 183 | batch_size (int, optional): The images tensor is split into batches with 184 | batch size batch_size. A reasonable batch size depends on the hardware. 185 | Defaults to 50. 186 | dims (int, optional): Dimensionality of features returned by Inception. 187 | Defaults to 2048. 188 | device (str | torch.device, optional): Device to run calculations. 189 | Defaults to "cpu". 190 | 191 | Returns: 192 | tuple(np.ndarray, np.ndarray): (mu, sigma) 193 | mu => The mean over samples of the activations of the pool_3 layer of 194 | the inception model. 195 | sigma => The covariance matrix of the activations of the pool_3 layer of 196 | the inception model. 197 | """ 198 | act = get_activations(images, model, batch_size, dims, device) 199 | mu = np.mean(act, axis=0) 200 | sigma = np.cov(act, rowvar=False) 201 | return mu, sigma 202 | 203 | 204 | def get_stats(images, model=None, batch_size=None, dims=None, device=None): 205 | """ 206 | Get the InceptionV3 activation statistics (mu, sigma) for a batch of `images`. 207 | 208 | If `model` (InceptionV3) is not provided, it will be instanciated according 209 | to `dims`. 210 | 211 | Other arguments are optional and will be inherited from `pfw.params` if not 212 | provided. Use `pfw.set_config` to change those params globally for future calls 213 | 214 | 215 | Args: 216 | images (torch.Tensor): The images to compute the statistics for. Must be 217 | N x C x H x W 218 | model (torch.nn.Module, optional): InceptionV3 model. Defaults to None. 219 | batch_size (int, optional): Inception inference batch size. 220 | Will use `pfw.params.batch_size` if not provided. Defaults to None. 221 | dims (int, optional): which inception block to select. See 222 | InceptionV3.BLOCK_INDEX_BY_DIM. Will use pfw.params.dims if not provided. 223 | Defaults to None. 224 | device (str | torch.device, optional): PyTorch device for inception inference. 225 | Will use pfw.params.device if not provided. Defaults to None. 226 | 227 | Returns: 228 | tuple(np.ndarray, np.ndarray): (mu, sigma) 229 | mu => The mean over samples of the activations of the pool_3 layer of 230 | the inception model. 231 | sigma => The covariance matrix of the activations of the pool_3 layer of 232 | the inception model. 233 | """ 234 | if batch_size is None: 235 | batch_size = pfw_params.batch_size 236 | if dims is None: 237 | dims = pfw_params.dims 238 | if device is None: 239 | device = pfw_params.device 240 | 241 | if model is None: 242 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 243 | model = InceptionV3([block_idx]).to(device) 244 | else: 245 | assert isinstance(model, InceptionV3) 246 | return calculate_activation_statistics(images, model, batch_size, dims, device) 247 | 248 | 249 | def fid( 250 | fake_images, 251 | real_images=None, 252 | real_m=None, 253 | real_s=None, 254 | batch_size=None, 255 | dims=None, 256 | device=None, 257 | ): 258 | """ 259 | Computes the FID score of `fake_images` w.r.t. either precomputed stats on real 260 | data, or another batch of images (typically real ones). 261 | 262 | If `real_images` is `None`, you must provide `real_m` **and** `real_s` with 263 | matching dimensions to `fake_images`. 264 | 265 | If `real_images` is not `None` it will prevail over `real_m` and `real_s` 266 | which will be ignored 267 | 268 | Other arguments are optional and will be inherited from `pfw.params` if not 269 | provided. Use `pfw.set_config` to change those params globally for future calls 270 | 271 | Args: 272 | fake_images (torch.Tensor): N x C x H x W tensor. 273 | real_images (torch.Tensor, optional): N x C x H x W tensor. If provided, 274 | stats will be computed from it, ignoring real_s and real_m. 275 | Defaults to None. 276 | real_m (, optional): Mean of a previous activation stats computation, 277 | typically on real data. Defaults to None. 278 | real_s (, optional): Std of a previous activation stats computation, 279 | typically on real data. Defaults to None. 280 | batch_size (int, optional): Inception inference batch_size. 281 | Will use pfw.params.batch_size if not provided. Defaults to None. 282 | dims (int, optional): which inception block to select. 283 | See InceptionV3.BLOCK_INDEX_BY_DIM. Will use pfw.params.dims 284 | if not provided. Defaults to None. 285 | device (str | torch.device, optional): PyTorch device for inception inference. 286 | Will use pfw.params.device if not provided. Defaults to None. 287 | 288 | Returns: 289 | float: Frechet Inception Distance between `fake_images` and either `real_images` 290 | or `(real_m, real_s)` 291 | """ 292 | 293 | assert real_images is not None or (real_m is not None and real_s is not None) 294 | 295 | if batch_size is None: 296 | batch_size = pfw_params.batch_size 297 | if dims is None: 298 | dims = pfw_params.dims 299 | if device is None: 300 | device = pfw_params.device 301 | 302 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 303 | model = InceptionV3([block_idx]).to(device) 304 | 305 | if real_images is not None: 306 | real_m, real_s = get_stats(real_images, model, batch_size, dims, device) 307 | 308 | fake_m, fake_s = get_stats(fake_images, model, batch_size, dims, device) 309 | 310 | fid_value = calculate_frechet_distance(real_m, real_s, fake_m, fake_s) 311 | return fid_value 312 | -------------------------------------------------------------------------------- /pytorch_fid_wrapper/inception.py: -------------------------------------------------------------------------------- 1 | """ 2 | # ------------------------------ 3 | # ----- InceptionV3 Code ----- 4 | # ------------------------------ 5 | 6 | Copied from: 7 | https://github.com/mseitzer/pytorch-fid/blob/04740914ff6b9c0ba2e68b079e5477534db56c90/pytorch_fid/inception.py 8 | 9 | Only modifications are black formatting and some typos. 10 | """ 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torchvision 15 | 16 | try: 17 | from torchvision.models.utils import load_state_dict_from_url 18 | except ImportError: 19 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 20 | 21 | # Inception weights ported to Pytorch from 22 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 23 | FID_WEIGHTS_URL = ( 24 | "https://github.com/mseitzer/pytorch-fid/releases/download/" 25 | + "fid_weights/pt_inception-2015-12-05-6726825d.pth" 26 | ) 27 | 28 | 29 | class InceptionV3(nn.Module): 30 | """Pretrained InceptionV3 network returning feature maps""" 31 | 32 | # Index of default block of inception to return, 33 | # corresponds to output of final average pooling 34 | DEFAULT_BLOCK_INDEX = 3 35 | 36 | # Maps feature dimensionality to their output blocks indices 37 | BLOCK_INDEX_BY_DIM = { 38 | 64: 0, # First max pooling features 39 | 192: 1, # Second max pooling features 40 | 768: 2, # Pre-aux classifier features 41 | 2048: 3, # Final average pooling features 42 | } 43 | 44 | def __init__( 45 | self, 46 | output_blocks=[DEFAULT_BLOCK_INDEX], 47 | resize_input=True, 48 | normalize_input=True, 49 | requires_grad=False, 50 | use_fid_inception=True, 51 | ): 52 | """Build pretrained InceptionV3 53 | Parameters 54 | ---------- 55 | output_blocks : list of int 56 | Indices of blocks to return features of. Possible values are: 57 | - 0: corresponds to output of first max pooling 58 | - 1: corresponds to output of second max pooling 59 | - 2: corresponds to output which is fed to aux classifier 60 | - 3: corresponds to output of final average pooling 61 | resize_input : bool 62 | If true, bilinearly resizes input to width and height 299 before 63 | feeding input to model. As the network without fully connected 64 | layers is fully convolutional, it should be able to handle inputs 65 | of arbitrary size, so resizing might not be strictly needed 66 | normalize_input : bool 67 | If true, scales the input from range (0, 1) to the range the 68 | pretrained Inception network expects, namely (-1, 1) 69 | requires_grad : bool 70 | If true, parameters of the model require gradients. Possibly useful 71 | for finetuning the network 72 | use_fid_inception : bool 73 | If true, uses the pretrained Inception model used in Tensorflow's 74 | FID implementation. If false, uses the pretrained Inception model 75 | available in torchvision. The FID Inception model has different 76 | weights and a slightly different structure from torchvision's 77 | Inception model. If you want to compute FID scores, you are 78 | strongly advised to set this parameter to true to get comparable 79 | results. 80 | """ 81 | super().__init__() 82 | 83 | self.resize_input = resize_input 84 | self.normalize_input = normalize_input 85 | self.output_blocks = sorted(output_blocks) 86 | self.last_needed_block = max(output_blocks) 87 | 88 | assert self.last_needed_block <= 3, "Last possible output block index is 3" 89 | 90 | self.blocks = nn.ModuleList() 91 | 92 | if use_fid_inception: 93 | inception = fid_inception_v3() 94 | else: 95 | inception = _inception_v3(pretrained=True) 96 | 97 | # Block 0: input to maxpool1 98 | block0 = [ 99 | inception.Conv2d_1a_3x3, 100 | inception.Conv2d_2a_3x3, 101 | inception.Conv2d_2b_3x3, 102 | nn.MaxPool2d(kernel_size=3, stride=2), 103 | ] 104 | self.blocks.append(nn.Sequential(*block0)) 105 | 106 | # Block 1: maxpool1 to maxpool2 107 | if self.last_needed_block >= 1: 108 | block1 = [ 109 | inception.Conv2d_3b_1x1, 110 | inception.Conv2d_4a_3x3, 111 | nn.MaxPool2d(kernel_size=3, stride=2), 112 | ] 113 | self.blocks.append(nn.Sequential(*block1)) 114 | 115 | # Block 2: maxpool2 to aux classifier 116 | if self.last_needed_block >= 2: 117 | block2 = [ 118 | inception.Mixed_5b, 119 | inception.Mixed_5c, 120 | inception.Mixed_5d, 121 | inception.Mixed_6a, 122 | inception.Mixed_6b, 123 | inception.Mixed_6c, 124 | inception.Mixed_6d, 125 | inception.Mixed_6e, 126 | ] 127 | self.blocks.append(nn.Sequential(*block2)) 128 | 129 | # Block 3: aux classifier to final avgpool 130 | if self.last_needed_block >= 3: 131 | block3 = [ 132 | inception.Mixed_7a, 133 | inception.Mixed_7b, 134 | inception.Mixed_7c, 135 | nn.AdaptiveAvgPool2d(output_size=(1, 1)), 136 | ] 137 | self.blocks.append(nn.Sequential(*block3)) 138 | 139 | for param in self.parameters(): 140 | param.requires_grad = requires_grad 141 | 142 | def forward(self, inp): 143 | """Get Inception feature maps 144 | Parameters 145 | ---------- 146 | inp : torch.autograd.Variable 147 | Input tensor of shape Bx3xHxW. Values are expected to be in 148 | range (0, 1) 149 | Returns 150 | ------- 151 | List of torch.autograd.Variable, corresponding to the selected output 152 | block, sorted ascending by index 153 | """ 154 | outp = [] 155 | x = inp 156 | 157 | if self.resize_input: 158 | x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=False) 159 | 160 | if self.normalize_input: 161 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 162 | 163 | for idx, block in enumerate(self.blocks): 164 | x = block(x) 165 | if idx in self.output_blocks: 166 | outp.append(x) 167 | 168 | if idx == self.last_needed_block: 169 | break 170 | 171 | return outp 172 | 173 | 174 | def _inception_v3(*args, **kwargs): 175 | """Wraps `torchvision.models.inception_v3` 176 | Skips default weight initialization if supported by torchvision version. 177 | See https://github.com/mseitzer/pytorch-fid/issues/28. 178 | """ 179 | try: 180 | version = tuple(map(int, torchvision.__version__.split(".")[:2])) 181 | except ValueError: 182 | # Just a caution against weird version strings 183 | version = (0,) 184 | 185 | if version >= (0, 6): 186 | kwargs["init_weights"] = False 187 | 188 | return torchvision.models.inception_v3(*args, **kwargs) 189 | 190 | 191 | def fid_inception_v3(): 192 | """Build pretrained Inception model for FID computation 193 | The Inception model for FID computation uses a different set of weights 194 | and has a slightly different structure than torchvision's Inception. 195 | This method first constructs torchvision's Inception and then patches the 196 | necessary parts that are different in the FID Inception model. 197 | """ 198 | inception = _inception_v3(num_classes=1008, aux_logits=False, pretrained=False) 199 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 200 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 201 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 202 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 203 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 204 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 205 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 206 | inception.Mixed_7b = FIDInceptionE_1(1280) 207 | inception.Mixed_7c = FIDInceptionE_2(2048) 208 | 209 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 210 | inception.load_state_dict(state_dict) 211 | return inception 212 | 213 | 214 | class FIDInceptionA(torchvision.models.inception.InceptionA): 215 | """InceptionA block patched for FID computation""" 216 | 217 | def __init__(self, in_channels, pool_features): 218 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 219 | 220 | def forward(self, x): 221 | branch1x1 = self.branch1x1(x) 222 | 223 | branch5x5 = self.branch5x5_1(x) 224 | branch5x5 = self.branch5x5_2(branch5x5) 225 | 226 | branch3x3dbl = self.branch3x3dbl_1(x) 227 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 228 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 229 | 230 | # Patch: Tensorflow's average pool does not use the padded zero's in 231 | # its average calculation 232 | branch_pool = F.avg_pool2d( 233 | x, kernel_size=3, stride=1, padding=1, count_include_pad=False 234 | ) 235 | branch_pool = self.branch_pool(branch_pool) 236 | 237 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 238 | return torch.cat(outputs, 1) 239 | 240 | 241 | class FIDInceptionC(torchvision.models.inception.InceptionC): 242 | """InceptionC block patched for FID computation""" 243 | 244 | def __init__(self, in_channels, channels_7x7): 245 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 246 | 247 | def forward(self, x): 248 | branch1x1 = self.branch1x1(x) 249 | 250 | branch7x7 = self.branch7x7_1(x) 251 | branch7x7 = self.branch7x7_2(branch7x7) 252 | branch7x7 = self.branch7x7_3(branch7x7) 253 | 254 | branch7x7dbl = self.branch7x7dbl_1(x) 255 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 256 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 257 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 258 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 259 | 260 | # Patch: Tensorflow's average pool does not use the padded zero's in 261 | # its average calculation 262 | branch_pool = F.avg_pool2d( 263 | x, kernel_size=3, stride=1, padding=1, count_include_pad=False 264 | ) 265 | branch_pool = self.branch_pool(branch_pool) 266 | 267 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 268 | return torch.cat(outputs, 1) 269 | 270 | 271 | class FIDInceptionE_1(torchvision.models.inception.InceptionE): 272 | """First InceptionE block patched for FID computation""" 273 | 274 | def __init__(self, in_channels): 275 | super(FIDInceptionE_1, self).__init__(in_channels) 276 | 277 | def forward(self, x): 278 | branch1x1 = self.branch1x1(x) 279 | 280 | branch3x3 = self.branch3x3_1(x) 281 | branch3x3 = [self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)] 282 | branch3x3 = torch.cat(branch3x3, 1) 283 | 284 | branch3x3dbl = self.branch3x3dbl_1(x) 285 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 286 | branch3x3dbl = [ 287 | self.branch3x3dbl_3a(branch3x3dbl), 288 | self.branch3x3dbl_3b(branch3x3dbl), 289 | ] 290 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 291 | 292 | # Patch: Tensorflow's average pool does not use the padded zero's in 293 | # its average calculation 294 | branch_pool = F.avg_pool2d( 295 | x, kernel_size=3, stride=1, padding=1, count_include_pad=False 296 | ) 297 | branch_pool = self.branch_pool(branch_pool) 298 | 299 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 300 | return torch.cat(outputs, 1) 301 | 302 | 303 | class FIDInceptionE_2(torchvision.models.inception.InceptionE): 304 | """Second InceptionE block patched for FID computation""" 305 | 306 | def __init__(self, in_channels): 307 | super(FIDInceptionE_2, self).__init__(in_channels) 308 | 309 | def forward(self, x): 310 | branch1x1 = self.branch1x1(x) 311 | 312 | branch3x3 = self.branch3x3_1(x) 313 | branch3x3 = [self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)] 314 | branch3x3 = torch.cat(branch3x3, 1) 315 | 316 | branch3x3dbl = self.branch3x3dbl_1(x) 317 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 318 | branch3x3dbl = [ 319 | self.branch3x3dbl_3a(branch3x3dbl), 320 | self.branch3x3dbl_3b(branch3x3dbl), 321 | ] 322 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 323 | 324 | # Patch: The FID Inception model uses max pooling instead of average 325 | # pooling. This is likely an error in this specific Inception 326 | # implementation, as other Inception models use average pooling here 327 | # (which matches the description in the paper). 328 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 329 | branch_pool = self.branch_pool(branch_pool) 330 | 331 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 332 | return torch.cat(outputs, 1) 333 | -------------------------------------------------------------------------------- /pytorch_fid_wrapper/params.py: -------------------------------------------------------------------------------- 1 | """ 2 | # --------------------------------- 3 | # ----- pfw's global params ----- 4 | # --------------------------------- 5 | 6 | batch_size: batch_size for inception inference. 7 | Defaults to 50 as in pytorch-fid. 8 | dims: which inception block to select from InceptionV3.BLOCK_INDEX_BY_DIM. 9 | Defaults to 2048 as in pytorch-fid. 10 | device: PyTorch device for inception inference. 11 | Defaults to cpu as in pytorch-fid. 12 | """ 13 | batch_size = 50 14 | dims = 2048 15 | device = "cpu" 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | import pytorch_fid_wrapper 3 | 4 | with open("README.md", "r") as fh: 5 | long_description = fh.read() 6 | 7 | 8 | setuptools.setup( 9 | name="pytorch-fid-wrapper", 10 | version=pytorch_fid_wrapper.__version__, 11 | author="Victor Schmidt", 12 | author_email="not.an.address@yes.com", 13 | description=( 14 | "Wrapper around the pytorch-fid package to compute Frechet Inception" 15 | + "Distance (FID) using PyTorch in-memory given tensors of images." 16 | ), 17 | long_description=long_description, 18 | long_description_content_type="text/markdown", 19 | url="https://github.com/vict0rsch/pytorch-fid-wrapper", 20 | packages=setuptools.find_packages(), 21 | classifiers=[ 22 | "Programming Language :: Python :: 3", 23 | "License :: OSI Approved :: Apache Software License", 24 | ], 25 | python_requires=">=3.5", 26 | install_requires=["numpy", "pillow", "scipy", "torch", "torchvision"], 27 | ) 28 | --------------------------------------------------------------------------------