├── .gitignore ├── LICENSE ├── README.md ├── fid_score.py ├── kid_score.py └── models ├── inception.py ├── lenet.pth └── lenet.py /.gitignore: -------------------------------------------------------------------------------- 1 | mnist-test.npy 2 | cifar*.npy 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # Environments 95 | .env 96 | .venv 97 | env/ 98 | venv/ 99 | ENV/ 100 | env.bak/ 101 | venv.bak/ 102 | 103 | # Spyder project settings 104 | .spyderproject 105 | .spyproject 106 | 107 | # Rope project settings 108 | .ropeproject 109 | 110 | # mkdocs documentation 111 | /site 112 | 113 | # mypy 114 | .mypy_cache/ 115 | .dmypy.json 116 | dmypy.json 117 | 118 | # Pyre type checker 119 | .pyre/ 120 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Metrics for Evaluating GANs (Pytorch) 2 | 3 | The following GAN metrics are implemented: 4 | 5 | 1. Fréchet Inception Distance (FID) 6 | 2. Kernel Inception Distance (KID) 7 | 8 | 9 | ## Usage 10 | 11 | Requirements: 12 | - python3 13 | - pytorch 14 | - torchvision 15 | - numpy 16 | - scipy 17 | - scikit-learn 18 | - Pillow 19 | 20 | To compute the FID or KID score between two datasets with features extracted from inception net: 21 | 22 | * Ensure that you have saved both datasets as numpy files (`.npy`) in channels-first format, i.e. `(no. of images, channels, height, width)` 23 | 24 | ``` 25 | python fid_score.py --true path/to/real/images.npy --fake path/to/gan/generated.npy --gpu ID 26 | ``` 27 | ``` 28 | python kid_score.py --true path/to/real/images.npy --fake path/to/gan/generated.npy --gpu ID 29 | ``` 30 | 31 | ### Using different layers for feature maps 32 | 33 | In difference to the official implementation, you can choose to use a different feature layer of the Inception network instead of the default `pool3` layer. 34 | As the lower layer features still have spatial extent, the features are first global average pooled to a vector before estimating mean and covariance. 35 | 36 | This might be useful if the datasets you want to compare have less than the otherwise required 2048 images. 37 | Note that this changes the magnitude of the FID score and you can not compare them against scores calculated on another dimensionality. 38 | The resulting scores might also no longer correlate with visual quality. 39 | 40 | You can select the dimensionality of features to use with the flag `--dims N`, where N is the dimensionality of features. 41 | The choices are: 42 | - 64: first max pooling features 43 | - 192: second max pooling featurs 44 | - 768: pre-aux classifier features 45 | - 2048: final average pooling features (this is the default) 46 | 47 | ### MNIST 48 | 49 | The repo also contains a LeNet (modified from [activatedgeek/LeNet-5](https://github.com/activatedgeek/LeNet-5)) pretrained on MNIST which can be used for evaluating MNIST samples. Just set the model to LeNet using `--model_type lenet`. 50 | 51 | ``` 52 | python fid_score.py --true path/to/real/images.npy --fake path/to/gan/generated.npy --gpu ID --model_type lenet 53 | ``` 54 | ``` 55 | python kid_score.py --true path/to/real/images.npy --fake path/to/gan/generated.npy --gpu ID --model_type lenet 56 | ``` 57 | 58 | ## License for [mseitzer/pytorch-fid](https://github.com/mseitzer/pytorch-fid) 59 | 60 | This implementation is licensed under the Apache License 2.0. 61 | 62 | FID was introduced by Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler and Sepp Hochreiter in "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium", see [https://arxiv.org/abs/1706.08500](https://arxiv.org/abs/1706.08500) 63 | 64 | The original implementation is by the Institute of Bioinformatics, JKU Linz, licensed under the Apache License 2.0. 65 | See [https://github.com/bioinf-jku/TTUR](https://github.com/bioinf-jku/TTUR). 66 | -------------------------------------------------------------------------------- /fid_score.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs 3 | 4 | The FID metric calculates the distance between two distributions of images. 5 | Typically, we have summary statistics (mean & covariance matrix) of one 6 | of these distributions, while the 2nd distribution is given by a GAN. 7 | 8 | When run as a stand-alone program, it compares the distribution of 9 | images that are stored as PNG/JPEG at a specified location with a 10 | distribution given by summary statistics (in pickle format). 11 | 12 | The FID is calculated by assuming that X_1 and X_2 are the activations of 13 | the pool_3 layer of the inception net for generated samples and real world 14 | samples respectively. 15 | 16 | See --help to see further details. 17 | 18 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 19 | of Tensorflow 20 | 21 | Copyright 2018 Institute of Bioinformatics, JKU Linz 22 | 23 | Licensed under the Apache License, Version 2.0 (the "License"); 24 | you may not use this file except in compliance with the License. 25 | You may obtain a copy of the License at 26 | 27 | http://www.apache.org/licenses/LICENSE-2.0 28 | 29 | Unless required by applicable law or agreed to in writing, software 30 | distributed under the License is distributed on an "AS IS" BASIS, 31 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 32 | See the License for the specific language governing permissions and 33 | limitations under the License. 34 | """ 35 | import os 36 | import pathlib 37 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 38 | 39 | import numpy as np 40 | import torch 41 | from scipy import linalg 42 | from PIL import Image 43 | from torch.nn.functional import adaptive_avg_pool2d 44 | 45 | try: 46 | from tqdm import tqdm 47 | except ImportError: 48 | # If not tqdm is not available, provide a mock version of it 49 | def tqdm(x): return x 50 | #from models import lenet 51 | from models.inception import InceptionV3 52 | from models.lenet import LeNet5 53 | 54 | 55 | def get_activations(files, model, batch_size=50, dims=2048, 56 | cuda=False, verbose=False): 57 | """Calculates the activations of the pool_3 layer for all images. 58 | 59 | Params: 60 | -- files : List of image files paths 61 | -- model : Instance of inception model 62 | -- batch_size : Batch size of images for the model to process at once. 63 | Make sure that the number of samples is a multiple of 64 | the batch size, otherwise some samples are ignored. This 65 | behavior is retained to match the original FID score 66 | implementation. 67 | -- dims : Dimensionality of features returned by Inception 68 | -- cuda : If set to True, use GPU 69 | -- verbose : If set to True and parameter out_step is given, the number 70 | of calculated batches is reported. 71 | Returns: 72 | -- A numpy array of dimension (num images, dims) that contains the 73 | activations of the given tensor when feeding inception with the 74 | query tensor. 75 | """ 76 | model.eval() 77 | 78 | is_numpy = True if type(files[0]) == np.ndarray else False 79 | 80 | if len(files) % batch_size != 0: 81 | print(('Warning: number of images is not a multiple of the ' 82 | 'batch size. Some samples are going to be ignored.')) 83 | if batch_size > len(files): 84 | print(('Warning: batch size is bigger than the data size. ' 85 | 'Setting batch size to data size')) 86 | batch_size = len(files) 87 | 88 | n_batches = len(files) // batch_size 89 | n_used_imgs = n_batches * batch_size 90 | 91 | pred_arr = np.empty((n_used_imgs, dims)) 92 | 93 | for i in tqdm(range(n_batches)): 94 | if verbose: 95 | print('\rPropagating batch %d/%d' % (i + 1, n_batches), end='', flush=True) 96 | start = i * batch_size 97 | end = start + batch_size 98 | if is_numpy: 99 | images = np.copy(files[start:end]) + 1 100 | images /= 2. 101 | else: 102 | images = [np.array(Image.open(str(f))) for f in files[start:end]] 103 | images = np.stack(images).astype(np.float32) / 255. 104 | # Reshape to (n_images, 3, height, width) 105 | images = images.transpose((0, 3, 1, 2)) 106 | 107 | batch = torch.from_numpy(images).type(torch.FloatTensor) 108 | if cuda: 109 | batch = batch.cuda() 110 | 111 | pred = model(batch)[0] 112 | 113 | # If model output is not scalar, apply global spatial average pooling. 114 | # This happens if you choose a dimensionality not equal 2048. 115 | if pred.shape[2] != 1 or pred.shape[3] != 1: 116 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 117 | 118 | pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1) 119 | 120 | if verbose: 121 | print('done', np.min(images)) 122 | 123 | return pred_arr 124 | 125 | 126 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 127 | """Numpy implementation of the Frechet Distance. 128 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 129 | and X_2 ~ N(mu_2, C_2) is 130 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 131 | 132 | Stable version by Dougal J. Sutherland. 133 | 134 | Params: 135 | -- mu1 : Numpy array containing the activations of a layer of the 136 | inception net (like returned by the function 'get_predictions') 137 | for generated samples. 138 | -- mu2 : The sample mean over activations, precalculated on an 139 | representative data set. 140 | -- sigma1: The covariance matrix over activations for generated samples. 141 | -- sigma2: The covariance matrix over activations, precalculated on an 142 | representative data set. 143 | 144 | Returns: 145 | -- : The Frechet Distance. 146 | """ 147 | 148 | mu1 = np.atleast_1d(mu1) 149 | mu2 = np.atleast_1d(mu2) 150 | 151 | sigma1 = np.atleast_2d(sigma1) 152 | sigma2 = np.atleast_2d(sigma2) 153 | 154 | assert mu1.shape == mu2.shape, \ 155 | 'Training and test mean vectors have different lengths' 156 | assert sigma1.shape == sigma2.shape, \ 157 | 'Training and test covariances have different dimensions' 158 | 159 | diff = mu1 - mu2 160 | 161 | # Product might be almost singular 162 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 163 | if not np.isfinite(covmean).all(): 164 | msg = ('fid calculation produces singular product; ' 165 | 'adding %s to diagonal of cov estimates') % eps 166 | print(msg) 167 | offset = np.eye(sigma1.shape[0]) * eps 168 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 169 | 170 | # Numerical error might give slight imaginary component 171 | if np.iscomplexobj(covmean): 172 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 173 | m = np.max(np.abs(covmean.imag)) 174 | raise ValueError('Imaginary component {}'.format(m)) 175 | covmean = covmean.real 176 | 177 | tr_covmean = np.trace(covmean) 178 | 179 | return (diff.dot(diff) + np.trace(sigma1) + 180 | np.trace(sigma2) - 2 * tr_covmean) 181 | 182 | 183 | def calculate_activation_statistics(act): 184 | """Calculation of the statistics used by the FID. 185 | Params: 186 | -- files : List of image files paths 187 | -- model : Instance of inception model 188 | -- batch_size : The images numpy array is split into batches with 189 | batch size batch_size. A reasonable batch size 190 | depends on the hardware. 191 | -- dims : Dimensionality of features returned by Inception 192 | -- cuda : If set to True, use GPU 193 | -- verbose : If set to True and parameter out_step is given, the 194 | number of calculated batches is reported. 195 | Returns: 196 | -- mu : The mean over samples of the activations of the pool_3 layer of 197 | the inception model. 198 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 199 | the inception model. 200 | """ 201 | mu = np.mean(act, axis=0) 202 | sigma = np.cov(act, rowvar=False) 203 | return mu, sigma 204 | 205 | 206 | def extract_lenet_features(imgs, net, cuda): 207 | net.eval() 208 | feats = [] 209 | imgs = imgs.reshape([-1, 100] + list(imgs.shape[1:])) 210 | if imgs[0].min() < -0.001: 211 | imgs = (imgs + 1)/2.0 212 | print(imgs.shape, imgs.min(), imgs.max()) 213 | if cuda: 214 | imgs = torch.from_numpy(imgs).cuda() 215 | else: 216 | imgs = torch.from_numpy(imgs) 217 | for i, images in enumerate(imgs): 218 | feats.append(net.extract_features(images).detach().cpu().numpy()) 219 | feats = np.vstack(feats) 220 | return feats 221 | 222 | 223 | def _compute_activations(path, model, batch_size, dims, cuda, model_type): 224 | if not type(path) == np.ndarray: 225 | import glob 226 | jpg = os.path.join(path, '*.jpg') 227 | png = os.path.join(path, '*.png') 228 | path = glob.glob(jpg) + glob.glob(png) 229 | if len(path) > 25000: 230 | import random 231 | random.shuffle(path) 232 | path = path[:25000] 233 | if model_type == 'inception': 234 | act = get_activations(path, model, batch_size, dims, cuda) 235 | elif model_type == 'lenet': 236 | act = extract_lenet_features(path, model, cuda) 237 | 238 | return act 239 | 240 | 241 | def calculate_fid_given_paths(paths, batch_size, cuda, dims, bootstrap=True, n_bootstraps=10, model_type='inception'): 242 | """Calculates the FID of two paths""" 243 | pths = [] 244 | for p in paths: 245 | if not os.path.exists(p): 246 | raise RuntimeError('Invalid path: %s' % p) 247 | if os.path.isdir(p): 248 | pths.append(p) 249 | elif p.endswith('.npy'): 250 | np_imgs = np.load(p) 251 | if np_imgs.shape[0] > 25000: 252 | np_imgs = np_imgs[:50000] 253 | pths.append(np_imgs) 254 | 255 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 256 | 257 | if model_type == 'inception': 258 | model = InceptionV3([block_idx]) 259 | elif model_type == 'lenet': 260 | model = LeNet5() 261 | model.load_state_dict(torch.load('./models/lenet.pth')) 262 | if cuda: 263 | model.cuda() 264 | 265 | act_true = _compute_activations(pths[0], model, batch_size, dims, cuda, model_type) 266 | n_bootstraps = n_bootstraps if bootstrap else 1 267 | pths = pths[1:] 268 | results = [] 269 | for j, pth in enumerate(pths): 270 | print(paths[j+1]) 271 | actj = _compute_activations(pth, model, batch_size, dims, cuda, model_type) 272 | fid_values = np.zeros((n_bootstraps)) 273 | with tqdm(range(n_bootstraps), desc='FID') as bar: 274 | for i in bar: 275 | act1_bs = act_true[np.random.choice(act_true.shape[0], act_true.shape[0], replace=True)] 276 | act2_bs = actj[np.random.choice(actj.shape[0], actj.shape[0], replace=True)] 277 | m1, s1 = calculate_activation_statistics(act1_bs) 278 | m2, s2 = calculate_activation_statistics(act2_bs) 279 | fid_values[i] = calculate_frechet_distance(m1, s1, m2, s2) 280 | bar.set_postfix({'mean': fid_values[:i+1].mean()}) 281 | results.append((paths[j+1], fid_values.mean(), fid_values.std())) 282 | return results 283 | 284 | 285 | if __name__ == '__main__': 286 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 287 | parser.add_argument('--true', type=str, required=True, 288 | help=('Path to the true images')) 289 | parser.add_argument('--fake', type=str, nargs='+', required=True, 290 | help=('Path to the generated images')) 291 | parser.add_argument('--batch-size', type=int, default=50, 292 | help='Batch size to use') 293 | parser.add_argument('--dims', type=int, default=2048, 294 | choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), 295 | help=('Dimensionality of Inception features to use. ' 296 | 'By default, uses pool3 features')) 297 | parser.add_argument('-c', '--gpu', default='', type=str, 298 | help='GPU to use (leave blank for CPU only)') 299 | parser.add_argument('--model', default='inception', type=str, 300 | help='inception or lenet') 301 | args = parser.parse_args() 302 | print(args) 303 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 304 | paths = [args.true] + args.fake 305 | 306 | results = calculate_fid_given_paths(paths, args.batch_size, args.gpu != '', args.dims, model_type=args.model) 307 | for p, m, s in results: 308 | print('FID (%s): %.2f (%.3f)' % (p, m, s)) 309 | -------------------------------------------------------------------------------- /kid_score.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Calculates the Kernel Inception Distance (KID) to evalulate GANs 3 | """ 4 | import os 5 | import pathlib 6 | import sys 7 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 8 | 9 | import numpy as np 10 | import torch 11 | from sklearn.metrics.pairwise import polynomial_kernel 12 | from scipy import linalg 13 | from PIL import Image 14 | from torch.nn.functional import adaptive_avg_pool2d 15 | 16 | try: 17 | from tqdm import tqdm 18 | except ImportError: 19 | # If not tqdm is not available, provide a mock version of it 20 | def tqdm(x): return x 21 | 22 | from models.inception import InceptionV3 23 | from models.lenet import LeNet5 24 | 25 | def get_activations(files, model, batch_size=50, dims=2048, 26 | cuda=False, verbose=False): 27 | """Calculates the activations of the pool_3 layer for all images. 28 | 29 | Params: 30 | -- files : List of image files paths 31 | -- model : Instance of inception model 32 | -- batch_size : Batch size of images for the model to process at once. 33 | Make sure that the number of samples is a multiple of 34 | the batch size, otherwise some samples are ignored. This 35 | behavior is retained to match the original FID score 36 | implementation. 37 | -- dims : Dimensionality of features returned by Inception 38 | -- cuda : If set to True, use GPU 39 | -- verbose : If set to True and parameter out_step is given, the number 40 | of calculated batches is reported. 41 | Returns: 42 | -- A numpy array of dimension (num images, dims) that contains the 43 | activations of the given tensor when feeding inception with the 44 | query tensor. 45 | """ 46 | model.eval() 47 | 48 | is_numpy = True if type(files[0]) == np.ndarray else False 49 | 50 | if len(files) % batch_size != 0: 51 | print(('Warning: number of images is not a multiple of the ' 52 | 'batch size. Some samples are going to be ignored.')) 53 | if batch_size > len(files): 54 | print(('Warning: batch size is bigger than the data size. ' 55 | 'Setting batch size to data size')) 56 | batch_size = len(files) 57 | 58 | n_batches = len(files) // batch_size 59 | n_used_imgs = n_batches * batch_size 60 | 61 | pred_arr = np.empty((n_used_imgs, dims)) 62 | 63 | for i in tqdm(range(n_batches)): 64 | if verbose: 65 | print('\rPropagating batch %d/%d' % (i + 1, n_batches), end='', flush=True) 66 | start = i * batch_size 67 | end = start + batch_size 68 | if is_numpy: 69 | images = np.copy(files[start:end]) + 1 70 | images /= 2. 71 | else: 72 | images = [np.array(Image.open(str(f))) for f in files[start:end]] 73 | images = np.stack(images).astype(np.float32) / 255. 74 | # Reshape to (n_images, 3, height, width) 75 | images = images.transpose((0, 3, 1, 2)) 76 | 77 | batch = torch.from_numpy(images).type(torch.FloatTensor) 78 | if cuda: 79 | batch = batch.cuda() 80 | 81 | pred = model(batch)[0] 82 | 83 | # If model output is not scalar, apply global spatial average pooling. 84 | # This happens if you choose a dimensionality not equal 2048. 85 | if pred.shape[2] != 1 or pred.shape[3] != 1: 86 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 87 | 88 | pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1) 89 | 90 | if verbose: 91 | print('done', np.min(images)) 92 | 93 | return pred_arr 94 | 95 | 96 | def extract_lenet_features(imgs, net): 97 | net.eval() 98 | feats = [] 99 | imgs = imgs.reshape([-1, 100] + list(imgs.shape[1:])) 100 | if imgs[0].min() < -0.001: 101 | imgs = (imgs + 1)/2.0 102 | print(imgs.shape, imgs.min(), imgs.max()) 103 | imgs = torch.from_numpy(imgs) 104 | for i, images in enumerate(imgs): 105 | feats.append(net.extract_features(images).detach().cpu().numpy()) 106 | feats = np.vstack(feats) 107 | return feats 108 | 109 | 110 | def _compute_activations(path, model, batch_size, dims, cuda, model_type): 111 | if not type(path) == np.ndarray: 112 | import glob 113 | jpg = os.path.join(path, '*.jpg') 114 | png = os.path.join(path, '*.png') 115 | path = glob.glob(jpg) + glob.glob(png) 116 | if len(path) > 50000: 117 | import random 118 | random.shuffle(path) 119 | path = path[:50000] 120 | if model_type == 'inception': 121 | act = get_activations(path, model, batch_size, dims, cuda) 122 | elif model_type == 'lenet': 123 | act = extract_lenet_features(path, model) 124 | return act 125 | 126 | 127 | def calculate_kid_given_paths(paths, batch_size, cuda, dims, model_type='inception'): 128 | """Calculates the KID of two paths""" 129 | pths = [] 130 | for p in paths: 131 | if not os.path.exists(p): 132 | raise RuntimeError('Invalid path: %s' % p) 133 | if os.path.isdir(p): 134 | pths.append(p) 135 | elif p.endswith('.npy'): 136 | np_imgs = np.load(p) 137 | if np_imgs.shape[0] > 50000: np_imgs = np_imgs[np.random.permutation(np.arange(np_imgs.shape[0]))][:50000] 138 | pths.append(np_imgs) 139 | 140 | if model_type == 'inception': 141 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 142 | model = InceptionV3([block_idx]) 143 | elif model_type == 'lenet': 144 | model = LeNet5() 145 | model.load_state_dict(torch.load('./models/lenet.pth')) 146 | if cuda: 147 | model.cuda() 148 | 149 | act_true = _compute_activations(pths[0], model, batch_size, dims, cuda, model_type) 150 | pths = pths[1:] 151 | results = [] 152 | for j, pth in enumerate(pths): 153 | print(paths[j+1]) 154 | actj = _compute_activations(pth, model, batch_size, dims, cuda, model_type) 155 | kid_values = polynomial_mmd_averages(act_true, actj, n_subsets=100) 156 | results.append((paths[j+1], kid_values[0].mean(), kid_values[0].std())) 157 | return results 158 | 159 | def _sqn(arr): 160 | flat = np.ravel(arr) 161 | return flat.dot(flat) 162 | 163 | 164 | def polynomial_mmd_averages(codes_g, codes_r, n_subsets=50, subset_size=1000, 165 | ret_var=True, output=sys.stdout, **kernel_args): 166 | m = min(codes_g.shape[0], codes_r.shape[0]) 167 | mmds = np.zeros(n_subsets) 168 | if ret_var: 169 | vars = np.zeros(n_subsets) 170 | choice = np.random.choice 171 | 172 | with tqdm(range(n_subsets), desc='MMD', file=output) as bar: 173 | for i in bar: 174 | g = codes_g[choice(len(codes_g), subset_size, replace=False)] 175 | r = codes_r[choice(len(codes_r), subset_size, replace=False)] 176 | o = polynomial_mmd(g, r, **kernel_args, var_at_m=m, ret_var=ret_var) 177 | if ret_var: 178 | mmds[i], vars[i] = o 179 | else: 180 | mmds[i] = o 181 | bar.set_postfix({'mean': mmds[:i+1].mean()}) 182 | return (mmds, vars) if ret_var else mmds 183 | 184 | 185 | def polynomial_mmd(codes_g, codes_r, degree=3, gamma=None, coef0=1, 186 | var_at_m=None, ret_var=True): 187 | # use k(x, y) = (gamma + coef0)^degree 188 | # default gamma is 1 / dim 189 | X = codes_g 190 | Y = codes_r 191 | 192 | K_XX = polynomial_kernel(X, degree=degree, gamma=gamma, coef0=coef0) 193 | K_YY = polynomial_kernel(Y, degree=degree, gamma=gamma, coef0=coef0) 194 | K_XY = polynomial_kernel(X, Y, degree=degree, gamma=gamma, coef0=coef0) 195 | 196 | return _mmd2_and_variance(K_XX, K_XY, K_YY, 197 | var_at_m=var_at_m, ret_var=ret_var) 198 | 199 | def _mmd2_and_variance(K_XX, K_XY, K_YY, unit_diagonal=False, 200 | mmd_est='unbiased', block_size=1024, 201 | var_at_m=None, ret_var=True): 202 | # based on 203 | # https://github.com/dougalsutherland/opt-mmd/blob/master/two_sample/mmd.py 204 | # but changed to not compute the full kernel matrix at once 205 | m = K_XX.shape[0] 206 | assert K_XX.shape == (m, m) 207 | assert K_XY.shape == (m, m) 208 | assert K_YY.shape == (m, m) 209 | if var_at_m is None: 210 | var_at_m = m 211 | 212 | # Get the various sums of kernels that we'll use 213 | # Kts drop the diagonal, but we don't need to compute them explicitly 214 | if unit_diagonal: 215 | diag_X = diag_Y = 1 216 | sum_diag_X = sum_diag_Y = m 217 | sum_diag2_X = sum_diag2_Y = m 218 | else: 219 | diag_X = np.diagonal(K_XX) 220 | diag_Y = np.diagonal(K_YY) 221 | 222 | sum_diag_X = diag_X.sum() 223 | sum_diag_Y = diag_Y.sum() 224 | 225 | sum_diag2_X = _sqn(diag_X) 226 | sum_diag2_Y = _sqn(diag_Y) 227 | 228 | Kt_XX_sums = K_XX.sum(axis=1) - diag_X 229 | Kt_YY_sums = K_YY.sum(axis=1) - diag_Y 230 | K_XY_sums_0 = K_XY.sum(axis=0) 231 | K_XY_sums_1 = K_XY.sum(axis=1) 232 | 233 | Kt_XX_sum = Kt_XX_sums.sum() 234 | Kt_YY_sum = Kt_YY_sums.sum() 235 | K_XY_sum = K_XY_sums_0.sum() 236 | 237 | if mmd_est == 'biased': 238 | mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m) 239 | + (Kt_YY_sum + sum_diag_Y) / (m * m) 240 | - 2 * K_XY_sum / (m * m)) 241 | else: 242 | assert mmd_est in {'unbiased', 'u-statistic'} 243 | mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m-1)) 244 | if mmd_est == 'unbiased': 245 | mmd2 -= 2 * K_XY_sum / (m * m) 246 | else: 247 | mmd2 -= 2 * (K_XY_sum - np.trace(K_XY)) / (m * (m-1)) 248 | 249 | if not ret_var: 250 | return mmd2 251 | 252 | Kt_XX_2_sum = _sqn(K_XX) - sum_diag2_X 253 | Kt_YY_2_sum = _sqn(K_YY) - sum_diag2_Y 254 | K_XY_2_sum = _sqn(K_XY) 255 | 256 | dot_XX_XY = Kt_XX_sums.dot(K_XY_sums_1) 257 | dot_YY_YX = Kt_YY_sums.dot(K_XY_sums_0) 258 | 259 | m1 = m - 1 260 | m2 = m - 2 261 | zeta1_est = ( 262 | 1 / (m * m1 * m2) * ( 263 | _sqn(Kt_XX_sums) - Kt_XX_2_sum + _sqn(Kt_YY_sums) - Kt_YY_2_sum) 264 | - 1 / (m * m1)**2 * (Kt_XX_sum**2 + Kt_YY_sum**2) 265 | + 1 / (m * m * m1) * ( 266 | _sqn(K_XY_sums_1) + _sqn(K_XY_sums_0) - 2 * K_XY_2_sum) 267 | - 2 / m**4 * K_XY_sum**2 268 | - 2 / (m * m * m1) * (dot_XX_XY + dot_YY_YX) 269 | + 2 / (m**3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum 270 | ) 271 | zeta2_est = ( 272 | 1 / (m * m1) * (Kt_XX_2_sum + Kt_YY_2_sum) 273 | - 1 / (m * m1)**2 * (Kt_XX_sum**2 + Kt_YY_sum**2) 274 | + 2 / (m * m) * K_XY_2_sum 275 | - 2 / m**4 * K_XY_sum**2 276 | - 4 / (m * m * m1) * (dot_XX_XY + dot_YY_YX) 277 | + 4 / (m**3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum 278 | ) 279 | var_est = (4 * (var_at_m - 2) / (var_at_m * (var_at_m - 1)) * zeta1_est 280 | + 2 / (var_at_m * (var_at_m - 1)) * zeta2_est) 281 | 282 | return mmd2, var_est 283 | 284 | 285 | if __name__ == '__main__': 286 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 287 | parser.add_argument('--true', type=str, required=True, 288 | help=('Path to the true images')) 289 | parser.add_argument('--fake', type=str, nargs='+', required=True, 290 | help=('Path to the generated images')) 291 | parser.add_argument('--batch-size', type=int, default=50, 292 | help='Batch size to use') 293 | parser.add_argument('--dims', type=int, default=2048, 294 | choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), 295 | help=('Dimensionality of Inception features to use. ' 296 | 'By default, uses pool3 features')) 297 | parser.add_argument('-c', '--gpu', default='', type=str, 298 | help='GPU to use (leave blank for CPU only)') 299 | parser.add_argument('--model', default='inception', type=str, 300 | help='inception or lenet') 301 | args = parser.parse_args() 302 | print(args) 303 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 304 | paths = [args.true] + args.fake 305 | 306 | results = calculate_kid_given_paths(paths, args.batch_size, args.gpu != '', args.dims, model_type=args.model) 307 | for p, m, s in results: 308 | print('KID (%s): %.3f (%.3f)' % (p, m, s)) 309 | -------------------------------------------------------------------------------- /models/inception.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torchvision import models 4 | 5 | 6 | class InceptionV3(nn.Module): 7 | """Pretrained InceptionV3 network returning feature maps""" 8 | 9 | # Index of default block of inception to return, 10 | # corresponds to output of final average pooling 11 | DEFAULT_BLOCK_INDEX = 3 12 | 13 | # Maps feature dimensionality to their output blocks indices 14 | BLOCK_INDEX_BY_DIM = { 15 | 64: 0, # First max pooling features 16 | 192: 1, # Second max pooling featurs 17 | 768: 2, # Pre-aux classifier features 18 | 2048: 3 # Final average pooling features 19 | } 20 | 21 | def __init__(self, 22 | output_blocks=[DEFAULT_BLOCK_INDEX], 23 | resize_input=True, 24 | normalize_input=True, 25 | requires_grad=False): 26 | """Build pretrained InceptionV3 27 | 28 | Parameters 29 | ---------- 30 | output_blocks : list of int 31 | Indices of blocks to return features of. Possible values are: 32 | - 0: corresponds to output of first max pooling 33 | - 1: corresponds to output of second max pooling 34 | - 2: corresponds to output which is fed to aux classifier 35 | - 3: corresponds to output of final average pooling 36 | resize_input : bool 37 | If true, bilinearly resizes input to width and height 299 before 38 | feeding input to model. As the network without fully connected 39 | layers is fully convolutional, it should be able to handle inputs 40 | of arbitrary size, so resizing might not be strictly needed 41 | normalize_input : bool 42 | If true, scales the input from range (0, 1) to the range the 43 | pretrained Inception network expects, namely (-1, 1) 44 | requires_grad : bool 45 | If true, parameters of the model require gradient. Possibly useful 46 | for finetuning the network 47 | """ 48 | super(InceptionV3, self).__init__() 49 | 50 | self.resize_input = resize_input 51 | self.normalize_input = normalize_input 52 | self.output_blocks = sorted(output_blocks) 53 | self.last_needed_block = max(output_blocks) 54 | 55 | assert self.last_needed_block <= 3, \ 56 | 'Last possible output block index is 3' 57 | 58 | self.blocks = nn.ModuleList() 59 | 60 | inception = models.inception_v3(pretrained=True) 61 | 62 | # Block 0: input to maxpool1 63 | block0 = [ 64 | inception.Conv2d_1a_3x3, 65 | inception.Conv2d_2a_3x3, 66 | inception.Conv2d_2b_3x3, 67 | nn.MaxPool2d(kernel_size=3, stride=2) 68 | ] 69 | self.blocks.append(nn.Sequential(*block0)) 70 | 71 | # Block 1: maxpool1 to maxpool2 72 | if self.last_needed_block >= 1: 73 | block1 = [ 74 | inception.Conv2d_3b_1x1, 75 | inception.Conv2d_4a_3x3, 76 | nn.MaxPool2d(kernel_size=3, stride=2) 77 | ] 78 | self.blocks.append(nn.Sequential(*block1)) 79 | 80 | # Block 2: maxpool2 to aux classifier 81 | if self.last_needed_block >= 2: 82 | block2 = [ 83 | inception.Mixed_5b, 84 | inception.Mixed_5c, 85 | inception.Mixed_5d, 86 | inception.Mixed_6a, 87 | inception.Mixed_6b, 88 | inception.Mixed_6c, 89 | inception.Mixed_6d, 90 | inception.Mixed_6e, 91 | ] 92 | self.blocks.append(nn.Sequential(*block2)) 93 | 94 | # Block 3: aux classifier to final avgpool 95 | if self.last_needed_block >= 3: 96 | block3 = [ 97 | inception.Mixed_7a, 98 | inception.Mixed_7b, 99 | inception.Mixed_7c, 100 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 101 | ] 102 | self.blocks.append(nn.Sequential(*block3)) 103 | 104 | for param in self.parameters(): 105 | param.requires_grad = requires_grad 106 | 107 | def forward(self, inp): 108 | """Get Inception feature maps 109 | 110 | Parameters 111 | ---------- 112 | inp : torch.autograd.Variable 113 | Input tensor of shape Bx3xHxW. Values are expected to be in 114 | range (0.0, 1.0) 115 | 116 | Returns 117 | ------- 118 | List of torch.autograd.Variable, corresponding to the selected output 119 | block, sorted ascending by index 120 | """ 121 | outp = [] 122 | x = inp 123 | 124 | if self.resize_input: 125 | x = F.interpolate(x, 126 | size=(299, 299), 127 | mode='bilinear', 128 | align_corners=False) 129 | 130 | if self.normalize_input: 131 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 132 | 133 | for idx, block in enumerate(self.blocks): 134 | x = block(x) 135 | if idx in self.output_blocks: 136 | outp.append(x) 137 | 138 | if idx == self.last_needed_block: 139 | break 140 | 141 | return outp 142 | -------------------------------------------------------------------------------- /models/lenet.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abdulfatir/gan-metrics-pytorch/4537f97a9e0be0ab8bd86f8a9e3ac6274597e0e3/models/lenet.pth -------------------------------------------------------------------------------- /models/lenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from collections import OrderedDict 3 | 4 | 5 | class LeNet5(nn.Module): 6 | """ 7 | Input - 1x32x32 8 | C1 - 6@28x28 (5x5 kernel) 9 | tanh 10 | S2 - 6@14x14 (2x2 kernel, stride 2) Subsampling 11 | C3 - 16@10x10 (5x5 kernel, complicated shit) 12 | tanh 13 | S4 - 16@5x5 (2x2 kernel, stride 2) Subsampling 14 | C5 - 120@1x1 (5x5 kernel) 15 | F6 - 84 16 | tanh 17 | F7 - 10 (Output) 18 | """ 19 | def __init__(self): 20 | super(LeNet5, self).__init__() 21 | 22 | self.convnet = nn.Sequential(OrderedDict([ 23 | ('c1', nn.Conv2d(1, 6, kernel_size=(5, 5))), 24 | ('tanh1', nn.Tanh()), 25 | ('s2', nn.MaxPool2d(kernel_size=(2, 2), stride=2, padding=1)), 26 | ('c3', nn.Conv2d(6, 16, kernel_size=(5, 5))), 27 | ('tanh3', nn.Tanh()), 28 | ('s4', nn.MaxPool2d(kernel_size=(2, 2), stride=2, padding=1)), 29 | ('c5', nn.Conv2d(16, 120, kernel_size=(5, 5))), 30 | ('tanh5', nn.Tanh()) 31 | ])) 32 | 33 | self.fc = nn.Sequential(OrderedDict([ 34 | ('f6', nn.Linear(120, 84)), 35 | ('tanh6', nn.Tanh()), 36 | ('f7', nn.Linear(84, 10)), 37 | ('sig7', nn.LogSoftmax(dim=-1)) 38 | ])) 39 | 40 | def forward(self, img): 41 | output = self.convnet(img) 42 | output = output.view(img.size(0), -1) 43 | output = self.fc(output) 44 | return output 45 | 46 | def extract_features(self, img): 47 | output = self.convnet(img.float()) 48 | output = output.view(img.size(0), -1) 49 | output = self.fc[1](self.fc[0](output)) 50 | return output 51 | --------------------------------------------------------------------------------