├── .gitignore ├── LICENSE ├── README.md ├── pytorch_image_generation_metrics ├── __init__.py ├── calc.py ├── core.py ├── districuted.py ├── fid_ref.py ├── inception.py ├── utils.py └── version.py ├── requirements-dev.txt ├── requirements.txt ├── setup.py ├── tests ├── __init__.py ├── conftest.py ├── test_all_metrics.py ├── test_fid.py ├── test_fid_ref.py └── test_inception_score.py └── tox.ini /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | 3 | __pycache__ 4 | .python-version 5 | 6 | build 7 | dist 8 | .tox 9 | *.egg-info 10 | 11 | tests/* 12 | !tests/*.py 13 | 14 | fid_refs 15 | 16 | cheatsheet.md 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Implementation of Common Image Generation Metrics 2 | 3 | ![PyPI](https://img.shields.io/pypi/v/pytorch_image_generation_metrics) 4 | 5 | ## Installation 6 | ``` 7 | pip install pytorch-image-generation-metrics 8 | ``` 9 | 10 | ## Quick Start 11 | ```python 12 | from pytorch_image_generation_metrics import get_inception_score, get_fid 13 | 14 | images = ... # [N, 3, H, W] normalized to [0, 1] 15 | IS, IS_std = get_inception_score(images) # Inception Score 16 | FID = get_fid(images, 'path/to/fid_ref.npz') # Frechet Inception Distance 17 | ``` 18 | The file `path/to/fid_ref.npz` is compatiable with the [official FID implementation](https://github.com/bioinf-jku/TTUR). 19 | 20 | ## Notes 21 | The FID implementation is inspired by [pytorch-fid](https://github.com/mseitzer/pytorch-fid). 22 | 23 | This repository is developed for personal research. If you find this package useful, please feel free to open issues. 24 | 25 | ## Features 26 | - Currently, this package supports the following metrics: 27 | - [Inception Score](https://github.com/openai/improved-gan) (IS) 28 | - [Fréchet Inception Distance](https://github.com/bioinf-jku/TTUR) (FID) 29 | - The computation procedures for IS and FID are integrated to avoid multiple forward passes. 30 | - Supports reading images on the fly to prevent out-of-memory issues, especially for large-scale images. 31 | - Supports computation on GPU to speed up some CPU operations, such as `np.cov` and `scipy.linalg.sqrtm`. 32 | 33 | ## Reproducing Results of Official Implementations on CIFAR-10 34 | 35 | | |Train IS |Test IS |Train(50k) vs Test(10k)
FID| 36 | |---------------------|:--------:|:--------:|:----------------------------:| 37 | |Official |11.24±0.20|10.98±0.22|3.1508 | 38 | |ours |11.26±0.13|10.97±0.19|3.1525 | 39 | |ours `use_torch=True`|11.26±0.15|10.97±0.20|3.1457 | 40 | 41 | The results differ slightly from the official implementations due to the framework differences between PyTorch and TensorFlow. 42 | 43 | ## Documentation 44 | 45 | ### Prepare Statistical Reference for FID 46 | - [Download](https://drive.google.com/drive/folders/1UBdzl6GtNMwNQ5U-4ESlIer43tNjiGJC?usp=sharing) the pre-calculated reference, or 47 | - Calculate the statistical reference for your custom dataset using the command-line tool: 48 | ```bash 49 | python -m pytorch_image_generation_metrics.fid_ref \ 50 | --path path/to/images \ 51 | --output path/to/fid_ref.npz 52 | ``` 53 | See [fid_ref.py](./pytorch_image_generation_metrics/fid_ref.py) for details. 54 | 55 | ### Inception Features 56 | - When getting IS or FID, the `InceptionV3` model will be loaded into `torch.device('cuda:0')` by default. 57 | - Change the `device` argument in the `get_*` functions to set the torch device. 58 | 59 | ### Using `torch.Tensor` as images 60 | 61 | - Prepare images as `torch.float32` tensors with shape `[N, 3, H, W]`, normalized to `[0,1]`. 62 | ```python 63 | from pytorch_image_generation_metrics import ( 64 | get_inception_score, 65 | get_fid, 66 | get_inception_score_and_fid 67 | ) 68 | 69 | images = ... # [N, 3, H, W] 70 | assert 0 <= images.min() and images.max() <= 1 71 | 72 | # Inception Score 73 | IS, IS_std = get_inception_score( 74 | images) 75 | 76 | # Frechet Inception Distance 77 | FID = get_fid( 78 | images, 'path/to/fid_ref.npz') 79 | 80 | # Inception Score & Frechet Inception Distance 81 | (IS, IS_std), FID = get_inception_score_and_fid( 82 | images, 'path/to/fid_ref.npz') 83 | 84 | ``` 85 | 86 | ### Using PyTorch DataLoader to Provide Images 87 | 88 | 1. Use `pytorch_image_generation_metrics.ImageDataset` to collect images from your storage or use your custom `torch.utils.data.Dataset`. 89 | ```python 90 | from pytorch_image_generation_metrics import ImageDataset 91 | from torch.utils.data import DataLoader 92 | 93 | dataset = ImageDataset(path_to_dir, exts=['png', 'jpg']) 94 | loader = DataLoader(dataset, batch_size=50, num_workers=4) 95 | ``` 96 | 97 | You can wrap a generative model in a dataset to support generating images on the fly. 98 | ```python 99 | class GeneratorDataset(Dataset): 100 | def __init__(self, G, noise_dim): 101 | self.G = G 102 | self.noise_dim = noise_dim 103 | 104 | def __len__(self): 105 | return 50000 106 | 107 | def __getitem__(self, index): 108 | return self.G(torch.randn(1, self.noise_dim)) 109 | 110 | dataset = GeneratorDataset(G, noise_dim=128) 111 | loader = DataLoader(dataset, batch_size=50, num_workers=0) 112 | ``` 113 | 114 | 2. Calculate metrics 115 | ```python 116 | from pytorch_image_generation_metrics import ( 117 | get_inception_score, 118 | get_fid, 119 | get_inception_score_and_fid 120 | ) 121 | 122 | # Inception Score 123 | IS, IS_std = get_inception_score( 124 | loader) 125 | 126 | # Frechet Inception Distance 127 | FID = get_fid( 128 | loader, 'path/to/fid_ref.npz') 129 | 130 | # Inception Score & Frechet Inception Distance 131 | (IS, IS_std), FID = get_inception_score_and_fid( 132 | loader, 'path/to/fid_ref.npz') 133 | ``` 134 | 135 | ### Load Images from a Directory 136 | 137 | - Calculate metrics for images in a directory and its subfolders. 138 | ```python 139 | from pytorch_image_generation_metrics import ( 140 | get_inception_score_from_directory, 141 | get_fid_from_directory, 142 | get_inception_score_and_fid_from_directory) 143 | 144 | IS, IS_std = get_inception_score_from_directory( 145 | 'path/to/images') 146 | FID = get_fid_from_directory( 147 | 'path/to/images', 'path/to/fid_ref.npz') 148 | (IS, IS_std), FID = get_inception_score_and_fid_from_directory( 149 | 'path/to/images', 'path/to/fid_ref.npz') 150 | ``` 151 | 152 | ### Accelerating Matrix Computation with PyTorch 153 | 154 | - Set `use_torch=True` when calling functions like `get_inception_score`, `get_fid`, etc. 155 | 156 | - **WARNING**: when `use_torch=True` is used, the FID might be `nan` due to the unstable implementation of matrix sqrt root. 157 | 158 | ## Tested Versions 159 | - `python 3.9 + torch 1.13.1 + CUDA 11.7` 160 | - `python 3.9 + torch 2.3.0 + CUDA 12.1` 161 | 162 | ## License 163 | 164 | This implementation is licensed under the Apache License 2.0. 165 | 166 | This implementation is derived from [pytorch-fid](https://github.com/mseitzer/pytorch-fid), licensed under the Apache License 2.0. 167 | 168 | FID was introduced by Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler and Sepp Hochreiter in "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium", see [https://arxiv.org/abs/1706.08500](https://arxiv.org/abs/1706.08500) 169 | 170 | The original implementation of FID is by the Institute of Bioinformatics, JKU Linz, licensed under the Apache License 2.0. 171 | See [https://github.com/bioinf-jku/TTUR](https://github.com/bioinf-jku/TTUR). 172 | -------------------------------------------------------------------------------- /pytorch_image_generation_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | """Export public API.""" 2 | 3 | from .version import __version__ 4 | 5 | from pytorch_image_generation_metrics.utils import ( 6 | ImageDataset, 7 | get_inception_score, 8 | get_inception_score_from_directory, 9 | get_fid, 10 | get_fid_from_directory, 11 | get_inception_score_and_fid, 12 | get_inception_score_and_fid_from_directory) 13 | 14 | __all__ = [ 15 | ImageDataset, 16 | get_inception_score, 17 | get_inception_score_from_directory, 18 | get_fid, 19 | get_fid_from_directory, 20 | get_inception_score_and_fid, 21 | get_inception_score_and_fid_from_directory, 22 | __version__, 23 | ] 24 | -------------------------------------------------------------------------------- /pytorch_image_generation_metrics/calc.py: -------------------------------------------------------------------------------- 1 | """Calculate the FID and Inception Score of images in a directory.""" 2 | 3 | import argparse 4 | import os 5 | import tempfile 6 | 7 | import torch 8 | from torch.utils.data import DataLoader, SequentialSampler 9 | from torch.utils.data.distributed import DistributedSampler 10 | 11 | from .districuted import init, world_size, print0 12 | from .utils import ImageDataset, get_inception_score_and_fid 13 | 14 | 15 | def calc(args): 16 | """Calculate the FID and Inception Score of images in a directory.""" 17 | dataset = ImageDataset(root=args.path, num_images=args.num_images) 18 | if world_size() > 1: 19 | sampler = DistributedSampler(dataset, shuffle=False) 20 | else: 21 | sampler = SequentialSampler(dataset) 22 | loader = DataLoader( 23 | dataset, 24 | batch_size=args.batch_size, 25 | sampler=sampler, 26 | num_workers=args.num_workers) 27 | (IS, IS_std), FID = get_inception_score_and_fid( 28 | loader, 29 | args.fid_ref, 30 | use_torch=args.use_torch, 31 | verbose=True) 32 | print0(IS, IS_std, FID) 33 | 34 | 35 | def calc_init(init_method, world_size, rank, args): 36 | """Initialize the distributed environment and calculate the FID and Inception Score of images in a directory.""" 37 | init(init_method, world_size, rank) 38 | calc(args) 39 | 40 | 41 | def main(): 42 | """Parse command-line arguments and calculate the FID and Inception Score of images in a directory.""" 43 | parser = argparse.ArgumentParser( 44 | description="A command-line tool to calculate Frechet Inception Distance (FID) between generated and reference images.", 45 | epilog="Example: CUDA_VISIBLE_DEVICES=0,1 python -m pytorch_image_generation_metrics.calc_metrics --path cifar10/train --fid_ref cifar10.test.npz --batch 64", 46 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 47 | parser.add_argument('--path', type=str, required=True, 48 | help='Path to the directory containing generated images.') 49 | parser.add_argument('--fid_ref', type=str, required=True, 50 | help='Path to precalculated reference statistics file.') 51 | parser.add_argument("--batch_size", type=int, default=50, 52 | help="Batch size for processing images.") 53 | parser.add_argument("--num_images", type=int, default=None, 54 | help="Number of images to use for calculating FID. If not specified, all images in the directory will be used.") 55 | parser.add_argument('--use_torch', action='store_true', 56 | help='Use PyTorch for matrix operations.') 57 | parser.add_argument("--num_workers", type=int, default=os.cpu_count(), 58 | help="Number of worker processes for data loading.") 59 | args = parser.parse_args() 60 | 61 | world_size = len(os.environ.get('CUDA_VISIBLE_DEVICES', "0").split(',')) 62 | if world_size == 1: 63 | calc(args) 64 | else: 65 | with tempfile.TemporaryDirectory() as temp: 66 | init_method = f'file://{os.path.abspath(os.path.join(temp, ".ddp"))}' 67 | processes = [] 68 | for rank in range(world_size): 69 | p = torch.multiprocessing.Process( 70 | target=calc_init, 71 | args=(init_method, world_size, rank, args)) 72 | p.start() 73 | processes.append(p) 74 | for p in processes: 75 | p.join() 76 | 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /pytorch_image_generation_metrics/core.py: -------------------------------------------------------------------------------- 1 | """The core implementation of Inception Score and FID.""" 2 | 3 | from typing import List, Union, Tuple 4 | 5 | import numpy as np 6 | import torch 7 | from scipy import linalg 8 | from tqdm.auto import tqdm 9 | from torch.utils.data import DataLoader, SequentialSampler, TensorDataset 10 | from torch.utils.data.distributed import DistributedSampler 11 | 12 | from . import districuted as dist 13 | from .inception import InceptionV3 14 | 15 | 16 | # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 17 | 18 | 19 | def get_inception_feature( 20 | images: Union[torch.FloatTensor, DataLoader], 21 | dims: List[int], 22 | batch_size: int = 50, 23 | use_torch: bool = False, 24 | verbose: bool = False, 25 | device: torch.device = None, 26 | ) -> Union[torch.FloatTensor, np.ndarray]: 27 | """Calculate Inception Score and FID. 28 | 29 | For each image, only a forward propagation is required to calculating 30 | features for FID and Inception Score. 31 | 32 | Args: 33 | images: tensor or torch.utils.data.Dataloader. The images 34 | must be float tensor of range [0, 1]. 35 | dims: List of int, see InceptionV3.BLOCK_INDEX_BY_DIM for 36 | available dimension. 37 | batch_size: int, The batch size for calculating activations. If 38 | `images` is torch.utils.data.Dataloader, this argument is 39 | ignored. 40 | use_torch: When True, use torch to calculate FID. Otherwise, use numpy. 41 | verbose: Set verbose to False for disabling progress bar. Otherwise, 42 | the progress bar is showing when calculating activations. 43 | device: the torch device which is used to calculate inception feature 44 | Returns: 45 | inception_features: a list of extracted inception features 46 | corresponding to given dims. 47 | """ 48 | assert all(dim in InceptionV3.BLOCK_INDEX_BY_DIM for dim in dims) 49 | if device is None: 50 | device = dist.device() 51 | 52 | if not isinstance(images, DataLoader): 53 | num_images = len(images) 54 | if dist.world_size() > 1: 55 | sampler = DistributedSampler( 56 | TensorDataset(images), shuffle=False) 57 | else: 58 | sampler = SequentialSampler(TensorDataset(images)) 59 | # print(sampler) 60 | loader = DataLoader( 61 | images, 62 | batch_size=batch_size, 63 | sampler=sampler, 64 | drop_last=False, 65 | ) 66 | else: 67 | num_images = len(images.dataset) 68 | loader = images 69 | 70 | block_idxs = [InceptionV3.BLOCK_INDEX_BY_DIM[dim] for dim in dims] 71 | # Initialize InceptionV3 model 72 | if dist.rank() == 0: 73 | # Only rank 0 initialize download the model 74 | model = InceptionV3(block_idxs).to(device) 75 | dist.barrier() 76 | else: 77 | # Other ranks wait until rank 0 download the model 78 | dist.barrier() 79 | model = InceptionV3(block_idxs).to(device) 80 | model.eval() 81 | 82 | if dist.rank() == 0: 83 | if use_torch: 84 | features = [ 85 | torch.empty((num_images, dim)).to(device) 86 | for dim in dims] 87 | else: 88 | features = [ 89 | np.empty((num_images, dim)) 90 | for dim in dims] 91 | 92 | pbar = tqdm( 93 | total=num_images, dynamic_ncols=True, leave=False, 94 | disable=(dist.rank() != 0 or not verbose), 95 | desc="get_inception_feature") 96 | start = 0 97 | for batch_images in loader: 98 | batch_images = batch_images.to(device) 99 | # calculate inception feature 100 | end = min(start + len(batch_images) * dist.world_size(), num_images) 101 | with torch.no_grad(): 102 | outputs = model(batch_images) 103 | if end == num_images: 104 | # This is the last batch, so we need to remove the padding. 105 | if num_images % dist.world_size() != 0: 106 | is_padded = dist.rank() >= num_images % dist.world_size() 107 | else: 108 | is_padded = False 109 | if is_padded: 110 | # Remove the padding 111 | outputs = [output[: -1] for output in outputs] 112 | for output in outputs: 113 | print(output.shape) 114 | outputs = [dist.gather(output) for output in outputs] 115 | if dist.rank() == 0: 116 | for feature, output, dim in zip(features, outputs, dims): 117 | if use_torch: 118 | feature[start: end] = output.view(-1, dim) 119 | else: 120 | feature[start: end] = output.view(-1, dim).cpu().numpy() 121 | dist.barrier() 122 | else: 123 | dist.barrier() 124 | pbar.update(end - start) 125 | start = end 126 | pbar.close() 127 | assert start == num_images 128 | if dist.rank() == 0: 129 | return features 130 | else: 131 | return [None for _ in range(len(dims))] 132 | 133 | 134 | def torch_cov(m, rowvar=False): 135 | """Estimate a covariance matrix given data. 136 | 137 | Covariance indicates the level to which two variables vary together. 138 | If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, 139 | then the covariance matrix element `C_{ij}` is the covariance of 140 | `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. 141 | 142 | Args: 143 | m: A 1-D or 2-D array containing multiple variables and observations. 144 | Each row of `m` represents a variable, and each column a single 145 | observation of all those variables. 146 | rowvar: If `rowvar` is True, then each row represents a 147 | variable, with observations in the columns. Otherwise, the 148 | relationship is transposed: each column represents a variable, 149 | while the rows contain observations. 150 | 151 | Returns: 152 | The covariance matrix of the variables. 153 | """ 154 | if m.dim() > 2: 155 | raise ValueError('m has more than 2 dimensions') 156 | if m.dim() < 2: 157 | m = m.view(1, -1) 158 | if not rowvar and m.size(0) != 1: 159 | m = m.t() 160 | # m = m.type(torch.double) # uncomment this line if desired 161 | fact = 1.0 / (m.size(1) - 1) 162 | m -= torch.mean(m, dim=1, keepdim=True) 163 | mt = m.t() # if complex: mt = m.t().conj() 164 | return fact * m.matmul(mt).squeeze() 165 | 166 | 167 | # Pytorch implementation of matrix sqrt, from Tsung-Yu Lin, and Subhransu Maji 168 | # https://github.com/msubhransu/matrix-sqrt 169 | def sqrt_newton_schulz(A, numIters, dtype=None): # noqa 170 | with torch.no_grad(): 171 | if dtype is None: 172 | dtype = A.type() 173 | batchSize = A.shape[0] 174 | dim = A.shape[1] 175 | normA = A.mul(A).sum(dim=1).sum(dim=1).sqrt() 176 | Y = A.div(normA.view(batchSize, 1, 1).expand_as(A)) 177 | K = torch.eye(dim, dim).view(1, dim, dim).repeat(batchSize, 1, 1) 178 | Z = torch.eye(dim, dim).view(1, dim, dim).repeat(batchSize, 1, 1) 179 | K = K.type(dtype) 180 | Z = Z.type(dtype) 181 | for i in range(numIters): 182 | T = 0.5 * (3.0 * K - Z.bmm(Y)) 183 | Y = Y.bmm(T) 184 | Z = T.bmm(Z) 185 | sA = Y * torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A) 186 | return sA 187 | 188 | 189 | def calculate_frechet_inception_distance( 190 | acts: Union[torch.FloatTensor, np.ndarray], 191 | mu: np.ndarray, 192 | sigma: np.ndarray, 193 | use_torch: bool = False, 194 | eps: float = 1e-6, 195 | device: torch.device = torch.device('cuda:0'), 196 | ) -> float: # noqa 197 | if use_torch: 198 | m1 = torch.mean(acts, axis=0) 199 | s1 = torch_cov(acts, rowvar=False) 200 | mu = torch.tensor(mu).to(m1.dtype).to(device) 201 | sigma = torch.tensor(sigma).to(s1.dtype).to(device) 202 | else: 203 | m1 = np.mean(acts, axis=0) 204 | s1 = np.cov(acts, rowvar=False) 205 | return calculate_frechet_distance(m1, s1, mu, sigma, use_torch, eps) 206 | 207 | 208 | def calculate_frechet_distance( 209 | mu1: Union[torch.FloatTensor, np.ndarray], 210 | sigma1: Union[torch.FloatTensor, np.ndarray], 211 | mu2: Union[torch.FloatTensor, np.ndarray], 212 | sigma2: Union[torch.FloatTensor, np.ndarray], 213 | use_torch: bool = False, 214 | eps: float = 1e-6, 215 | ) -> float: 216 | """Calculate Frechet Distance. 217 | 218 | Args: 219 | mu1: The sample mean over activations for a set of samples. 220 | sigma1: The covariance matrix over activations for a set of samples. 221 | mu2: The sample mean over activations for another set of samples. 222 | sigma2: The covariance matrix over activations for another set of 223 | samples. 224 | use_torch: When True, use torch to calculate FID. Otherwise, use numpy. 225 | eps: prevent covmean from being singular matrix 226 | 227 | Returns: 228 | The Frechet Distance. 229 | """ 230 | if use_torch: 231 | assert mu1.shape == mu2.shape, \ 232 | 'Training and test mean vectors have different lengths' 233 | assert sigma1.shape == sigma2.shape, \ 234 | 'Training and test covariances have different dimensions' 235 | 236 | diff = mu1 - mu2 237 | # Run 50 itrs of newton-schulz to get the matrix sqrt of 238 | # sigma1 dot sigma2 239 | covmean = sqrt_newton_schulz(sigma1.mm(sigma2).unsqueeze(0), 50) 240 | if torch.any(torch.isnan(covmean)): 241 | return float('nan') 242 | covmean = covmean.squeeze() 243 | out = (diff.dot(diff) + # noqa: W504 244 | torch.trace(sigma1) + # noqa: W504 245 | torch.trace(sigma2) - # noqa: W504 246 | 2 * torch.trace(covmean)).cpu().item() 247 | else: 248 | mu1 = np.atleast_1d(mu1) 249 | mu2 = np.atleast_1d(mu2) 250 | 251 | sigma1 = np.atleast_2d(sigma1) 252 | sigma2 = np.atleast_2d(sigma2) 253 | 254 | assert mu1.shape == mu2.shape, \ 255 | 'Training and test mean vectors have different lengths' 256 | assert sigma1.shape == sigma2.shape, \ 257 | 'Training and test covariances have different dimensions' 258 | 259 | diff = mu1 - mu2 260 | 261 | # Product might be almost singular 262 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 263 | if not np.isfinite(covmean).all(): 264 | msg = ('fid calculation produces singular product; ' 265 | 'adding %s to diagonal of cov estimates') % eps 266 | print(msg) 267 | offset = np.eye(sigma1.shape[0]) * eps 268 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 269 | 270 | # Numerical error might give slight imaginary component 271 | if np.iscomplexobj(covmean): 272 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 273 | m = np.max(np.abs(covmean.imag)) 274 | raise ValueError('Imaginary component {}'.format(m)) 275 | covmean = covmean.real 276 | 277 | tr_covmean = np.trace(covmean) 278 | 279 | out = (diff.dot(diff) + # noqa: W504 280 | np.trace(sigma1) + # noqa: W504 281 | np.trace(sigma2) - # noqa: W504 282 | 2 * tr_covmean).item() 283 | return out 284 | 285 | 286 | def calculate_inception_score( 287 | probs: Union[torch.FloatTensor, np.ndarray], 288 | splits: int = 10, 289 | use_torch: bool = False, 290 | ) -> Tuple[float, float]: # noqa 291 | # Inception Score 292 | scores = [] 293 | for i in range(splits): 294 | part = probs[ 295 | (i * probs.shape[0] // splits): 296 | ((i + 1) * probs.shape[0] // splits), :] 297 | if use_torch: 298 | kl = part * ( 299 | torch.log(part) - # noqa: W504 300 | torch.log(torch.unsqueeze(torch.mean(part, 0), 0))) 301 | kl = torch.mean(torch.sum(kl, 1)) 302 | scores.append(torch.exp(kl)) 303 | else: 304 | kl = part * ( 305 | np.log(part) - # noqa: W504 306 | np.log(np.expand_dims(np.mean(part, 0), 0))) 307 | kl = np.mean(np.sum(kl, 1)) 308 | scores.append(np.exp(kl)) 309 | if use_torch: 310 | scores = torch.stack(scores) 311 | inception_score = torch.mean(scores).cpu().item() 312 | std = torch.std(scores).cpu().item() 313 | else: 314 | inception_score, std = (np.mean(scores).item(), np.std(scores).item()) 315 | del probs, scores 316 | return inception_score, std 317 | -------------------------------------------------------------------------------- /pytorch_image_generation_metrics/districuted.py: -------------------------------------------------------------------------------- 1 | """Utilities for distributed processing in PyTorch.""" 2 | 3 | import torch 4 | import torch.distributed 5 | 6 | 7 | def init(init_method, world_size, rank): 8 | """Initialize the distributed process group for multi-GPU processing. 9 | 10 | Args: 11 | init_method (str): URL specifying how to initialize the process group. 12 | world_size (int): Number of processes participating in the job. 13 | rank (int): Rank of the current process. 14 | 15 | Initializes the NCCL backend for distributed GPU communication and sets the current CUDA device to the given rank. 16 | """ 17 | torch.distributed.init_process_group('nccl', init_method, world_size=world_size, rank=rank) 18 | torch.cuda.set_device(rank) 19 | torch.cuda.empty_cache() 20 | 21 | 22 | def rank(): 23 | """Return the rank of the current process in the distributed process group. 24 | 25 | Returns: 26 | int: Rank of the current process. Returns 0 if the process group is not initialized. 27 | """ 28 | if torch.distributed.is_initialized(): 29 | return torch.distributed.get_rank() 30 | else: 31 | return 0 32 | 33 | 34 | def world_size(): 35 | """Return the number of processes in the distributed process group. 36 | 37 | Returns: 38 | int: Number of processes. Returns 1 if the process group is not initialized. 39 | """ 40 | if torch.distributed.is_initialized(): 41 | return torch.distributed.get_world_size() 42 | else: 43 | return 1 44 | 45 | 46 | def barrier(): 47 | """Synchronize all processes in the distributed process group. 48 | 49 | Blocks until all processes have reached this function call. 50 | """ 51 | if torch.distributed.is_initialized(): 52 | torch.distributed.barrier() 53 | 54 | 55 | def device(): 56 | """Return the current CUDA device for the process based on its rank. 57 | 58 | Returns: 59 | torch.device: CUDA device object for the current process. 60 | """ 61 | return torch.device(f'cuda:{rank()}') 62 | 63 | 64 | def gather_shape(x: torch.Tensor, dim: int = 0): 65 | """Gather the shapes of tensors along a specific dimension from all processes in the distributed process group. 66 | 67 | Args: 68 | x (torch.Tensor): The tensor whose shape to gather. 69 | dim (int): The dimension along which to gather the shapes. Default is 0. 70 | 71 | Returns: 72 | list of torch.Size: A list of shapes from all processes. 73 | """ 74 | if world_size() > 1: 75 | sizes_at_dim = [torch.tensor(0).to(x.device) for _ in range(world_size())] 76 | torch.distributed.all_gather(sizes_at_dim, torch.tensor(x.shape[dim], device=x.device)) 77 | shapes = [] 78 | for size in sizes_at_dim: 79 | shape = list(x.shape) 80 | shape[dim] = size.item() 81 | shapes.append(torch.Size(shape)) 82 | return shapes 83 | else: 84 | return [x.shape] 85 | 86 | 87 | def gather(x: torch.Tensor, cat_dim: int = 0): 88 | """Gather tensors from all processes and concatenates them along a specified dimension. 89 | 90 | Args: 91 | x (torch.Tensor): The tensor to gather. 92 | cat_dim (int): The dimension along which to concatenate the tensors. Default is 0. 93 | 94 | Returns: 95 | torch.Tensor: The concatenated tensor from all processes. 96 | """ 97 | if world_size() > 1: 98 | shapes = gather_shape(x, cat_dim) 99 | xs = [torch.zeros(shape, device=x.device) for shape in shapes] 100 | torch.distributed.all_gather(xs, x) 101 | return torch.cat(xs, dim=cat_dim) 102 | else: 103 | return x 104 | 105 | 106 | def print0(*args, **kwargs): 107 | """Print messages only from the process with rank 0.""" 108 | if rank() == 0: 109 | print(*args, **kwargs) 110 | -------------------------------------------------------------------------------- /pytorch_image_generation_metrics/fid_ref.py: -------------------------------------------------------------------------------- 1 | """Calculate statistics for FID and save them to a file.""" 2 | 3 | import argparse 4 | import os 5 | import tempfile 6 | 7 | import torch 8 | 9 | from .districuted import init 10 | from .utils import calc_fid_ref 11 | 12 | 13 | def calc(args): 14 | """Calculate statistics for FID and save them to a file.""" 15 | calc_fid_ref( 16 | args.path, 17 | args.output, 18 | args.batch_size, 19 | args.img_size, 20 | args.use_torch, 21 | args.num_workers) 22 | 23 | 24 | def calc_init(init_method, world_size, rank, args): 25 | """Initialize the distributed environment and calculate statistics for FID and save them to a file.""" 26 | init(init_method, world_size, rank) 27 | calc(args) 28 | 29 | 30 | def main(): 31 | """Parse command-line arguments and calculate statistics for FID and save them to a file.""" 32 | parser = argparse.ArgumentParser( 33 | description="A command-line tool to compute Frechet Inception Distance (FID) statistics.", 34 | epilog="Example: CUDA_VISIBLE_DEVICES=0,1 python -m pytorch_image_generation_metrics.fid_ref --path cifar10/train --output cifar10.test.npz --batch_size 64", 35 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 36 | parser.add_argument("--path", type=str, required=True, 37 | help='Path to the directory containing images (including subfolders).') 38 | parser.add_argument("--output", type=str, required=True, 39 | help="Output file path for saving the computed statistics.") 40 | parser.add_argument("--batch_size", type=int, default=50, 41 | help="Batch size for processing images.") 42 | parser.add_argument("--img_size", type=int, default=None, 43 | help="Resize images to this specified size (if provided).") 44 | parser.add_argument('--use_torch', action='store_true', 45 | help='Use PyTorch for matrix operations.') 46 | parser.add_argument("--num_workers", type=int, default=os.cpu_count(), 47 | help="Number of worker processes for data loading.") 48 | args = parser.parse_args() 49 | 50 | world_size = len(os.environ.get('CUDA_VISIBLE_DEVICES', "0").split(',')) 51 | if world_size == 1: 52 | calc(args) 53 | else: 54 | with tempfile.TemporaryDirectory() as temp: 55 | init_method = f'file://{os.path.abspath(os.path.join(temp, ".ddp"))}' 56 | processes = [] 57 | for rank in range(world_size): 58 | p = torch.multiprocessing.Process( 59 | target=calc_init, 60 | args=(init_method, world_size, rank, args)) 61 | p.start() 62 | processes.append(p) 63 | for p in processes: 64 | p.join() 65 | 66 | 67 | if __name__ == '__main__': 68 | main() 69 | -------------------------------------------------------------------------------- /pytorch_image_generation_metrics/inception.py: -------------------------------------------------------------------------------- 1 | """Inception Model v3 for FID computation.""" 2 | 3 | from packaging import version 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchvision 9 | from torchvision import models 10 | from torch.hub import load_state_dict_from_url 11 | 12 | # Inception weights ported to Pytorch from 13 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 14 | FID_WEIGHTS_URL = ('https://github.com/w86763777/pytorch-image-generation-metrics/releases/' 15 | 'download/v0.1.0/pt_inception-2015-12-05-6726825d.pth') 16 | TORCHVISION_VERSION = version.parse(torchvision.__version__) 17 | 18 | 19 | class InceptionV3(nn.Module): 20 | """Pretrained InceptionV3 network returning feature maps.""" 21 | 22 | # Index of default block of inception to return, 23 | # corresponds to output of final average pooling 24 | DEFAULT_BLOCK_INDEX = 3 25 | 26 | # Maps feature dimensionality to their output blocks indices 27 | BLOCK_INDEX_BY_DIM = { 28 | 64: 0, # First max pooling features 29 | 192: 1, # Second max pooling featurs 30 | 768: 2, # Pre-aux classifier features 31 | 2048: 3, # Final average pooling features 32 | 1008: 4, # softmax layer 33 | } 34 | 35 | def __init__(self, 36 | output_blocks=[DEFAULT_BLOCK_INDEX], 37 | resize_input=True, 38 | normalize_input=True, 39 | requires_grad=False, 40 | use_fid_inception=True): 41 | """Build pretrained InceptionV3. 42 | 43 | Args: 44 | output_blocks : List of int 45 | Indices of blocks to return features of. Possible values are: 46 | - 0: corresponds to output of first max pooling 47 | - 1: corresponds to output of second max pooling 48 | - 2: corresponds to output which is fed to aux classifier 49 | - 3: corresponds to output of final average pooling 50 | - 4: corresponds to output of softmax 51 | resize_input : bool 52 | If true, bilinearly resizes input to width and height 299 53 | before feeding input to model. As the network without fully 54 | connected layers is fully convolutional, it should be able to 55 | handle inputs of arbitrary size, so resizing might not be 56 | strictly needed 57 | normalize_input : bool 58 | If true, scales the input from range (0, 1) to the range the 59 | pretrained Inception network expects, namely (-1, 1) 60 | requires_grad : bool 61 | If true, parameters of the model require gradients. Possibly 62 | useful for finetuning the network 63 | use_fid_inception : bool 64 | If true, uses the pretrained Inception model used in 65 | Tensorflow's FID implementation. If false, uses the pretrained 66 | Inception model available in torchvision. The FID Inception 67 | model has different weights and a slightly different structure 68 | from torchvision's Inception model. If you want to compute FID 69 | scores, you are strongly advised to set this parameter to true 70 | to get comparable results. 71 | """ 72 | super(InceptionV3, self).__init__() 73 | 74 | self.resize_input = resize_input 75 | self.normalize_input = normalize_input 76 | self.output_blocks = output_blocks 77 | self.last_needed_block = max(output_blocks) 78 | 79 | # assert self.last_needed_block <= 3, \ 80 | # 'Last possible output block index is 3' 81 | 82 | self.blocks = nn.ModuleList() 83 | 84 | if use_fid_inception: 85 | inception = fid_inception_v3() 86 | else: 87 | if TORCHVISION_VERSION < version.parse("0.13.0"): 88 | inception = models.inception_v3( 89 | pretrained=True, 90 | init_weights=False) 91 | else: 92 | inception = models.inception_v3( 93 | weights=models.Inception_V3_Weights.IMAGENET1K_V1, 94 | ) 95 | 96 | # Block 0: input to maxpool1 97 | block0 = [ 98 | inception.Conv2d_1a_3x3, 99 | inception.Conv2d_2a_3x3, 100 | inception.Conv2d_2b_3x3, 101 | nn.MaxPool2d(kernel_size=3, stride=2) 102 | ] 103 | self.blocks.append(nn.Sequential(*block0)) 104 | 105 | # Block 1: maxpool1 to maxpool2 106 | if self.last_needed_block >= 1: 107 | block1 = [ 108 | inception.Conv2d_3b_1x1, 109 | inception.Conv2d_4a_3x3, 110 | nn.MaxPool2d(kernel_size=3, stride=2) 111 | ] 112 | self.blocks.append(nn.Sequential(*block1)) 113 | 114 | # Block 2: maxpool2 to aux classifier 115 | if self.last_needed_block >= 2: 116 | block2 = [ 117 | inception.Mixed_5b, 118 | inception.Mixed_5c, 119 | inception.Mixed_5d, 120 | inception.Mixed_6a, 121 | inception.Mixed_6b, 122 | inception.Mixed_6c, 123 | inception.Mixed_6d, 124 | inception.Mixed_6e, 125 | ] 126 | self.blocks.append(nn.Sequential(*block2)) 127 | 128 | # Block 3: aux classifier to final avgpool 129 | if self.last_needed_block >= 3: 130 | block3 = [ 131 | inception.Mixed_7a, 132 | inception.Mixed_7b, 133 | inception.Mixed_7c, 134 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 135 | ] 136 | self.blocks.append(nn.Sequential(*block3)) 137 | 138 | if self.last_needed_block >= 4: 139 | inception.fc.bias = None 140 | self.blocks.append(inception.fc) 141 | 142 | for param in self.parameters(): 143 | param.requires_grad = requires_grad 144 | 145 | def forward(self, x): 146 | """Get Inception feature maps. 147 | 148 | Args: 149 | x : torch.FloatTensor,Input tensor of shape [B x 3 x H x W]. If 150 | `normalize_input` is True, values are expected to be in range 151 | [0, 1]; Otherwise, values are expected to be in range [-1, 1]. 152 | 153 | Returns: 154 | List of torch.FloatTensor, corresponding to the selected output 155 | block, sorted ascending by index 156 | """ 157 | outputs = [None for _ in range(len(self.output_blocks))] 158 | 159 | if self.resize_input: 160 | x = F.interpolate(x, 161 | size=(299, 299), 162 | mode='bilinear', 163 | align_corners=False) 164 | 165 | if self.normalize_input: 166 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 167 | 168 | for idx, block in enumerate(self.blocks): 169 | if idx < 4: 170 | x = block(x) 171 | else: 172 | x = F.dropout(x, training=self.training) # N x 2048 x 1 x 1 173 | x = torch.flatten(x, start_dim=1) # N x 2048 174 | x = block(x) # N x 1000 175 | x = F.softmax(x, dim=1) 176 | 177 | if idx in self.output_blocks: 178 | order = self.output_blocks.index(idx) 179 | outputs[order] = x 180 | 181 | if idx == self.last_needed_block: 182 | break 183 | 184 | return outputs 185 | 186 | 187 | def fid_inception_v3(): 188 | """Build pretrained Inception model for FID computation. 189 | 190 | The Inception model for FID computation uses a different set of weights 191 | and has a slightly different structure than torchvision's Inception. 192 | 193 | This method first constructs torchvision's Inception and then patches the 194 | necessary parts that are different in the FID Inception model. 195 | """ 196 | if TORCHVISION_VERSION < version.parse("0.13.0"): 197 | inception = models.inception_v3( 198 | pretrained=False, 199 | aux_logits=False, 200 | num_classes=1008, 201 | init_weights=False, 202 | ) 203 | else: 204 | inception = models.inception_v3( 205 | weights=None, 206 | aux_logits=False, 207 | num_classes=1008, 208 | init_weights=False, 209 | ) 210 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 211 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 212 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 213 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 214 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 215 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 216 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 217 | inception.Mixed_7b = FIDInceptionE_1(1280) 218 | inception.Mixed_7c = FIDInceptionE_2(2048) 219 | 220 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 221 | inception.load_state_dict(state_dict) 222 | return inception 223 | 224 | 225 | class FIDInceptionA(models.inception.InceptionA): 226 | """InceptionA block patched for FID computation.""" 227 | 228 | def __init__(self, in_channels, pool_features): # noqa 229 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 230 | 231 | def forward(self, x): # noqa 232 | branch1x1 = self.branch1x1(x) 233 | 234 | branch5x5 = self.branch5x5_1(x) 235 | branch5x5 = self.branch5x5_2(branch5x5) 236 | 237 | branch3x3dbl = self.branch3x3dbl_1(x) 238 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 239 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 240 | 241 | # Patch: Tensorflow's average pool does not use the padded zero's in 242 | # its average calculation 243 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 244 | count_include_pad=False) 245 | branch_pool = self.branch_pool(branch_pool) 246 | 247 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 248 | return torch.cat(outputs, 1) 249 | 250 | 251 | class FIDInceptionC(models.inception.InceptionC): 252 | """InceptionC block patched for FID computation.""" 253 | 254 | def __init__(self, in_channels, channels_7x7): # noqa 255 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 256 | 257 | def forward(self, x): # noqa 258 | branch1x1 = self.branch1x1(x) 259 | 260 | branch7x7 = self.branch7x7_1(x) 261 | branch7x7 = self.branch7x7_2(branch7x7) 262 | branch7x7 = self.branch7x7_3(branch7x7) 263 | 264 | branch7x7dbl = self.branch7x7dbl_1(x) 265 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 266 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 267 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 268 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 269 | 270 | # Patch: Tensorflow's average pool does not use the padded zero's in 271 | # its average calculation 272 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 273 | count_include_pad=False) 274 | branch_pool = self.branch_pool(branch_pool) 275 | 276 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 277 | return torch.cat(outputs, 1) 278 | 279 | 280 | class FIDInceptionE_1(models.inception.InceptionE): 281 | """First InceptionE block patched for FID computation.""" 282 | 283 | def __init__(self, in_channels): # noqa 284 | super(FIDInceptionE_1, self).__init__(in_channels) 285 | 286 | def forward(self, x): # noqa 287 | branch1x1 = self.branch1x1(x) 288 | 289 | branch3x3 = self.branch3x3_1(x) 290 | branch3x3 = [ 291 | self.branch3x3_2a(branch3x3), 292 | self.branch3x3_2b(branch3x3), 293 | ] 294 | branch3x3 = torch.cat(branch3x3, 1) 295 | 296 | branch3x3dbl = self.branch3x3dbl_1(x) 297 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 298 | branch3x3dbl = [ 299 | self.branch3x3dbl_3a(branch3x3dbl), 300 | self.branch3x3dbl_3b(branch3x3dbl), 301 | ] 302 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 303 | 304 | # Patch: Tensorflow's average pool does not use the padded zero's in 305 | # its average calculation 306 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 307 | count_include_pad=False) 308 | branch_pool = self.branch_pool(branch_pool) 309 | 310 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 311 | return torch.cat(outputs, 1) 312 | 313 | 314 | class FIDInceptionE_2(models.inception.InceptionE): 315 | """Second InceptionE block patched for FID computation.""" 316 | 317 | def __init__(self, in_channels): # noqa 318 | super(FIDInceptionE_2, self).__init__(in_channels) 319 | 320 | def forward(self, x): # noqa 321 | branch1x1 = self.branch1x1(x) 322 | 323 | branch3x3 = self.branch3x3_1(x) 324 | branch3x3 = [ 325 | self.branch3x3_2a(branch3x3), 326 | self.branch3x3_2b(branch3x3), 327 | ] 328 | branch3x3 = torch.cat(branch3x3, 1) 329 | 330 | branch3x3dbl = self.branch3x3dbl_1(x) 331 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 332 | branch3x3dbl = [ 333 | self.branch3x3dbl_3a(branch3x3dbl), 334 | self.branch3x3dbl_3b(branch3x3dbl), 335 | ] 336 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 337 | 338 | # Patch: The FID Inception model uses max pooling instead of average 339 | # pooling. This is likely an error in this specific Inception 340 | # implementation, as other Inception models use average pooling here 341 | # (which matches the description in the paper). 342 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 343 | branch_pool = self.branch_pool(branch_pool) 344 | 345 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 346 | return torch.cat(outputs, 1) 347 | -------------------------------------------------------------------------------- /pytorch_image_generation_metrics/utils.py: -------------------------------------------------------------------------------- 1 | """The public API of pytorch_image_generation_metrics.""" 2 | 3 | import os 4 | from typing import List, Union, Tuple, Optional 5 | from glob import glob 6 | 7 | import numpy as np 8 | import torch 9 | from PIL import Image 10 | from torch.utils.data import Dataset, DataLoader, SequentialSampler 11 | from torch.utils.data.distributed import DistributedSampler 12 | from torchvision.transforms import Compose, Resize, ToTensor 13 | from torchvision.transforms.functional import to_tensor 14 | 15 | from .districuted import rank, world_size 16 | from .core import ( 17 | get_inception_feature, 18 | calculate_inception_score, 19 | calculate_frechet_inception_distance, 20 | torch_cov) 21 | 22 | 23 | class ImageDataset(Dataset): 24 | """An simple image dataset for calculating inception score and FID.""" 25 | 26 | def __init__(self, root, exts=['png', 'jpg', 'JPEG'], transform=None, 27 | num_images=None): 28 | """Construct an image dataset. 29 | 30 | Args: 31 | root: Path to the image directory. This directory will be 32 | recursively searched. 33 | exts: List of extensions to search for. 34 | transform: A torchvision transform to apply to the images. If 35 | None, the images will be converted to tensors. 36 | num_images: The number of images to load. If None, all images 37 | will be loaded. 38 | """ 39 | self.paths = [] 40 | self.transform = transform 41 | for ext in exts: 42 | self.paths.extend( 43 | list(glob( 44 | os.path.join(root, '**/*.%s' % ext), recursive=True))) 45 | self.paths = self.paths[:num_images] 46 | 47 | def __len__(self): # noqa 48 | return len(self.paths) 49 | 50 | def __getitem__(self, idx): # noqa 51 | image = Image.open(self.paths[idx]) 52 | image = image.convert('RGB') # fix ImageNet grayscale images 53 | if self.transform is not None: 54 | image = self.transform(image) 55 | else: 56 | image = to_tensor(image) 57 | return image 58 | 59 | 60 | def get_inception_score_and_fid( 61 | images: Union[torch.FloatTensor, DataLoader], 62 | fid_ref: str, 63 | splits: int = 10, 64 | use_torch: bool = False, 65 | **kwargs, 66 | ) -> Tuple[Tuple[float, float], float]: 67 | """Calculate Inception Score and FID. 68 | 69 | For each image, only a forward propagation is required to 70 | calculating features for FID and Inception Score. 71 | 72 | Args: 73 | images: List of tensor or torch.utils.data.Dataloader. The return image 74 | must be float tensor of range [0, 1]. 75 | fid_ref: Path to pre-calculated statistic. 76 | splits: The number of bins of Inception Score. 77 | use_torch: When True, use torch to calculate FID. Otherwise, use numpy. 78 | **kwargs: The arguments passed to 79 | `pytorch_image_generation_metrics.core.get_inception_feature`. 80 | Returns: 81 | inception_score: float tuple, (mean, std) 82 | fid: float 83 | """ 84 | acts, probs = get_inception_feature( 85 | images, dims=[2048, 1008], use_torch=use_torch, **kwargs) 86 | 87 | if rank() != 0: 88 | return (None, None), None 89 | 90 | # Inception Score 91 | inception_score, std = calculate_inception_score(probs, splits, use_torch) 92 | 93 | # Frechet Inception Distance 94 | f = np.load(fid_ref, allow_pickle=True) 95 | if isinstance(f, np.ndarray): 96 | mu, sigma = f.item()['mu'][:], f.item()['sigma'][:] 97 | else: 98 | mu, sigma = f['mu'][:], f['sigma'][:] 99 | f.close() 100 | fid = calculate_frechet_inception_distance(acts, mu, sigma, use_torch) 101 | 102 | return (inception_score, std), fid 103 | 104 | 105 | def get_inception_score_and_fid_from_directory( 106 | path: str, 107 | fid_ref: str, 108 | exts: List[str] = ['png', 'jpg'], 109 | batch_size: int = 50, 110 | splits: int = 10, 111 | use_torch: bool = False, 112 | **kwargs 113 | ) -> Tuple[Tuple[float, float], float]: 114 | """Calculate Inception Score and FID of images in a directory. 115 | 116 | Args: 117 | path: Path to the image directory. This function will recursively find 118 | images in all subfolders. 119 | fid_ref: Path to pre-calculated statistic. 120 | exts: List of extensions to search for. 121 | batch_size: Batch size of DataLoader. 122 | splits: The number of bins of Inception Score. 123 | use_torch: When True, use torch to calculate FID. Otherwise, use numpy. 124 | **kwargs: The arguments passed to 125 | `pytorch_image_generation_metrics.core.get_inception_feature`. 126 | 127 | Returns: 128 | Inception Score: float tuple, mean and std 129 | FID: float 130 | """ 131 | return get_inception_score_and_fid( 132 | images=DataLoader(ImageDataset(path, exts), batch_size=batch_size), 133 | fid_ref=fid_ref, 134 | splits=splits, 135 | use_torch=use_torch, **kwargs) 136 | 137 | 138 | def get_fid( 139 | images: Union[torch.FloatTensor, DataLoader], 140 | fid_ref: str, 141 | use_torch: bool = False, 142 | **kwargs, 143 | ) -> float: 144 | """Calculate Frechet Inception Distance. 145 | 146 | Args: 147 | images: List of tensor or torch.utils.data.Dataloader. The return image 148 | must be float tensor of range [0, 1]. 149 | fid_ref: Path to pre-calculated statistic. 150 | use_torch: When True, use torch to calculate FID. Otherwise, use numpy. 151 | **kwargs: The arguments passed to 152 | `pytorch_image_generation_metrics.core.get_inception_feature`. 153 | 154 | Returns: 155 | FID 156 | """ 157 | acts, = get_inception_feature( 158 | images, dims=[2048], use_torch=use_torch, **kwargs) 159 | 160 | if rank() != 0: 161 | return None 162 | 163 | # Frechet Inception Distance 164 | f = np.load(fid_ref, allow_pickle=True) 165 | if isinstance(f, np.ndarray): 166 | mu, sigma = f.item()['mu'][:], f.item()['sigma'][:] 167 | else: 168 | mu, sigma = f['mu'][:], f['sigma'][:] 169 | f.close() 170 | fid = calculate_frechet_inception_distance(acts, mu, sigma, use_torch) 171 | 172 | return fid 173 | 174 | 175 | def get_fid_from_directory( 176 | path: str, 177 | fid_ref: str, 178 | exts: List[str] = ['png', 'jpg'], 179 | batch_size: int = 50, 180 | use_torch: bool = False, 181 | **kwargs 182 | ) -> float: 183 | """Calculate Frechet Inception Distance of images in a directory. 184 | 185 | Args: 186 | path: Path to the image directory. This function will recursively find 187 | images in all subfolders. 188 | fid_ref: Path to pre-calculated statistic. 189 | exts: List of extensions to search for. 190 | use_torch: When True, use torch to calculate FID. Otherwise, use numpy. 191 | **kwargs: The arguments passed to 192 | `pytorch_image_generation_metrics.core.get_inception_feature`. 193 | 194 | Returns: 195 | FID 196 | """ 197 | return get_fid( 198 | images=DataLoader(ImageDataset(path, exts), batch_size=batch_size), 199 | fid_ref=fid_ref, 200 | use_torch=use_torch, 201 | **kwargs) 202 | 203 | 204 | def get_inception_score( 205 | images: Union[torch.FloatTensor, DataLoader], 206 | splits: int = 10, 207 | use_torch: bool = False, 208 | **kwargs, 209 | ) -> Tuple[float, float]: 210 | """Calculate Inception Score. 211 | 212 | Args: 213 | images: List of tensor or torch.utils.data.Dataloader. The return image 214 | must be float tensor of range [0, 1]. 215 | splits: The number of bins of Inception Score. 216 | use_torch: When True, use torch to calculate FID. Otherwise, use numpy. 217 | **kwargs: The arguments passed to 218 | `pytorch_image_generation_metrics.core.get_inception_feature`. 219 | 220 | Returns: 221 | Inception Score 222 | """ 223 | probs, = get_inception_feature( 224 | images, dims=[1008], use_torch=use_torch, **kwargs) 225 | if rank() != 0: 226 | return (None, None) 227 | inception_score, std = calculate_inception_score(probs, splits, use_torch) 228 | return (inception_score, std) 229 | 230 | 231 | def get_inception_score_from_directory( 232 | path: str, 233 | splits: int = 10, 234 | exts: List[str] = ['png', 'jpg'], 235 | batch_size: int = 50, 236 | use_torch: bool = False, 237 | **kwargs 238 | ) -> Tuple[Tuple[float, float], float]: 239 | """Calculate Frechet Inception Distance of images in a directory. 240 | 241 | Args: 242 | path: Path to the image directory. This function will recursively find 243 | images in all subfolders. 244 | splits: The number of bins of Inception Score. 245 | exts: List of extensions to search for. 246 | batch_size: Batch size of DataLoader. 247 | use_torch: When True, use torch to calculate FID. Otherwise, use numpy. 248 | **kwargs: The arguments passed to 249 | `pytorch_image_generation_metrics.core.get_inception_feature`. 250 | 251 | 252 | Returns: 253 | FID: float 254 | """ 255 | return get_inception_score( 256 | images=DataLoader(ImageDataset(path, exts), batch_size=batch_size), 257 | splits=splits, 258 | use_torch=use_torch, 259 | **kwargs) 260 | 261 | 262 | def calc_fid_ref( 263 | input_path: str, 264 | output_path: str = None, 265 | batch_size: int = 50, 266 | img_size: Optional[int] = None, 267 | use_torch: bool = False, 268 | num_workers: int = os.cpu_count(), 269 | verbose: bool = True, 270 | ) -> None: 271 | """Calculate the FID statistics and save them to a file. 272 | 273 | Args: 274 | input_path (str): Path to the image directory. This function will 275 | recursively find images in all subfolders. 276 | output_path (str): Path to the output file. Use None to disable. 277 | batch_size (int): Batch size. Defaults to 50. 278 | img_size (int): Image size. If None, use the original image size. 279 | num_workers (int): Number of dataloader workers. Default: 280 | os.cpu_count(). 281 | """ 282 | if img_size is not None: 283 | transform = Compose([Resize([img_size, img_size]), ToTensor()]) 284 | else: 285 | transform = ToTensor() 286 | 287 | dataset = ImageDataset(root=input_path, transform=transform) 288 | if world_size() > 1: 289 | sampler = DistributedSampler(dataset, shuffle=False) 290 | else: 291 | sampler = SequentialSampler(dataset) 292 | loader = DataLoader( 293 | dataset, 294 | batch_size=batch_size, 295 | sampler=sampler, 296 | num_workers=num_workers) 297 | acts, = get_inception_feature( 298 | loader, dims=[2048], use_torch=use_torch, verbose=verbose) 299 | 300 | if rank() != 0: 301 | return 302 | 303 | if use_torch: 304 | mu = torch.mean(acts, dim=0).cpu().numpy() 305 | sigma = torch_cov(acts, rowvar=False).cpu().numpy() 306 | else: 307 | mu = np.mean(acts, axis=0) 308 | sigma = np.cov(acts, rowvar=False) 309 | 310 | if output_path is not None: 311 | if os.path.dirname(output_path) != "": 312 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 313 | np.savez_compressed(output_path, mu=mu, sigma=sigma) 314 | 315 | return mu, sigma 316 | -------------------------------------------------------------------------------- /pytorch_image_generation_metrics/version.py: -------------------------------------------------------------------------------- 1 | """Version information for pytorch_image_generation_metrics.""" 2 | 3 | __version__ = '0.6.1' 4 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | packaging 2 | tqdm 3 | scipy 4 | torch>=1.8.2 5 | torchvision>=0.9.2 6 | 7 | pydocstyle 8 | pytest 9 | pytest-order 10 | tox 11 | twine 12 | wheel 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | packaging 2 | tqdm 3 | scipy 4 | torch>=1.8.1 5 | torchvision>=0.9.1 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Install the package.""" 2 | 3 | import os 4 | import setuptools 5 | 6 | 7 | def read(rel_path): 8 | """Read a file.""" 9 | base_path = os.path.abspath(os.path.dirname(__file__)) 10 | with open(os.path.join(base_path, rel_path), 'r') as f: 11 | return f.read() 12 | 13 | 14 | if __name__ == '__main__': 15 | # get __version__ 16 | with open('./pytorch_image_generation_metrics/version.py') as f: 17 | exec(f.read()) 18 | 19 | setuptools.setup( 20 | name='pytorch_image_generation_metrics', 21 | version=__version__, # noqa: F821 22 | author='Yi-Lun Wu', 23 | author_email='w86763777@gmail.com', 24 | description=('Package for calculating image generation metrics using Pytorch'), 25 | long_description=read('README.md'), 26 | long_description_content_type='text/markdown', 27 | url='https://github.com/w86763777/pytorch_image_generation_metrics', 28 | packages=setuptools.find_packages(include=['pytorch_image_generation_metrics']), 29 | keywords=[ 30 | 'PyTorch', 31 | 'Inception Score', 32 | 'IS', 33 | 'Frechet Inception Distance', 34 | 'FID'], 35 | classifiers=[ 36 | 'Programming Language :: Python :: 3', 37 | 'License :: OSI Approved :: Apache Software License', 38 | 'Operating System :: OS Independent', 39 | ], 40 | python_requires='>=3.6', 41 | install_requires=[ 42 | "pytorch-image-generation-metrics", 43 | ], 44 | ) 45 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w86763777/pytorch-image-generation-metrics/a3f5353125b24e11d32a593841ab740cc5fc131e/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | from packaging import version 5 | 6 | import pytest 7 | import torch 8 | from torchvision.datasets import CIFAR10 9 | 10 | from pytorch_image_generation_metrics.utils import ImageDataset 11 | 12 | 13 | TORCH_VERSION = version.parse(torch.__version__).base_version 14 | TEST_ROOT = os.environ.get('TEST_ROOT', './tests') 15 | TEST_NAME = TORCH_VERSION 16 | PATH_CIFAR10 = "/tmp/cifar10" 17 | PATH_CIFAR10_TRAIN = f"{PATH_CIFAR10}/train" 18 | PATH_CIFAR10_TEST = f"{PATH_CIFAR10}/test" 19 | PATH_CIFAR10_TRAIN_FID_REF_NP = f'{TEST_ROOT}/{TEST_NAME}/cifar10.train.npz' 20 | PATH_CIFAR10_TEST_FID_REF_NP = f'{TEST_ROOT}/{TEST_NAME}/cifar10.test.npz' 21 | PATH_CIFAR10_TRAIN_FID_REF_PT = f'{TEST_ROOT}/{TEST_NAME}/cifar10.train.pt.npz' 22 | PATH_CIFAR10_TEST_FID_REF_PT = f'{TEST_ROOT}/{TEST_NAME}/cifar10.test.pt.npz' 23 | NUM_WORKERS = int(os.environ.get('NUM_WORKERS', min(torch.get_num_threads(), 4))) 24 | 25 | 26 | def save_dataset(dataset, root): 27 | os.makedirs(root, exist_ok=True) 28 | for i, (x, _) in enumerate(dataset): 29 | x.save(os.path.join(root, f'{i + 1}.png')) 30 | 31 | 32 | @pytest.fixture 33 | def cifar10_test(): 34 | dataset = CIFAR10(PATH_CIFAR10, train=False, download=True) 35 | if len(glob.glob(os.path.join(PATH_CIFAR10_TEST, '*.png'))) != len(dataset): 36 | save_dataset(dataset, root=PATH_CIFAR10_TEST) 37 | return ImageDataset(PATH_CIFAR10_TEST) 38 | 39 | 40 | @pytest.fixture 41 | def cifar10_train(): 42 | dataset = CIFAR10(PATH_CIFAR10, train=True, download=True) 43 | if len(glob.glob(os.path.join(PATH_CIFAR10_TRAIN, '*.png'))) != len(dataset): 44 | save_dataset(dataset, root=PATH_CIFAR10_TRAIN) 45 | return ImageDataset(PATH_CIFAR10_TRAIN) 46 | 47 | 48 | @pytest.fixture(autouse=True) 49 | def set_caplog(caplog): 50 | caplog.set_level(logging.INFO) 51 | return caplog 52 | 53 | 54 | def format_relative_error(name, value, expected): 55 | return f'{name}: {value:.9f}, expected: {expected:.9f}, relative error: {abs(value - expected) / expected: .5f}' 56 | -------------------------------------------------------------------------------- /tests/test_all_metrics.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pytest 4 | import torch 5 | from torch.utils.data import DataLoader 6 | 7 | from pytorch_image_generation_metrics.utils import ( 8 | get_inception_score_and_fid, 9 | get_inception_score_and_fid_from_directory, 10 | ) 11 | from .conftest import ( 12 | PATH_CIFAR10_TEST, 13 | PATH_CIFAR10_TRAIN_FID_REF_NP, 14 | PATH_CIFAR10_TRAIN_FID_REF_PT, 15 | NUM_WORKERS, 16 | format_relative_error, 17 | ) 18 | from .test_fid import NP_FID, PT_FID 19 | from .test_inception_score import NP_IS, NP_IS_STD, PT_IS, PT_IS_STD 20 | 21 | 22 | @pytest.mark.fid 23 | @pytest.mark.inception_score 24 | @pytest.mark.order(1) 25 | class TestAllMetrics: 26 | @pytest.mark.parametrize("batch_size, fid_ref, use_torch, expected_is, expected_std, expected_fid", [ 27 | (50, PATH_CIFAR10_TRAIN_FID_REF_NP, False, NP_IS, NP_IS_STD, NP_FID), 28 | (50, PATH_CIFAR10_TRAIN_FID_REF_PT, True, PT_IS, PT_IS_STD, PT_FID), 29 | ]) 30 | def test_inception_score_and_fid_dataloader( 31 | self, 32 | cifar10_test, 33 | batch_size, 34 | fid_ref, 35 | use_torch, 36 | expected_is, 37 | expected_std, 38 | expected_fid 39 | ): 40 | loader = DataLoader( 41 | cifar10_test, batch_size=batch_size, num_workers=NUM_WORKERS) 42 | (IS, IS_std), FID = get_inception_score_and_fid( 43 | loader, fid_ref, use_torch=use_torch) 44 | logging.info(format_relative_error("IS", IS, expected_is)) 45 | logging.info(format_relative_error("IS_STD", IS_std, expected_std)) 46 | logging.info(format_relative_error("FID", FID, expected_fid)) 47 | assert torch.allclose(torch.tensor(IS), torch.tensor(expected_is), rtol=1e-2) 48 | assert torch.allclose(torch.tensor(IS_std), torch.tensor(expected_std), rtol=1e-2) 49 | assert torch.allclose(torch.tensor(FID), torch.tensor(expected_fid), rtol=1e-2) 50 | 51 | @pytest.mark.parametrize("batch_size, fid_ref, use_torch, expected_is, expected_std, expected_fid", [ 52 | (50, PATH_CIFAR10_TRAIN_FID_REF_NP, False, NP_IS, NP_IS_STD, NP_FID), 53 | (50, PATH_CIFAR10_TRAIN_FID_REF_PT, True, PT_IS, PT_IS_STD, PT_FID), 54 | ]) 55 | def test_inception_score_and_fid_tensor( 56 | self, 57 | cifar10_test, 58 | batch_size, 59 | fid_ref, 60 | use_torch, 61 | expected_is, 62 | expected_std, 63 | expected_fid 64 | ): 65 | loader = DataLoader( 66 | cifar10_test, batch_size=batch_size, num_workers=NUM_WORKERS) 67 | images = torch.cat([batch_images for batch_images in loader], dim=0) 68 | (IS, IS_std), FID = get_inception_score_and_fid( 69 | images, fid_ref, use_torch=use_torch) 70 | logging.info(format_relative_error("IS", IS, expected_is)) 71 | logging.info(format_relative_error("IS_STD", IS_std, expected_std)) 72 | logging.info(format_relative_error("FID", FID, expected_fid)) 73 | assert torch.allclose(torch.tensor(IS), torch.tensor(expected_is), rtol=1e-2) 74 | assert torch.allclose(torch.tensor(IS_std), torch.tensor(expected_std), rtol=1e-2) 75 | assert torch.allclose(torch.tensor(FID), torch.tensor(expected_fid), rtol=1e-2) 76 | 77 | @pytest.mark.parametrize("batch_size, fid_ref, use_torch, expected_is, expected_std, expected_fid", [ 78 | (50, PATH_CIFAR10_TRAIN_FID_REF_NP, False, NP_IS, NP_IS_STD, NP_FID), 79 | (50, PATH_CIFAR10_TRAIN_FID_REF_PT, True, PT_IS, PT_IS_STD, PT_FID), 80 | ]) 81 | def test_inception_score_and_fid_from_directory( 82 | self, 83 | batch_size, 84 | fid_ref, 85 | use_torch, 86 | expected_is, 87 | expected_std, 88 | expected_fid 89 | ): 90 | (IS, IS_std), FID = get_inception_score_and_fid_from_directory( 91 | PATH_CIFAR10_TEST, fid_ref, batch_size=batch_size, 92 | use_torch=use_torch) 93 | logging.info(format_relative_error("IS", IS, expected_is)) 94 | logging.info(format_relative_error("IS_STD", IS_std, expected_std)) 95 | logging.info(format_relative_error("FID", FID, expected_fid)) 96 | assert torch.allclose(torch.tensor(IS), torch.tensor(expected_is), rtol=1e-2) 97 | assert torch.allclose(torch.tensor(IS_std), torch.tensor(expected_std), rtol=1e-2) 98 | assert torch.allclose(torch.tensor(FID), torch.tensor(expected_fid), rtol=1e-2) 99 | -------------------------------------------------------------------------------- /tests/test_fid.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pytest 4 | import torch 5 | from torch.utils.data import DataLoader 6 | 7 | 8 | from pytorch_image_generation_metrics.utils import ( 9 | get_fid, 10 | get_fid_from_directory 11 | ) 12 | from .conftest import ( 13 | PATH_CIFAR10_TEST, 14 | PATH_CIFAR10_TRAIN_FID_REF_NP, 15 | PATH_CIFAR10_TRAIN_FID_REF_PT, 16 | NUM_WORKERS, 17 | format_relative_error, 18 | ) 19 | 20 | 21 | NP_FID = 3.1525318697637204 22 | PT_FID = 3.145660400390625 # torch==2.3.0 23 | 24 | 25 | @pytest.mark.fid 26 | @pytest.mark.order(1) 27 | class TestFID: 28 | @pytest.mark.parametrize("batch_size, fid_ref, use_torch, expected_fid", [ 29 | (50, PATH_CIFAR10_TRAIN_FID_REF_NP, False, NP_FID), 30 | (50, PATH_CIFAR10_TRAIN_FID_REF_PT, True, PT_FID), 31 | ]) 32 | def test_fid_dataloader( 33 | self, 34 | cifar10_test, 35 | batch_size, 36 | fid_ref, 37 | use_torch, 38 | expected_fid 39 | ): 40 | loader = DataLoader( 41 | cifar10_test, batch_size=batch_size, num_workers=NUM_WORKERS) 42 | FID = get_fid(loader, fid_ref, use_torch=use_torch) 43 | logging.info(format_relative_error("FID", FID, expected_fid)) 44 | assert torch.allclose(torch.tensor(FID), torch.tensor(expected_fid), rtol=1e-2) 45 | 46 | @pytest.mark.parametrize("batch_size, fid_ref, use_torch, expected_fid", [ 47 | (50, PATH_CIFAR10_TRAIN_FID_REF_NP, False, NP_FID), 48 | (50, PATH_CIFAR10_TRAIN_FID_REF_PT, True, PT_FID), 49 | ]) 50 | def test_fid_tensor( 51 | self, 52 | cifar10_test, 53 | batch_size, 54 | fid_ref, 55 | use_torch, 56 | expected_fid 57 | ): 58 | loader = DataLoader( 59 | cifar10_test, batch_size=batch_size, num_workers=NUM_WORKERS) 60 | images = torch.cat([batch_images for batch_images in loader], dim=0) 61 | FID = get_fid(images, fid_ref, use_torch=use_torch) 62 | logging.info(format_relative_error("FID", FID, expected_fid)) 63 | assert torch.allclose(torch.tensor(FID), torch.tensor(expected_fid), rtol=1e-2) 64 | 65 | @pytest.mark.parametrize("batch_size, fid_ref, use_torch, expected_fid", [ 66 | (50, PATH_CIFAR10_TRAIN_FID_REF_NP, False, NP_FID), 67 | (50, PATH_CIFAR10_TRAIN_FID_REF_PT, True, PT_FID), 68 | ]) 69 | def test_fid_from_directory( 70 | self, 71 | batch_size, 72 | fid_ref, 73 | use_torch, 74 | expected_fid 75 | ): 76 | FID = get_fid_from_directory( 77 | PATH_CIFAR10_TEST, fid_ref, batch_size=batch_size, 78 | use_torch=use_torch) 79 | logging.info(format_relative_error("FID", FID, expected_fid)) 80 | assert torch.allclose(torch.tensor(FID), torch.tensor(expected_fid), rtol=1e-2) 81 | -------------------------------------------------------------------------------- /tests/test_fid_ref.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from pytorch_image_generation_metrics.fid_ref import calc_fid_ref 6 | from .conftest import ( 7 | PATH_CIFAR10_TRAIN, 8 | PATH_CIFAR10_TEST, 9 | PATH_CIFAR10_TRAIN_FID_REF_NP, 10 | PATH_CIFAR10_TEST_FID_REF_NP, 11 | PATH_CIFAR10_TRAIN_FID_REF_PT, 12 | PATH_CIFAR10_TEST_FID_REF_PT, 13 | NUM_WORKERS, 14 | ) 15 | 16 | 17 | @pytest.mark.fid 18 | @pytest.mark.order(0) 19 | class TestFidRef: 20 | @pytest.mark.parametrize("output_path,use_torch", [ 21 | (PATH_CIFAR10_TEST_FID_REF_NP, False), 22 | (PATH_CIFAR10_TEST_FID_REF_PT, True), 23 | ]) 24 | def test_cifar10_test_fid_ref(self, output_path, use_torch): 25 | if not os.path.exists(output_path): 26 | calc_fid_ref( 27 | PATH_CIFAR10_TEST, output_path, use_torch=use_torch, 28 | num_workers=NUM_WORKERS, 29 | verbose=False) 30 | 31 | @pytest.mark.parametrize("output_path,use_torch", [ 32 | (PATH_CIFAR10_TRAIN_FID_REF_NP, False), 33 | (PATH_CIFAR10_TRAIN_FID_REF_PT, True), 34 | ]) 35 | def test_cifar10_train_fid_ref(self, output_path, use_torch): 36 | if not os.path.exists(output_path): 37 | calc_fid_ref( 38 | PATH_CIFAR10_TRAIN, output_path, use_torch=use_torch, 39 | num_workers=NUM_WORKERS, 40 | verbose=False) 41 | -------------------------------------------------------------------------------- /tests/test_inception_score.py: -------------------------------------------------------------------------------- 1 | # from packaging import version 2 | import logging 3 | 4 | import pytest 5 | import torch 6 | from torch.utils.data import DataLoader 7 | 8 | from pytorch_image_generation_metrics.utils import ( 9 | get_inception_score, 10 | get_inception_score_from_directory, 11 | ) 12 | from .conftest import ( 13 | PATH_CIFAR10_TEST, 14 | NUM_WORKERS, 15 | format_relative_error, 16 | ) 17 | 18 | NP_IS = 10.968601098 19 | NP_IS_STD = 0.193806868 20 | PT_IS = 10.968603134 # torch==2.3.0 21 | PT_IS_STD = 0.204290837 # torch==2.3.0 22 | 23 | 24 | @pytest.mark.inception_score 25 | @pytest.mark.order(1) 26 | class TestInceptionScore: 27 | @pytest.mark.parametrize("batch_size, use_torch, expected_is, expected_std", [ 28 | (50, False, NP_IS, NP_IS_STD), 29 | (50, True, PT_IS, PT_IS_STD), 30 | ]) 31 | def test_inception_score_dataloader( 32 | self, 33 | cifar10_test, 34 | batch_size, 35 | use_torch, 36 | expected_is, 37 | expected_std 38 | ): 39 | loader = DataLoader( 40 | cifar10_test, batch_size=batch_size, num_workers=NUM_WORKERS) 41 | IS, IS_std = get_inception_score(loader, use_torch=use_torch) 42 | logging.info(format_relative_error("IS", IS, expected_is)) 43 | logging.info(format_relative_error("IS_STD", IS_std, expected_std)) 44 | assert torch.allclose(torch.tensor(IS), torch.tensor(expected_is), rtol=1e-2) 45 | assert torch.allclose(torch.tensor(IS_std), torch.tensor(expected_std), rtol=1e-2) 46 | 47 | @pytest.mark.parametrize("batch_size, use_torch, expected_is, expected_std", [ 48 | (50, False, NP_IS, NP_IS_STD), 49 | (50, True, PT_IS, PT_IS_STD), 50 | ]) 51 | def test_inception_score_tensor( 52 | self, 53 | cifar10_test, 54 | batch_size, 55 | use_torch, 56 | expected_is, 57 | expected_std 58 | ): 59 | loader = DataLoader( 60 | cifar10_test, batch_size=batch_size, num_workers=NUM_WORKERS) 61 | images = torch.cat([batch_images for batch_images in loader], dim=0) 62 | IS, IS_std = get_inception_score(images, use_torch=use_torch) 63 | logging.info(format_relative_error("IS", IS, expected_is)) 64 | logging.info(format_relative_error("IS_STD", IS_std, expected_std)) 65 | assert torch.allclose(torch.tensor(IS), torch.tensor(expected_is), rtol=1e-2) 66 | assert torch.allclose(torch.tensor(IS_std), torch.tensor(expected_std), rtol=1e-2) 67 | 68 | @pytest.mark.parametrize("batch_size, use_torch, expected_is, expected_std", [ 69 | (50, False, NP_IS, NP_IS_STD), 70 | (50, True, PT_IS, PT_IS_STD), 71 | ]) 72 | def test_inception_score_from_directory( 73 | self, 74 | batch_size, 75 | use_torch, 76 | expected_is, 77 | expected_std 78 | ): 79 | IS, IS_std = get_inception_score_from_directory( 80 | PATH_CIFAR10_TEST, batch_size=batch_size, use_torch=use_torch) 81 | logging.info(format_relative_error("IS", IS, expected_is)) 82 | logging.info(format_relative_error("IS_STD", IS_std, expected_std)) 83 | assert torch.allclose(torch.tensor(IS), torch.tensor(expected_is), rtol=1e-2) 84 | assert torch.allclose(torch.tensor(IS_std), torch.tensor(expected_std), rtol=1e-2) 85 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | minversion = 6.0 3 | log_cli = True 4 | log_cli_level = INFO 5 | log_cli_format = %(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s) 6 | markers = 7 | inception_score 8 | fid 9 | 10 | [tox] 11 | toxworkdir=/tmp/.tox 12 | envlist = flake8,pydocstyle,py39-pt{1131, 230} 13 | 14 | [testenv] 15 | deps = 16 | pytest==8.2.1 17 | pytest-order==1.2.1 18 | setenv = 19 | TEST_ROOT = ./tests 20 | NUM_WORKERS = 4 21 | commands = pytest tests {posargs} 22 | 23 | [testenv:flake8] 24 | deps = flake8 25 | commands = flake8 --ignore=E501 pytorch_image_generation_metrics tests 26 | 27 | [testenv:pydocstyle] 28 | deps = pydocstyle 29 | commands = pydocstyle pytorch_image_generation_metrics 30 | 31 | [testenv:py39-pt1131] 32 | passenv = * 33 | install_command = 34 | pip3 install {opts} {packages} --extra-index-url https://download.pytorch.org/whl/cu117 35 | deps = 36 | -rrequirements.txt 37 | {[testenv]deps} 38 | torch==1.13.1+cu117 39 | torchvision==0.14.1+cu117 40 | 41 | [testenv:py39-pt230] 42 | passenv = * 43 | install_command = 44 | pip install {opts} {packages} --extra-index-url https://download.pytorch.org/whl/cu121 45 | deps = 46 | -rrequirements.txt 47 | {[testenv]deps} 48 | torch==2.3.0 49 | torchvision==0.18.0 50 | --------------------------------------------------------------------------------