├── .gitignore ├── GANModels.py ├── JumpingGAN_Train.py ├── LICENSE ├── LoadRealRunningJumping.py ├── LoadSyntheticRunningJumping.py ├── README.md ├── Running&JumpingVisualization.ipynb ├── RunningGAN_Train.py ├── adamw.py ├── cfg.py ├── dataLoader.py ├── functions.py ├── images ├── PositionalEncoding.pdf ├── PositionalEncoding.png ├── TTS-GAN.pdf └── TTS-GAN.png ├── pre-trained-models ├── JumpingGAN_checkpoint └── RunningGAN_checkpoint ├── train_GAN.py ├── utils ├── __init__.py ├── cal_fid_stat.py ├── fid_score.py ├── inception.py ├── inception_model.py ├── inception_score.py ├── torch_fid_score.py └── utils.py └── visualizationMetrics.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /GANModels.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | import math 6 | import numpy as np 7 | 8 | from torchvision.transforms import Compose, Resize, ToTensor 9 | from einops import rearrange, reduce, repeat 10 | from einops.layers.torch import Rearrange, Reduce 11 | from torchsummary import summary 12 | 13 | 14 | class Generator(nn.Module): 15 | def __init__(self, seq_len=150, patch_size=15, channels=3, num_classes=9, latent_dim=100, embed_dim=10, depth=3, 16 | num_heads=5, forward_drop_rate=0.5, attn_drop_rate=0.5): 17 | super(Generator, self).__init__() 18 | self.channels = channels 19 | self.latent_dim = latent_dim 20 | self.seq_len = seq_len 21 | self.embed_dim = embed_dim 22 | self.patch_size = patch_size 23 | self.depth = depth 24 | self.attn_drop_rate = attn_drop_rate 25 | self.forward_drop_rate = forward_drop_rate 26 | 27 | self.l1 = nn.Linear(self.latent_dim, self.seq_len * self.embed_dim) 28 | self.pos_embed = nn.Parameter(torch.zeros(1, self.seq_len, self.embed_dim)) 29 | self.blocks = Gen_TransformerEncoder( 30 | depth=self.depth, 31 | emb_size = self.embed_dim, 32 | drop_p = self.attn_drop_rate, 33 | forward_drop_p=self.forward_drop_rate 34 | ) 35 | 36 | self.deconv = nn.Sequential( 37 | nn.Conv2d(self.embed_dim, self.channels, 1, 1, 0) 38 | ) 39 | 40 | def forward(self, z): 41 | x = self.l1(z).view(-1, self.seq_len, self.embed_dim) 42 | x = x + self.pos_embed 43 | H, W = 1, self.seq_len 44 | x = self.blocks(x) 45 | x = x.reshape(x.shape[0], 1, x.shape[1], x.shape[2]) 46 | output = self.deconv(x.permute(0, 3, 1, 2)) 47 | output = output.view(-1, self.channels, H, W) 48 | return output 49 | 50 | 51 | class Gen_TransformerEncoderBlock(nn.Sequential): 52 | def __init__(self, 53 | emb_size, 54 | num_heads=5, 55 | drop_p=0.5, 56 | forward_expansion=4, 57 | forward_drop_p=0.5): 58 | super().__init__( 59 | ResidualAdd(nn.Sequential( 60 | nn.LayerNorm(emb_size), 61 | MultiHeadAttention(emb_size, num_heads, drop_p), 62 | nn.Dropout(drop_p) 63 | )), 64 | ResidualAdd(nn.Sequential( 65 | nn.LayerNorm(emb_size), 66 | FeedForwardBlock( 67 | emb_size, expansion=forward_expansion, drop_p=forward_drop_p), 68 | nn.Dropout(drop_p) 69 | ) 70 | )) 71 | 72 | 73 | class Gen_TransformerEncoder(nn.Sequential): 74 | def __init__(self, depth=8, **kwargs): 75 | super().__init__(*[Gen_TransformerEncoderBlock(**kwargs) for _ in range(depth)]) 76 | 77 | 78 | class MultiHeadAttention(nn.Module): 79 | def __init__(self, emb_size, num_heads, dropout): 80 | super().__init__() 81 | self.emb_size = emb_size 82 | self.num_heads = num_heads 83 | self.keys = nn.Linear(emb_size, emb_size) 84 | self.queries = nn.Linear(emb_size, emb_size) 85 | self.values = nn.Linear(emb_size, emb_size) 86 | self.att_drop = nn.Dropout(dropout) 87 | self.projection = nn.Linear(emb_size, emb_size) 88 | 89 | def forward(self, x: Tensor, mask: Tensor = None) -> Tensor: 90 | queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads) 91 | keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads) 92 | values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads) 93 | energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len 94 | if mask is not None: 95 | fill_value = torch.finfo(torch.float32).min 96 | energy.mask_fill(~mask, fill_value) 97 | 98 | scaling = self.emb_size ** (1 / 2) 99 | att = F.softmax(energy / scaling, dim=-1) 100 | att = self.att_drop(att) 101 | out = torch.einsum('bhal, bhlv -> bhav ', att, values) 102 | out = rearrange(out, "b h n d -> b n (h d)") 103 | out = self.projection(out) 104 | return out 105 | 106 | 107 | class ResidualAdd(nn.Module): 108 | def __init__(self, fn): 109 | super().__init__() 110 | self.fn = fn 111 | 112 | def forward(self, x, **kwargs): 113 | res = x 114 | x = self.fn(x, **kwargs) 115 | x += res 116 | return x 117 | 118 | 119 | class FeedForwardBlock(nn.Sequential): 120 | def __init__(self, emb_size, expansion, drop_p): 121 | super().__init__( 122 | nn.Linear(emb_size, expansion * emb_size), 123 | nn.GELU(), 124 | nn.Dropout(drop_p), 125 | nn.Linear(expansion * emb_size, emb_size), 126 | ) 127 | 128 | 129 | 130 | class Dis_TransformerEncoderBlock(nn.Sequential): 131 | def __init__(self, 132 | emb_size=100, 133 | num_heads=5, 134 | drop_p=0., 135 | forward_expansion=4, 136 | forward_drop_p=0.): 137 | super().__init__( 138 | ResidualAdd(nn.Sequential( 139 | nn.LayerNorm(emb_size), 140 | MultiHeadAttention(emb_size, num_heads, drop_p), 141 | nn.Dropout(drop_p) 142 | )), 143 | ResidualAdd(nn.Sequential( 144 | nn.LayerNorm(emb_size), 145 | FeedForwardBlock( 146 | emb_size, expansion=forward_expansion, drop_p=forward_drop_p), 147 | nn.Dropout(drop_p) 148 | ) 149 | )) 150 | 151 | 152 | class Dis_TransformerEncoder(nn.Sequential): 153 | def __init__(self, depth=8, **kwargs): 154 | super().__init__(*[Dis_TransformerEncoderBlock(**kwargs) for _ in range(depth)]) 155 | 156 | 157 | class ClassificationHead(nn.Sequential): 158 | def __init__(self, emb_size=100, n_classes=2): 159 | super().__init__() 160 | self.clshead = nn.Sequential( 161 | Reduce('b n e -> b e', reduction='mean'), 162 | nn.LayerNorm(emb_size), 163 | nn.Linear(emb_size, n_classes) 164 | ) 165 | 166 | def forward(self, x): 167 | out = self.clshead(x) 168 | return out 169 | 170 | 171 | class PatchEmbedding_Linear(nn.Module): 172 | #what are the proper parameters set here? 173 | def __init__(self, in_channels = 21, patch_size = 16, emb_size = 100, seq_length = 1024): 174 | # self.patch_size = patch_size 175 | super().__init__() 176 | #change the conv2d parameters here 177 | self.projection = nn.Sequential( 178 | Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)',s1 = 1, s2 = patch_size), 179 | nn.Linear(patch_size*in_channels, emb_size) 180 | ) 181 | self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size)) 182 | self.positions = nn.Parameter(torch.randn((seq_length // patch_size) + 1, emb_size)) 183 | 184 | def forward(self, x: Tensor) -> Tensor: 185 | b, _, _, _ = x.shape 186 | x = self.projection(x) 187 | cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b) 188 | #prepend the cls token to the input 189 | x = torch.cat([cls_tokens, x], dim=1) 190 | # position 191 | x += self.positions 192 | return x 193 | 194 | 195 | class Discriminator(nn.Sequential): 196 | def __init__(self, 197 | in_channels=3, 198 | patch_size=15, 199 | emb_size=50, 200 | seq_length = 150, 201 | depth=3, 202 | n_classes=1, 203 | **kwargs): 204 | super().__init__( 205 | PatchEmbedding_Linear(in_channels, patch_size, emb_size, seq_length), 206 | Dis_TransformerEncoder(depth, emb_size=emb_size, drop_p=0.5, forward_drop_p=0.5, **kwargs), 207 | ClassificationHead(emb_size, n_classes) 208 | ) 209 | 210 | -------------------------------------------------------------------------------- /JumpingGAN_Train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | import os 4 | import argparse 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--rank', type=str, default="0") 9 | parser.add_argument('--node', type=str, default="0015") 10 | opt = parser.parse_args() 11 | 12 | return opt 13 | args = parse_args() 14 | 15 | os.system(f"CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train_GAN.py \ 16 | -gen_bs 16 \ 17 | -dis_bs 16 \ 18 | --dist-url 'tcp://localhost:4321' \ 19 | --dist-backend 'nccl' \ 20 | --world-size 1 \ 21 | --rank {args.rank} \ 22 | --dataset UniMiB \ 23 | --bottom_width 8 \ 24 | --max_iter 500000 \ 25 | --img_size 32 \ 26 | --gen_model my_gen \ 27 | --dis_model my_dis \ 28 | --df_dim 384 \ 29 | --d_heads 4 \ 30 | --d_depth 3 \ 31 | --g_depth 5,4,2 \ 32 | --dropout 0 \ 33 | --latent_dim 100 \ 34 | --gf_dim 1024 \ 35 | --num_workers 16 \ 36 | --g_lr 0.0001 \ 37 | --d_lr 0.0003 \ 38 | --optimizer adam \ 39 | --loss lsgan \ 40 | --wd 1e-3 \ 41 | --beta1 0.9 \ 42 | --beta2 0.999 \ 43 | --phi 1 \ 44 | --batch_size 16 \ 45 | --num_eval_imgs 50000 \ 46 | --init_type xavier_uniform \ 47 | --n_critic 1 \ 48 | --val_freq 20 \ 49 | --print_freq 50 \ 50 | --grow_steps 0 0 \ 51 | --fade_in 0 \ 52 | --patch_size 2 \ 53 | --ema_kimg 500 \ 54 | --ema_warmup 0.1 \ 55 | --ema 0.9999 \ 56 | --diff_aug translation,cutout,color \ 57 | --class_name Jumping \ 58 | --exp_name Jumping") -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /LoadRealRunningJumping.py: -------------------------------------------------------------------------------- 1 | #A binary classification dataset, Jumping or Running 2 | 3 | 4 | import os 5 | import shutil #https://docs.python.org/3/library/shutil.html 6 | from shutil import unpack_archive # to unzip 7 | #from shutil import make_archive # to create zip for storage 8 | import requests #for downloading zip file 9 | from scipy import io #for loadmat, matlab conversion 10 | import pandas as pd 11 | import numpy as np 12 | #import matplotlib.pyplot as plt # for plotting - pandas uses matplotlib 13 | from tabulate import tabulate # for verbose tables 14 | #from tensorflow.keras.utils import to_categorical # for one-hot encoding 15 | 16 | #credit https://stackoverflow.com/questions/9419162/download-returned-zip-file-from-url 17 | #many other methods I tried failed to download the file properly 18 | from torch.utils.data import Dataset, DataLoader 19 | 20 | class_dict = {'StandingUpFS':0,'StandingUpFL':1,'Walking':2,'Running':3,'GoingUpS':4,'Jumping':5,'GoingDownS':6,'LyingDownFS':7,'SittingDown':8} 21 | 22 | class Running_Or_Jumping(Dataset): 23 | def __init__(self, 24 | incl_xyz_accel = False, #include component accel_x/y/z in ____X data 25 | incl_rms_accel = True, #add rms value (total accel) of accel_x/y/z in ____X data 26 | is_normalize = False, 27 | split_subj = dict 28 | (train_subj = [4,5,6,7,8,10,11,12,14,15,19,20,21,22,24,26,27,29,1,9,16,23,25,28], 29 | test_subj = [2,3,13,17,18,30]), 30 | data_mode = 'Train'): 31 | 32 | self.incl_xyz_accel = incl_xyz_accel 33 | self.incl_rms_accel = incl_rms_accel 34 | self.split_subj = split_subj 35 | self.data_mode = data_mode 36 | self.is_normalize = is_normalize 37 | 38 | #Download and unzip original dataset 39 | if (not os.path.isfile('./UniMiB-SHAR.zip')): 40 | print("Downloading UniMiB-SHAR.zip file") 41 | #invoking the shell command fails when exported to .py file 42 | #redirect link https://www.dropbox.com/s/raw/x2fpfqj0bpf8ep6/UniMiB-SHAR.zip 43 | #!wget https://www.dropbox.com/s/x2fpfqj0bpf8ep6/UniMiB-SHAR.zip 44 | self.download_url('https://www.dropbox.com/s/raw/x2fpfqj0bpf8ep6/UniMiB-SHAR.zip','./UniMiB-SHAR.zip') 45 | if (not os.path.isdir('./UniMiB-SHAR')): 46 | shutil.unpack_archive('./UniMiB-SHAR.zip','.','zip') 47 | #Convert .mat files to numpy ndarrays 48 | path_in = './UniMiB-SHAR/data' 49 | #loadmat loads matlab files as dictionary, keys: header, version, globals, data 50 | adl_data = io.loadmat(path_in + '/adl_data.mat')['adl_data'] 51 | adl_names = io.loadmat(path_in + '/adl_names.mat', chars_as_strings=True)['adl_names'] 52 | adl_labels = io.loadmat(path_in + '/adl_labels.mat')['adl_labels'] 53 | 54 | #Reshape data and compute total (rms) acceleration 55 | num_samples = 151 56 | #UniMiB SHAR has fixed size of 453 which is 151 accelX, 151 accely, 151 accelz 57 | adl_data = np.reshape(adl_data,(-1,num_samples,3), order='F') #uses Fortran order 58 | if (self.incl_rms_accel): 59 | rms_accel = np.sqrt((adl_data[:,:,0]**2) + (adl_data[:,:,1]**2) + (adl_data[:,:,2]**2)) 60 | adl_data = np.dstack((adl_data,rms_accel)) 61 | #remove component accel if needed 62 | if (not self.incl_xyz_accel): 63 | adl_data = np.delete(adl_data, [0,1,2], 2) 64 | 65 | #Split train/test sets, combine or make separate validation set 66 | #ref for this numpy gymnastics - find index of matching subject to sub_train/sub_test/sub_validate 67 | #https://numpy.org/doc/stable/reference/generated/numpy.isin.html 68 | 69 | 70 | act_num = (adl_labels[:,0])-1 #matlab source was 1 indexed, change to 0 indexed 71 | sub_num = (adl_labels[:,1]) #subject numbers are in column 1 of labels 72 | 73 | 74 | train_index = np.nonzero(np.isin(sub_num, self.split_subj['train_subj'])) 75 | x_train = adl_data[train_index] 76 | y_train = act_num[train_index] 77 | 78 | test_index = np.nonzero(np.isin(sub_num, self.split_subj['test_subj'])) 79 | x_test = adl_data[test_index] 80 | y_test = act_num[test_index] 81 | 82 | self.x_train = np.transpose(x_train, (0, 2, 1)) 83 | self.x_train = self.x_train.reshape(self.x_train.shape[0], self.x_train.shape[1], 1, self.x_train.shape[2]) 84 | self.x_train = self.x_train[:,:,:,:-1] 85 | self.y_train = y_train 86 | 87 | self.x_test = np.transpose(x_test, (0, 2, 1)) 88 | self.x_test = self.x_test.reshape(self.x_test.shape[0], self.x_test.shape[1], 1, self.x_test.shape[2]) 89 | self.x_test = self.x_test[:,:,:,:-1] 90 | self.y_test = y_test 91 | 92 | if self.is_normalize: 93 | self.x_train = self.normalization(self.x_train) 94 | self.x_test = self.normalization(self.x_test) 95 | 96 | #Select running and jumping data 97 | #Label running as 0 and jumping as 1 98 | 99 | Jumping_train_data = [] 100 | Running_train_data = [] 101 | Jumping_test_data = [] 102 | Running_test_data = [] 103 | 104 | 105 | for i, label in enumerate(y_train): 106 | if label == class_dict['Running']: 107 | Running_train_data.append(self.x_train[i]) 108 | elif label == class_dict['Jumping']: 109 | Jumping_train_data.append(self.x_train[i]) 110 | else: 111 | continue 112 | 113 | for i, label in enumerate(y_test): 114 | if label == class_dict['Running']: 115 | Running_test_data.append(self.x_test[i]) 116 | elif label == class_dict['Jumping']: 117 | Jumping_test_data.append(self.x_test[i]) 118 | else: 119 | continue 120 | 121 | self.Jumping_train_labels = np.ones(len(Jumping_train_data)) 122 | self.Jumping_test_labels = np.ones(len(Jumping_test_data)) 123 | self.Running_train_labels = np.zeros(len(Running_train_data)) 124 | self.Running_test_labels = np.zeros(len(Running_test_data)) 125 | 126 | self.Jumping_train_data = np.array(Jumping_train_data) 127 | self.Running_train_data = np.array(Running_train_data) 128 | self.Jumping_test_data = np.array(Jumping_test_data) 129 | self.Running_test_data = np.array(Running_test_data) 130 | 131 | 132 | #Crop Running to only 600 samples 133 | self.Running_train_data = self.Running_train_data[:600][:][:][:] 134 | self.Running_train_labels = self.Running_train_labels[:600] 135 | 136 | self.Running_test_data = self.Running_test_data[:146][:][:][:] 137 | self.Running_test_labels = self.Running_test_labels[:146] 138 | 139 | self.combined_train_data = np.concatenate((self.Jumping_train_data, self.Running_train_data), axis=0) 140 | self.combined_test_data = np.concatenate((self.Jumping_test_data, self.Running_test_data), axis=0) 141 | 142 | self.combined_train_label = np.concatenate((self.Jumping_train_labels, self.Running_train_labels), axis=0) 143 | self.combined_train_label = self.combined_train_label.reshape(self.combined_train_label.shape[0], 1) 144 | 145 | self.combined_test_label = np.concatenate((self.Jumping_test_labels, self.Running_test_labels), axis=0) 146 | self.combined_test_label = self.combined_test_label.reshape(self.combined_test_label.shape[0], 1) 147 | 148 | if self.data_mode == 'Train': 149 | print(f'data shape is {self.combined_train_data.shape}, label shape is {self.combined_train_label.shape}') 150 | print(f'Jumping label is 1, has {len(self.Jumping_train_labels)} samples, Running label is 0, has {len(self.Running_train_labels)} samples') 151 | else: 152 | print(f'data shape is {self.combined_test_data.shape}, label shape is {self.combined_test_label.shape}') 153 | print(f'Jumping label is 1, has {len(self.Jumping_test_labels)} samples, Running label is 0, has {len(self.Running_test_labels)} samples') 154 | 155 | 156 | def download_url(self, url, save_path, chunk_size=128): 157 | r = requests.get(url, stream=True) 158 | with open(save_path, 'wb') as fd: 159 | for chunk in r.iter_content(chunk_size=chunk_size): 160 | fd.write(chunk) 161 | 162 | def to_categorical(self, y, num_classes): 163 | """ 1-hot encodes a tensor """ 164 | return np.eye(num_classes, dtype='uint8')[y] 165 | 166 | 167 | def _normalize(self, epoch): 168 | """ A helper method for the normalization method. 169 | Returns 170 | result: a normalized epoch 171 | """ 172 | e = 1e-10 173 | result = (epoch - epoch.mean(axis=0)) / ((np.sqrt(epoch.var(axis=0)))+e) 174 | return result 175 | 176 | def _min_max_normalize(self, epoch): 177 | 178 | result = (epoch - min(epoch)) / (max(epoch) - min(epoch)) 179 | return result 180 | 181 | def normalization(self, epochs): 182 | """ Normalizes each epoch e s.t mean(e) = 0 and var(e) = 1 183 | Args: 184 | epochs - Numpy structure of epochs 185 | Returns: 186 | epochs_n - mne data structure of normalized epochs (mean=0, var=1) 187 | """ 188 | for i in range(epochs.shape[0]): 189 | for j in range(epochs.shape[1]): 190 | epochs[i,j,0,:] = self._normalize(epochs[i,j,0,:]) 191 | # epochs[i,j,0,:] = self._min_max_normalize(epochs[i,j,0,:]) 192 | 193 | return epochs 194 | 195 | 196 | def __len__(self): 197 | 198 | if self.data_mode == 'Train': 199 | return len(self.combined_train_label) 200 | else: 201 | return len(self.combined_test_label) 202 | 203 | def __getitem__(self, idx): 204 | 205 | if self.data_mode == 'Train': 206 | return self.combined_train_data[idx], self.combined_train_label[idx] 207 | else: 208 | return self.combined_test_data[idx], self.combined_test_label[idx] 209 | 210 | def collate_fn(self): 211 | pass 212 | 213 | -------------------------------------------------------------------------------- /LoadSyntheticRunningJumping.py: -------------------------------------------------------------------------------- 1 | # Generator synthetic Running and Jumping data 2 | # Made them to a Pytorch Dataset 3 | 4 | from torch.utils.data import Dataset, DataLoader 5 | import torch 6 | from GANModels import * 7 | import numpy as np 8 | import os 9 | 10 | class Synthetic_Dataset(Dataset): 11 | def __init__(self, 12 | Jumping_model_path = './pre-trained-models/JumpingGAN_checkpoint', 13 | Running_model_path = './pre-trained-models/RunningGAN_checkpoint', 14 | sample_size = 1000 15 | ): 16 | 17 | self.sample_size = sample_size 18 | 19 | #Generate Running Data 20 | running_gen_net = Generator(seq_len=150, channels=3, latent_dim=100) 21 | running_ckp = torch.load(Running_model_path) 22 | running_gen_net.load_state_dict(running_ckp['gen_state_dict']) 23 | 24 | #Generate Jumping Data 25 | jumping_gen_net = Generator(seq_len=150, channels=3, latent_dim=100) 26 | jumping_ckp = torch.load(Jumping_model_path) 27 | jumping_gen_net.load_state_dict(jumping_ckp['gen_state_dict']) 28 | 29 | 30 | #generate synthetic running data label is 0 31 | z = torch.FloatTensor(np.random.normal(0, 1, (self.sample_size, 100))) 32 | self.syn_running = running_gen_net(z) 33 | self.syn_running = self.syn_running.detach().numpy() 34 | self.running_label = np.zeros(len(self.syn_running)) 35 | 36 | #generate synthetic jumping data label is 1 37 | z = torch.FloatTensor(np.random.normal(0, 1, (self.sample_size, 100))) 38 | self.syn_jumping = jumping_gen_net(z) 39 | self.syn_jumping = self.syn_jumping.detach().numpy() 40 | self.jumping_label = np.ones(len(self.syn_jumping)) 41 | 42 | self.combined_train_data = np.concatenate((self.syn_running, self.syn_jumping), axis=0) 43 | self.combined_train_label = np.concatenate((self.running_label, self.jumping_label), axis=0) 44 | self.combined_train_label = self.combined_train_label.reshape(self.combined_train_label.shape[0], 1) 45 | 46 | print(self.combined_train_data.shape) 47 | print(self.combined_train_label.shape) 48 | 49 | 50 | def __len__(self): 51 | return self.sample_size * 2 52 | 53 | def __getitem__(self, idx): 54 | return self.combined_train_data[idx], self.combined_train_label[idx] 55 | 56 | 57 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TTS-GAN: A Transformer-based Time-Series Generative Adversarial Network 2 | --- 3 | 4 | This repository contains code from the paper "TTS-GAN: A Transformer-based Time-Series Generative Adversarial Network". 5 | 6 | The paper has been accepted to publish in the 20th International Conference on Artificial Intelligence in Medicine (AIME 2022). 7 | 8 | Please find the paper [here](https://arxiv.org/abs/2202.02691) 9 | 10 | --- 11 | 12 | **Abstract:** 13 | Time-series datasets used in machine learning applications often are small in size, making the training of deep neural network architectures ineffective. For time series, the suite of data augmentation tricks we can use to expand the size of the dataset is limited by the need to maintain the basic properties of the signal. Data generated by a Generative Adversarial Network (GAN) can be utilized as another data augmentation tool. RNN-based GANs suffer from the fact that they cannot effectively model long sequences of data points with irregular temporal relations. To tackle these problems, we introduce TTS-GAN, a transformer-based GAN which can successfully generate realistic synthetic time series data sequences of arbitrary length, similar to the original ones. Both the generator and discriminator networks of the GAN model are built using a pure transformer encoder architecture. We use visualizations to demonstrate the similarity of real and generated time series and a simple classification task that shows how we can use synthetically generated data to augment real data and improve classification accuracy. 14 | 15 | --- 16 | 17 | **Key Idea:** 18 | 19 | Transformer GAN generate synthetic time-series data 20 | 21 | **The TTS-GAN Architecture** 22 | 23 | ![The TTS-GAN Architecture](./images/TTS-GAN.png) 24 | 25 | The TTS-GAN model architecture is shown in the upper figure. It contains two main parts, a generator, and a discriminator. Both of them are built based on the transformer encoder architecture. An encoder is a composition of two compound blocks. A multi-head self-attention module constructs the first block and the second block is a feed-forward MLP with GELU activation function. The normalization layer is applied before both of the two blocks and the dropout layer is added after each block. Both blocks employ residual connections. 26 | 27 | 28 | **The time series data processing step** 29 | 30 | ![The time series data processing step](./images/PositionalEncoding.png) 31 | 32 | We view a time-series data sequence like an image with a height equal to 1. The number of time-steps is the width of an image, *W*. A time-series sequence can have a single channel or multiple channels, and those can be viewed as the number of channels (RGB) of an image, *C*. So an input sequence can be represented with the matrix of size *(Batch Size, C, 1, W)*. Then we choose a patch size *N* to divide a sequence into *W / N* patches. We then add a soft positional encoding value by the end of each patch, the positional value is learned during model training. Each patch will then have the data shape *(Batch Size, C, 1, (W/N) + 1)* This process is shown in the upper figure. 33 | 34 | --- 35 | 36 | **Repository structures:** 37 | 38 | > ./images 39 | 40 | Several images of the TTS-GAN project 41 | 42 | 43 | > ./pre-trained-models 44 | 45 | Saved pre-trained GAN model checkpoints 46 | 47 | 48 | > dataLoader.py 49 | 50 | The UniMiB dataset dataLoader used for loading GAN model training/testing data 51 | 52 | 53 | > LoadRealRunningJumping.py 54 | 55 | Load real running and jumping data from UniMiB dataset 56 | 57 | 58 | > LoadSyntheticRunningJumping.py 59 | 60 | Load Synthetic running and jumping data from the pre-trained GAN models 61 | 62 | 63 | > functions.py 64 | 65 | The GAN model training and evaluation functions 66 | 67 | 68 | > train_GAN.py 69 | 70 | The major GAN model training file 71 | 72 | 73 | > visualizationMetrics.py 74 | 75 | The help functions to draw T-SNE and PCA plots 76 | 77 | 78 | > adamw.py 79 | 80 | The adamw function file 81 | 82 | 83 | > cfg.py 84 | 85 | The parse function used for reading parameters to train_GAN.py file 86 | 87 | 88 | > JumpingGAN_Train.py 89 | 90 | Run this file to start training the Jumping GAN model 91 | 92 | 93 | > RunningGAN_Train.py 94 | 95 | Run this file to start training the Running GAN model 96 | 97 | 98 | --- 99 | 100 | **Code Instructions:** 101 | 102 | 103 | To train the Running data GAN model: 104 | ``` 105 | python RunningGAN_Train.py 106 | ``` 107 | 108 | To train the Jumping data GAN model: 109 | ``` 110 | python JumpingGAN_Train.py 111 | ``` 112 | 113 | A simple example of visualizing the similarity between the synthetic running&jumping data and the real running&jumping data: 114 | ``` 115 | Running&JumpingVisualization.ipynb 116 | ``` 117 | --- 118 | -------------------------------------------------------------------------------- /RunningGAN_Train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | import os 4 | import argparse 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--rank', type=str, default="0") 9 | parser.add_argument('--node', type=str, default="0015") 10 | opt = parser.parse_args() 11 | 12 | return opt 13 | args = parse_args() 14 | 15 | os.system(f"CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train_GAN.py \ 16 | -gen_bs 16 \ 17 | -dis_bs 16 \ 18 | --dist-url 'tcp://localhost:4321' \ 19 | --dist-backend 'nccl' \ 20 | --world-size 1 \ 21 | --rank {args.rank} \ 22 | --dataset UniMiB \ 23 | --bottom_width 8 \ 24 | --max_iter 500000 \ 25 | --img_size 32 \ 26 | --gen_model my_gen \ 27 | --dis_model my_dis \ 28 | --df_dim 384 \ 29 | --d_heads 4 \ 30 | --d_depth 3 \ 31 | --g_depth 5,4,2 \ 32 | --dropout 0 \ 33 | --latent_dim 100 \ 34 | --gf_dim 1024 \ 35 | --num_workers 16 \ 36 | --g_lr 0.0001 \ 37 | --d_lr 0.0003 \ 38 | --optimizer adam \ 39 | --loss lsgan \ 40 | --wd 1e-3 \ 41 | --beta1 0.9 \ 42 | --beta2 0.999 \ 43 | --phi 1 \ 44 | --batch_size 16 \ 45 | --num_eval_imgs 50000 \ 46 | --init_type xavier_uniform \ 47 | --n_critic 1 \ 48 | --val_freq 20 \ 49 | --print_freq 50 \ 50 | --grow_steps 0 0 \ 51 | --fade_in 0 \ 52 | --patch_size 2 \ 53 | --ema_kimg 500 \ 54 | --ema_warmup 0.1 \ 55 | --ema 0.9999 \ 56 | --diff_aug translation,cutout,color \ 57 | --class_name Running \ 58 | --exp_name Running") -------------------------------------------------------------------------------- /adamw.py: -------------------------------------------------------------------------------- 1 | """ AdamW Optimizer 2 | Impl copied from PyTorch master 3 | """ 4 | import math 5 | import torch 6 | from torch.optim.optimizer import Optimizer 7 | 8 | 9 | class AdamW(Optimizer): 10 | r"""Implements AdamW algorithm. 11 | 12 | The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. 13 | The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. 14 | 15 | Arguments: 16 | params (iterable): iterable of parameters to optimize or dicts defining 17 | parameter groups 18 | lr (float, optional): learning rate (default: 1e-3) 19 | betas (Tuple[float, float], optional): coefficients used for computing 20 | running averages of gradient and its square (default: (0.9, 0.999)) 21 | eps (float, optional): term added to the denominator to improve 22 | numerical stability (default: 1e-8) 23 | weight_decay (float, optional): weight decay coefficient (default: 1e-2) 24 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 25 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 26 | (default: False) 27 | 28 | .. _Adam\: A Method for Stochastic Optimization: 29 | https://arxiv.org/abs/1412.6980 30 | .. _Decoupled Weight Decay Regularization: 31 | https://arxiv.org/abs/1711.05101 32 | .. _On the Convergence of Adam and Beyond: 33 | https://openreview.net/forum?id=ryQu7f-RZ 34 | """ 35 | 36 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 37 | weight_decay=1e-2, amsgrad=False): 38 | if not 0.0 <= lr: 39 | raise ValueError("Invalid learning rate: {}".format(lr)) 40 | if not 0.0 <= eps: 41 | raise ValueError("Invalid epsilon value: {}".format(eps)) 42 | if not 0.0 <= betas[0] < 1.0: 43 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 44 | if not 0.0 <= betas[1] < 1.0: 45 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 46 | defaults = dict(lr=lr, betas=betas, eps=eps, 47 | weight_decay=weight_decay, amsgrad=amsgrad) 48 | super(AdamW, self).__init__(params, defaults) 49 | 50 | def __setstate__(self, state): 51 | super(AdamW, self).__setstate__(state) 52 | for group in self.param_groups: 53 | group.setdefault('amsgrad', False) 54 | 55 | def step(self, closure=None): 56 | """Performs a single optimization step. 57 | 58 | Arguments: 59 | closure (callable, optional): A closure that reevaluates the model 60 | and returns the loss. 61 | """ 62 | loss = None 63 | if closure is not None: 64 | loss = closure() 65 | 66 | for group in self.param_groups: 67 | for p in group['params']: 68 | if p.grad is None: 69 | continue 70 | 71 | # Perform stepweight decay 72 | p.data.mul_(1 - group['lr'] * group['weight_decay']) 73 | 74 | # Perform optimization step 75 | grad = p.grad.data 76 | if grad.is_sparse: 77 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 78 | amsgrad = group['amsgrad'] 79 | 80 | state = self.state[p] 81 | 82 | # State initialization 83 | if len(state) == 0: 84 | state['step'] = 0 85 | # Exponential moving average of gradient values 86 | state['exp_avg'] = torch.zeros_like(p.data) 87 | # Exponential moving average of squared gradient values 88 | state['exp_avg_sq'] = torch.zeros_like(p.data) 89 | if amsgrad: 90 | # Maintains max of all exp. moving avg. of sq. grad. values 91 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 92 | 93 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 94 | if amsgrad: 95 | max_exp_avg_sq = state['max_exp_avg_sq'] 96 | beta1, beta2 = group['betas'] 97 | 98 | state['step'] += 1 99 | bias_correction1 = 1 - beta1 ** state['step'] 100 | bias_correction2 = 1 - beta2 ** state['step'] 101 | 102 | # Decay the first and second moment running average coefficient 103 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 104 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 105 | if amsgrad: 106 | # Maintains the maximum of all 2nd moment running avg. till now 107 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 108 | # Use the max. for normalizing running avg. of gradient 109 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 110 | else: 111 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 112 | 113 | step_size = group['lr'] / bias_correction1 114 | 115 | p.data.addcdiv_(-step_size, exp_avg, denom) 116 | 117 | return loss -------------------------------------------------------------------------------- /cfg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2019-07-25 3 | # @Author : Xinyu Gong (xy_gong@tamu.edu) 4 | # @Link : None 5 | # @Version : 0.0 6 | 7 | import argparse 8 | 9 | 10 | def str2bool(v): 11 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 12 | return True 13 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 14 | return False 15 | else: 16 | raise argparse.ArgumentTypeError('Boolean value expected.') 17 | 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--world-size', default=-1, type=int, 22 | help='number of nodes for distributed training') 23 | parser.add_argument('--rank', default=-1, type=int, 24 | help='node rank for distributed training') 25 | parser.add_argument('--loca_rank', default=-1, type=int, 26 | help='node rank for distributed training') 27 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 28 | help='url used to set up distributed training') 29 | parser.add_argument('--dist-backend', default='nccl', type=str, 30 | help='distributed backend') 31 | parser.add_argument('--seed', default=12345, type=int, 32 | help='seed for initializing training. ') 33 | parser.add_argument('--gpu', default=None, type=int, 34 | help='GPU id to use.') 35 | parser.add_argument('--multiprocessing-distributed', action='store_true', 36 | help='Use multi-processing distributed training to launch ' 37 | 'N processes per node, which has N GPUs. This is the ' 38 | 'fastest way to use PyTorch for either single node or ' 39 | 'multi node data parallel training') 40 | parser.add_argument( 41 | '--max_epoch', 42 | type=int, 43 | default=200, 44 | help='number of epochs of training') 45 | parser.add_argument( 46 | '--max_iter', 47 | type=int, 48 | default=None, 49 | help='set the max iteration number') 50 | parser.add_argument( 51 | '-gen_bs', 52 | '--gen_batch_size', 53 | type=int, 54 | default=64, 55 | help='size of the batches') 56 | parser.add_argument( 57 | '-dis_bs', 58 | '--dis_batch_size', 59 | type=int, 60 | default=64, 61 | help='size of the batches') 62 | parser.add_argument( 63 | '-bs', 64 | '--batch_size', 65 | type=int, 66 | default=64, 67 | help='size of the batches to load dataset') 68 | parser.add_argument( 69 | '--g_lr', 70 | type=float, 71 | default=0.0002, 72 | help='adam: gen learning rate') 73 | parser.add_argument( 74 | '--wd', 75 | type=float, 76 | default=0, 77 | help='adamw: gen weight decay') 78 | parser.add_argument( 79 | '--d_lr', 80 | type=float, 81 | default=0.0002, 82 | help='adam: disc learning rate') 83 | parser.add_argument( 84 | '--ctrl_lr', 85 | type=float, 86 | default=3.5e-4, 87 | help='adam: ctrl learning rate') 88 | parser.add_argument( 89 | '--lr_decay', 90 | action='store_true', 91 | help='learning rate decay or not') 92 | parser.add_argument( 93 | '--beta1', 94 | type=float, 95 | default=0.0, 96 | help='adam: decay of first order momentum of gradient') 97 | parser.add_argument( 98 | '--beta2', 99 | type=float, 100 | default=0.9, 101 | help='adam: decay of first order momentum of gradient') 102 | parser.add_argument( 103 | '--num_workers', 104 | type=int, 105 | default=8, 106 | help='number of cpu threads to use during batch generation') 107 | parser.add_argument( 108 | '--latent_dim', 109 | type=int, 110 | default=128, 111 | help='dimensionality of the latent space') 112 | parser.add_argument( 113 | '--img_size', 114 | type=int, 115 | default=32, 116 | help='size of each image dimension') 117 | parser.add_argument( 118 | '--channels', 119 | type=int, 120 | default=3, 121 | help='number of image channels') 122 | parser.add_argument( 123 | '--n_critic', 124 | type=int, 125 | default=1, 126 | help='number of training steps for discriminator per iter') 127 | parser.add_argument( 128 | '--val_freq', 129 | type=int, 130 | default=20, 131 | help='interval between each validation') 132 | parser.add_argument( 133 | '--print_freq', 134 | type=int, 135 | default=100, 136 | help='interval between each verbose') 137 | parser.add_argument( 138 | '--load_path', 139 | type=str, 140 | help='The reload model path') 141 | parser.add_argument( 142 | '--class_name', 143 | type=str, 144 | help='The class name to load in UniMiB dataset') 145 | parser.add_argument( 146 | '--augment_times', 147 | type=int, 148 | default=None, 149 | help='The times of augment signals compare to original data') 150 | parser.add_argument( 151 | '--exp_name', 152 | type=str, 153 | help='The name of exp') 154 | parser.add_argument( 155 | '--d_spectral_norm', 156 | type=str2bool, 157 | default=False, 158 | help='add spectral_norm on discriminator?') 159 | parser.add_argument( 160 | '--g_spectral_norm', 161 | type=str2bool, 162 | default=False, 163 | help='add spectral_norm on generator?') 164 | parser.add_argument( 165 | '--dataset', 166 | type=str, 167 | default='cifar10', 168 | help='dataset type') 169 | parser.add_argument( 170 | '--data_path', 171 | type=str, 172 | default='./data', 173 | help='The path of data set') 174 | parser.add_argument('--init_type', type=str, default='normal', 175 | choices=['normal', 'orth', 'xavier_uniform', 'false'], 176 | help='The init type') 177 | parser.add_argument('--gf_dim', type=int, default=64, 178 | help='The base channel num of gen') 179 | parser.add_argument('--df_dim', type=int, default=64, 180 | help='The base channel num of disc') 181 | parser.add_argument( 182 | '--gen_model', 183 | type=str, 184 | help='path of gen model') 185 | parser.add_argument( 186 | '--dis_model', 187 | type=str, 188 | help='path of dis model') 189 | parser.add_argument( 190 | '--controller', 191 | type=str, 192 | default='controller', 193 | help='path of controller') 194 | parser.add_argument('--eval_batch_size', type=int, default=100) 195 | parser.add_argument('--num_eval_imgs', type=int, default=50000) 196 | parser.add_argument( 197 | '--bottom_width', 198 | type=int, 199 | default=4, 200 | help="the base resolution of the GAN") 201 | parser.add_argument('--random_seed', type=int, default=12345) 202 | 203 | # search 204 | parser.add_argument('--shared_epoch', type=int, default=15, 205 | help='the number of epoch to train the shared gan at each search iteration') 206 | parser.add_argument('--grow_step1', type=int, default=25, 207 | help='which iteration to grow the image size from 8 to 16') 208 | parser.add_argument('--grow_step2', type=int, default=55, 209 | help='which iteration to grow the image size from 16 to 32') 210 | parser.add_argument('--max_search_iter', type=int, default=90, 211 | help='max search iterations of this algorithm') 212 | parser.add_argument('--ctrl_step', type=int, default=30, 213 | help='number of steps to train the controller at each search iteration') 214 | parser.add_argument('--ctrl_sample_batch', type=int, default=1, 215 | help='sample size of controller of each step') 216 | parser.add_argument('--hid_size', type=int, default=100, 217 | help='the size of hidden vector') 218 | parser.add_argument('--baseline_decay', type=float, default=0.9, 219 | help='baseline decay rate in RL') 220 | parser.add_argument('--rl_num_eval_img', type=int, default=5000, 221 | help='number of images to be sampled in order to get the reward') 222 | parser.add_argument('--num_candidate', type=int, default=10, 223 | help='number of candidate architectures to be sampled') 224 | parser.add_argument('--topk', type=int, default=5, 225 | help='preserve topk models architectures after each stage' ) 226 | parser.add_argument('--entropy_coeff', type=float, default=1e-3, 227 | help='to encourage the exploration') 228 | parser.add_argument('--dynamic_reset_threshold', type=float, default=1e-3, 229 | help='var threshold') 230 | parser.add_argument('--dynamic_reset_window', type=int, default=500, 231 | help='the window size') 232 | parser.add_argument('--arch', nargs='+', type=int, 233 | help='the vector of a discovered architecture') 234 | parser.add_argument('--optimizer', type=str, default="adam", 235 | help='optimizer') 236 | parser.add_argument('--loss', type=str, default="hinge", 237 | help='loss function') 238 | parser.add_argument('--n_classes', type=int, default=0, 239 | help='classes') 240 | parser.add_argument('--phi', type=float, default=1, 241 | help='wgan-gp phi') 242 | parser.add_argument('--grow_steps', nargs='+', type=int, 243 | help='the vector of a discovered architecture') 244 | parser.add_argument('--D_downsample', type=str, default="avg", 245 | help='downsampling type') 246 | parser.add_argument('--fade_in', type=float, default=1, 247 | help='fade in step') 248 | parser.add_argument('--d_depth', type=int, default=7, 249 | help='Discriminator Depth') 250 | parser.add_argument('--g_depth', type=str, default="5,4,2", 251 | help='Generator Depth') 252 | parser.add_argument('--g_norm', type=str, default="ln", 253 | help='Generator Normalization') 254 | parser.add_argument('--d_norm', type=str, default="ln", 255 | help='Discriminator Normalization') 256 | parser.add_argument('--g_act', type=str, default="gelu", 257 | help='Generator activation Layer') 258 | parser.add_argument('--d_act', type=str, default="gelu", 259 | help='Discriminator activation layer') 260 | parser.add_argument('--patch_size', type=int, default=4, 261 | help='Discriminator Depth') 262 | parser.add_argument('--fid_stat', type=str, default="None", 263 | help='Discriminator Depth') 264 | parser.add_argument('--diff_aug', type=str, default="None", 265 | help='differentiable augmentation type') 266 | parser.add_argument('--accumulated_times', type=int, default=1, 267 | help='gradient accumulation') 268 | parser.add_argument('--g_accumulated_times', type=int, default=1, 269 | help='gradient accumulation') 270 | parser.add_argument('--num_landmarks', type=int, default=64, 271 | help='number of landmarks') 272 | parser.add_argument('--d_heads', type=int, default=4, 273 | help='number of heads') 274 | parser.add_argument('--dropout', type=float, default=0., 275 | help='dropout ratio') 276 | parser.add_argument('--ema', type=float, default=0.995, 277 | help='ema') 278 | parser.add_argument('--ema_warmup', type=float, default=0., 279 | help='ema warm up') 280 | parser.add_argument('--ema_kimg', type=int, default=500, 281 | help='ema thousand images') 282 | parser.add_argument('--latent_norm',action='store_true', 283 | help='latent vector normalization') 284 | parser.add_argument('--ministd',action='store_true', 285 | help='mini batch std') 286 | parser.add_argument('--g_mlp', type=int, default=4, 287 | help='generator mlp ratio') 288 | parser.add_argument('--d_mlp', type=int, default=4, 289 | help='discriminator mlp ratio') 290 | parser.add_argument('--g_window_size', type=int, default=8, 291 | help='generator mlp ratio') 292 | parser.add_argument('--d_window_size', type=int, default=8, 293 | help='discriminator mlp ratio') 294 | parser.add_argument('--show', action='store_true', 295 | help='show') 296 | 297 | opt = parser.parse_args() 298 | 299 | return opt -------------------------------------------------------------------------------- /dataLoader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """UniMiB_SHAR_ADL_load_dataset.ipynb 3 | 4 | Automatically generated by Colaboratory. 5 | 6 | Original file is located at 7 | https://colab.research.google.com/drive/1U1EY6cZsOFERD3Df1HRqjuTq5bDUGH03 8 | 9 | #UniMiB_SHAR_ADL_load_dataset.ipynb. 10 | Loads the A-9 (ADL) portion of the UniMiB dataset from the Internet repository and converts the data into numpy arrays while adhering to the general format of the [Keras MNIST load_data function](https://keras.io/api/datasets/mnist/#load_data-function). 11 | 12 | Arguments: tbd 13 | Returns: Tuple of Numpy arrays: 14 | (x_train, y_train),(x_validation, y_validation)\[optional\],(x_test, y_test) 15 | 16 | * x_train\/validation\/test: containing float64 with shapes (num_samples, 151, {3,4,1}) 17 | * y_train\/validation\/test: containing int8 with shapes (num_samples 0-9) 18 | 19 | The train/test split is by subject 20 | 21 | Example usage: 22 | x_train, y_train, x_test, y_test = unimib_load_dataset() 23 | 24 | Additional References 25 | If you use the dataset and/or code, please cite this paper (downloadable from [here](http://www.mdpi.com/2076-3417/7/10/1101/html)) 26 | 27 | Developed and tested using colab.research.google.com 28 | To save as .py version use File > Download .py 29 | 30 | Author: Lee B. Hinkle, IMICS Lab, Texas State University, 2021 31 | 32 | Creative Commons License
This work is licensed under a Creative Commons Attribution-ShareAlike 4.0 International License. 33 | 34 | 35 | TODOs: 36 | * Fix document strings 37 | * Assign names to activities instead of numbers 38 | """ 39 | 40 | import os 41 | import shutil #https://docs.python.org/3/library/shutil.html 42 | from shutil import unpack_archive # to unzip 43 | #from shutil import make_archive # to create zip for storage 44 | import requests #for downloading zip file 45 | from scipy import io #for loadmat, matlab conversion 46 | import pandas as pd 47 | import numpy as np 48 | #import matplotlib.pyplot as plt # for plotting - pandas uses matplotlib 49 | from tabulate import tabulate # for verbose tables 50 | #from tensorflow.keras.utils import to_categorical # for one-hot encoding 51 | 52 | #credit https://stackoverflow.com/questions/9419162/download-returned-zip-file-from-url 53 | #many other methods I tried failed to download the file properly 54 | from torch.utils.data import Dataset, DataLoader 55 | 56 | #data augmentation 57 | import tsaug 58 | 59 | class_dict = {'StandingUpFS':0,'StandingUpFL':1,'Walking':2,'Running':3,'GoingUpS':4,'Jumping':5,'GoingDownS':6,'LyingDownFS':7,'SittingDown':8} 60 | 61 | class unimib_load_dataset(Dataset): 62 | def __init__(self, 63 | verbose = False, 64 | incl_xyz_accel = False, #include component accel_x/y/z in ____X data 65 | incl_rms_accel = True, #add rms value (total accel) of accel_x/y/z in ____X data 66 | incl_val_group = False, #True => returns x/y_test, x/y_validation, x/y_train 67 | #False => combine test & validation groups 68 | is_normalize = False, 69 | split_subj = dict 70 | (train_subj = [4,5,6,7,8,10,11,12,14,15,19,20,21,22,24,26,27,29], 71 | validation_subj = [1,9,16,23,25,28], 72 | test_subj = [2,3,13,17,18,30]), 73 | one_hot_encode = True, data_mode = 'Train', single_class = False, class_name= 'Walking', augment_times = None): 74 | 75 | self.verbose = verbose 76 | self.incl_xyz_accel = incl_xyz_accel 77 | self.incl_rms_accel = incl_rms_accel 78 | self.incl_val_group = incl_val_group 79 | self.split_subj = split_subj 80 | self.one_hot_encode = one_hot_encode 81 | self.data_mode = data_mode 82 | self.class_name = class_name 83 | self.single_class = single_class 84 | self.is_normalize = is_normalize 85 | 86 | 87 | #Download and unzip original dataset 88 | if (not os.path.isfile('./UniMiB-SHAR.zip')): 89 | print("Downloading UniMiB-SHAR.zip file") 90 | #invoking the shell command fails when exported to .py file 91 | #redirect link https://www.dropbox.com/s/raw/x2fpfqj0bpf8ep6/UniMiB-SHAR.zip 92 | #!wget https://www.dropbox.com/s/x2fpfqj0bpf8ep6/UniMiB-SHAR.zip 93 | self.download_url('https://www.dropbox.com/s/raw/x2fpfqj0bpf8ep6/UniMiB-SHAR.zip','./UniMiB-SHAR.zip') 94 | if (not os.path.isdir('./UniMiB-SHAR')): 95 | shutil.unpack_archive('./UniMiB-SHAR.zip','.','zip') 96 | #Convert .mat files to numpy ndarrays 97 | path_in = './UniMiB-SHAR/data' 98 | #loadmat loads matlab files as dictionary, keys: header, version, globals, data 99 | adl_data = io.loadmat(path_in + '/adl_data.mat')['adl_data'] 100 | adl_names = io.loadmat(path_in + '/adl_names.mat', chars_as_strings=True)['adl_names'] 101 | adl_labels = io.loadmat(path_in + '/adl_labels.mat')['adl_labels'] 102 | 103 | if(self.verbose): 104 | headers = ("Raw data","shape", "object type", "data type") 105 | mydata = [("adl_data:", adl_data.shape, type(adl_data), adl_data.dtype), 106 | ("adl_labels:", adl_labels.shape ,type(adl_labels), adl_labels.dtype), 107 | ("adl_names:", adl_names.shape, type(adl_names), adl_names.dtype)] 108 | print(tabulate(mydata, headers=headers)) 109 | #Reshape data and compute total (rms) acceleration 110 | num_samples = 151 111 | #UniMiB SHAR has fixed size of 453 which is 151 accelX, 151 accely, 151 accelz 112 | adl_data = np.reshape(adl_data,(-1,num_samples,3), order='F') #uses Fortran order 113 | if (self.incl_rms_accel): 114 | rms_accel = np.sqrt((adl_data[:,:,0]**2) + (adl_data[:,:,1]**2) + (adl_data[:,:,2]**2)) 115 | adl_data = np.dstack((adl_data,rms_accel)) 116 | #remove component accel if needed 117 | if (not self.incl_xyz_accel): 118 | adl_data = np.delete(adl_data, [0,1,2], 2) 119 | if(verbose): 120 | headers = ("Reshaped data","shape", "object type", "data type") 121 | mydata = [("adl_data:", adl_data.shape, type(adl_data), adl_data.dtype), 122 | ("adl_labels:", adl_labels.shape ,type(adl_labels), adl_labels.dtype), 123 | ("adl_names:", adl_names.shape, type(adl_names), adl_names.dtype)] 124 | print(tabulate(mydata, headers=headers)) 125 | #Split train/test sets, combine or make separate validation set 126 | #ref for this numpy gymnastics - find index of matching subject to sub_train/sub_test/sub_validate 127 | #https://numpy.org/doc/stable/reference/generated/numpy.isin.html 128 | 129 | 130 | act_num = (adl_labels[:,0])-1 #matlab source was 1 indexed, change to 0 indexed 131 | sub_num = (adl_labels[:,1]) #subject numbers are in column 1 of labels 132 | 133 | if (not self.incl_val_group): 134 | train_index = np.nonzero(np.isin(sub_num, self.split_subj['train_subj'] + 135 | self.split_subj['validation_subj'])) 136 | x_train = adl_data[train_index] 137 | y_train = act_num[train_index] 138 | else: 139 | train_index = np.nonzero(np.isin(sub_num, self.split_subj['train_subj'])) 140 | x_train = adl_data[train_index] 141 | y_train = act_num[train_index] 142 | 143 | validation_index = np.nonzero(np.isin(sub_num, self.split_subj['validation_subj'])) 144 | x_validation = adl_data[validation_index] 145 | y_validation = act_num[validation_index] 146 | 147 | test_index = np.nonzero(np.isin(sub_num, self.split_subj['test_subj'])) 148 | x_test = adl_data[test_index] 149 | y_test = act_num[test_index] 150 | 151 | if (verbose): 152 | print("x/y_train shape ",x_train.shape,y_train.shape) 153 | if (self.incl_val_group): 154 | print("x/y_validation shape ",x_validation.shape,y_validation.shape) 155 | print("x/y_test shape ",x_test.shape,y_test.shape) 156 | #If selected one-hot encode y_* using keras to_categorical, reference: 157 | #https://keras.io/api/utils/python_utils/#to_categorical-function and 158 | #https://machinelearningmastery.com/how-to-one-hot-encode-sequence-data-in-python/ 159 | if (self.one_hot_encode): 160 | y_train = self.to_categorical(y_train, num_classes=9) 161 | if (self.incl_val_group): 162 | y_validation = self.to_categorical(y_validation, num_classes=9) 163 | y_test = self.to_categorical(y_test, num_classes=9) 164 | if (verbose): 165 | print("After one-hot encoding") 166 | print("x/y_train shape ",x_train.shape,y_train.shape) 167 | if (self.incl_val_group): 168 | print("x/y_validation shape ",x_validation.shape,y_validation.shape) 169 | print("x/y_test shape ",x_test.shape,y_test.shape) 170 | # if (self.incl_val_group): 171 | # return x_train, y_train, x_validation, y_validation, x_test, y_test 172 | # else: 173 | # return x_train, y_train, x_test, y_test 174 | 175 | # reshape x_train, x_test data shape from (BH, length, channel) to (BH, channel, 1, length) 176 | self.x_train = np.transpose(x_train, (0, 2, 1)) 177 | self.x_train = self.x_train.reshape(self.x_train.shape[0], self.x_train.shape[1], 1, self.x_train.shape[2]) 178 | self.x_train = self.x_train[:,:,:,:-1] 179 | self.y_train = y_train 180 | 181 | self.x_test = np.transpose(x_test, (0, 2, 1)) 182 | self.x_test = self.x_test.reshape(self.x_test.shape[0], self.x_test.shape[1], 1, self.x_test.shape[2]) 183 | self.x_test = self.x_test[:,:,:,:-1] 184 | self.y_test = y_test 185 | print(f'x_train shape is {self.x_train.shape}, x_test shape is {self.x_test.shape}') 186 | print(f'y_train shape is {self.y_train.shape}, y_test shape is {self.y_test.shape}') 187 | 188 | 189 | if self.is_normalize: 190 | self.x_train = self.normalization(self.x_train) 191 | self.x_test = self.normalization(self.x_test) 192 | 193 | #Return the give class train/test data & labels 194 | if self.single_class: 195 | one_class_train_data = [] 196 | one_class_train_labels = [] 197 | one_class_test_data = [] 198 | one_class_test_labels = [] 199 | 200 | for i, label in enumerate(y_train): 201 | if label == class_dict[self.class_name]: 202 | one_class_train_data.append(self.x_train[i]) 203 | one_class_train_labels.append(label) 204 | 205 | for i, label in enumerate(y_test): 206 | if label == class_dict[self.class_name]: 207 | one_class_test_data.append(self.x_test[i]) 208 | one_class_test_labels.append(label) 209 | self.one_class_train_data = np.array(one_class_train_data) 210 | self.one_class_train_labels = np.array(one_class_train_labels) 211 | self.one_class_test_data = np.array(one_class_test_data) 212 | self.one_class_test_labels = np.array(one_class_test_labels) 213 | 214 | if augment_times: 215 | augment_data = [] 216 | augment_labels = [] 217 | for data, label in zip(one_class_train_data, one_class_train_labels): 218 | # print(data.shape) # C, 1, T 219 | data = data.reshape(data.shape[0], data.shape[2]) # Channel, Timestep 220 | data = np.transpose(data, (1, 0)) # T, C 221 | data = np.asarray(data) 222 | for i in range(augment_times): 223 | 224 | aug_data = tsaug.Quantize(n_levels=[10, 20, 30]).augment(data) 225 | aug_data = tsaug.Drift(max_drift=(0.1, 0.5)).augment(aug_data) 226 | # aug_data = my_augmenter(data) # T, C 227 | aug_data = np.transpose(aug_data, (1, 0)) # C, T 228 | aug_data = aug_data.reshape(aug_data.shape[0], 1, aug_data.shape[1]) # C, 1, T 229 | augment_data.append(aug_data) 230 | augment_labels.append(label) 231 | 232 | augment_data = np.array(augment_data) 233 | augment_labels = np.array(augment_labels) 234 | print(f'augment_data shape is {augment_data.shape}') 235 | print(f'augment_labels shape is {augment_labels.shape}') 236 | self.one_class_train_data = np.concatenate((augment_data, self.one_class_train_data), axis = 0) 237 | self.one_class_train_labels = np.concatenate((augment_labels, self.one_class_train_labels), axis = 0) 238 | 239 | print(f'return single class data and labels, class is {self.class_name}') 240 | print(f'train_data shape is {self.one_class_train_data.shape}, test_data shape is {self.one_class_test_data.shape}') 241 | print(f'train label shape is {self.one_class_train_labels.shape}, test data shape is {self.one_class_test_labels.shape}') 242 | 243 | def download_url(self, url, save_path, chunk_size=128): 244 | r = requests.get(url, stream=True) 245 | with open(save_path, 'wb') as fd: 246 | for chunk in r.iter_content(chunk_size=chunk_size): 247 | fd.write(chunk) 248 | 249 | def to_categorical(self, y, num_classes): 250 | """ 1-hot encodes a tensor """ 251 | return np.eye(num_classes, dtype='uint8')[y] 252 | 253 | 254 | def _normalize(self, epoch): 255 | """ A helper method for the normalization method. 256 | Returns 257 | result: a normalized epoch 258 | """ 259 | e = 1e-10 260 | result = (epoch - epoch.mean(axis=0)) / ((np.sqrt(epoch.var(axis=0)))+e) 261 | return result 262 | 263 | def _min_max_normalize(self, epoch): 264 | 265 | result = (epoch - min(epoch)) / (max(epoch) - min(epoch)) 266 | return result 267 | 268 | def normalization(self, epochs): 269 | """ Normalizes each epoch e s.t mean(e) = 0 and var(e) = 1 270 | Args: 271 | epochs - Numpy structure of epochs 272 | Returns: 273 | epochs_n - mne data structure of normalized epochs (mean=0, var=1) 274 | """ 275 | for i in range(epochs.shape[0]): 276 | for j in range(epochs.shape[1]): 277 | epochs[i,j,0,:] = self._normalize(epochs[i,j,0,:]) 278 | # epochs[i,j,0,:] = self._min_max_normalize(epochs[i,j,0,:]) 279 | 280 | return epochs 281 | 282 | def __len__(self): 283 | 284 | if self.data_mode == 'Train': 285 | if self.single_class: 286 | return len(self.one_class_train_labels) 287 | else: 288 | return len(self.y_train) 289 | else: 290 | if self.single_class: 291 | return len(self.one_class_test_labels) 292 | else: 293 | return len(self.y_test) 294 | 295 | def __getitem__(self, idx): 296 | if self.data_mode == 'Train': 297 | if self.single_class: 298 | return self.one_class_train_data[idx], self.one_class_train_labels[idx] 299 | else: 300 | return self.x_train[idx], self.y_train[idx] 301 | else: 302 | if self.single_class: 303 | return self.one_class_test_data[idx], self.one_class_test_labels[idx] 304 | else: 305 | return self.x_test[idx], self.y_test[idx] 306 | 307 | -------------------------------------------------------------------------------- /functions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2019-07-25 3 | # @Author : Xinyu Gong (xy_gong@tamu.edu) 4 | # @Link : None 5 | # @Version : 0.0 6 | 7 | import logging 8 | import operator 9 | import os 10 | from copy import deepcopy 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | from imageio import imsave 16 | from utils.utils import make_grid, save_image 17 | from tqdm import tqdm 18 | import cv2 19 | 20 | # from utils.fid_score import calculate_fid_given_paths 21 | from utils.torch_fid_score import get_fid 22 | # from utils.inception_score import get_inception_scorepython exps/dist1_new_church256.py --node 0022 --rank 0sample 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | def cur_stages(iter, args): 27 | """ 28 | Return current stage. 29 | :param epoch: current epoch. 30 | :return: current stage 31 | """ 32 | # if search_iter < self.grow_step1: 33 | # return 0 34 | # elif self.grow_step1 <= search_iter < self.grow_step2: 35 | # return 1 36 | # else: 37 | # return 2 38 | # for idx, grow_step in enumerate(args.grow_steps): 39 | # if iter < grow_step: 40 | # return idx 41 | # return len(args.grow_steps) 42 | idx = 0 43 | for i in range(len(args.grow_steps)): 44 | if iter >= args.grow_steps[i]: 45 | idx = i+1 46 | return idx 47 | 48 | def compute_gradient_penalty(D, real_samples, fake_samples, phi): 49 | """Calculates the gradient penalty loss for WGAN GP""" 50 | # Random weight term for interpolation between real and fake samples 51 | alpha = torch.Tensor(np.random.random((real_samples.size(0), 1, 1, 1))).to(real_samples.get_device()) 52 | # Get random interpolation between real and fake samples 53 | interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) 54 | d_interpolates = D(interpolates) 55 | fake = torch.ones([real_samples.shape[0], 1], requires_grad=False).to(real_samples.get_device()) 56 | # Get gradient w.r.t. interpolates 57 | gradients = torch.autograd.grad( 58 | outputs=d_interpolates, 59 | inputs=interpolates, 60 | grad_outputs=fake, 61 | create_graph=True, 62 | retain_graph=True, 63 | only_inputs=True, 64 | )[0] 65 | gradients = gradients.reshape(gradients.size(0), -1) 66 | gradient_penalty = ((gradients.norm(2, dim=1) - phi) ** 2).mean() 67 | return gradient_penalty 68 | 69 | 70 | def train_d(args, gen_net: nn.Module, dis_net: nn.Module, dis_optimizer, train_loader, epoch, writer_dict,fixed_z, schedulers=None): 71 | writer = writer_dict['writer'] 72 | # gen_step = 0 73 | # train mode 74 | dis_net.train() 75 | 76 | dis_optimizer.zero_grad() 77 | 78 | for iter_idx, (imgs, _) in enumerate(tqdm(train_loader)): 79 | global_steps = writer_dict['train_global_steps'] 80 | 81 | 82 | # Adversarial ground truths 83 | real_imgs = imgs.type(torch.cuda.FloatTensor).cuda(args.gpu, non_blocking=True) 84 | 85 | # Sample noise as generator input 86 | z = torch.cuda.FloatTensor(np.random.normal(0, 1, (imgs.shape[0], args.latent_dim))).cuda(args.gpu, non_blocking=True) 87 | 88 | # --------------------- 89 | # Train Discriminator 90 | # --------------------- 91 | 92 | 93 | real_validity = dis_net(real_imgs) 94 | fake_imgs = gen_net(z).detach() 95 | 96 | assert fake_imgs.size() == real_imgs.size(), f"fake_imgs.size(): {fake_imgs.size()} real_imgs.size(): {real_imgs.size()}" 97 | 98 | fake_validity = dis_net(fake_imgs) 99 | 100 | # cal loss 101 | if args.loss == 'hinge': 102 | d_loss = 0 103 | d_loss = torch.mean(nn.ReLU(inplace=True)(1.0 - real_validity)) + \ 104 | torch.mean(nn.ReLU(inplace=True)(1 + fake_validity)) 105 | elif args.loss == 'standard': 106 | #soft label 107 | real_label = torch.full((imgs.shape[0],), 0.9, dtype=torch.float, device=real_imgs.get_device()) 108 | fake_label = torch.full((imgs.shape[0],), 0.1, dtype=torch.float, device=real_imgs.get_device()) 109 | real_validity = nn.Sigmoid()(real_validity.view(-1)) 110 | fake_validity = nn.Sigmoid()(fake_validity.view(-1)) 111 | d_real_loss = nn.BCELoss()(real_validity, real_label) 112 | d_fake_loss = nn.BCELoss()(fake_validity, fake_label) 113 | d_loss = d_real_loss + d_fake_loss 114 | elif args.loss == 'lsgan': 115 | if isinstance(fake_validity, list): 116 | d_loss = 0 117 | for real_validity_item, fake_validity_item in zip(real_validity, fake_validity): 118 | real_label = torch.full((real_validity_item.shape[0],real_validity_item.shape[1]), 1., dtype=torch.float, device=real_imgs.get_device()) 119 | fake_label = torch.full((real_validity_item.shape[0],real_validity_item.shape[1]), 0., dtype=torch.float, device=real_imgs.get_device()) 120 | d_real_loss = nn.MSELoss()(real_validity_item, real_label) 121 | d_fake_loss = nn.MSELoss()(fake_validity_item, fake_label) 122 | d_loss += d_real_loss + d_fake_loss 123 | else: 124 | real_label = torch.full((real_validity.shape[0],real_validity.shape[1]), 1., dtype=torch.float, device=real_imgs.get_device()) 125 | fake_label = torch.full((real_validity.shape[0],real_validity.shape[1]), 0., dtype=torch.float, device=real_imgs.get_device()) 126 | d_real_loss = nn.MSELoss()(real_validity, real_label) 127 | d_fake_loss = nn.MSELoss()(fake_validity, fake_label) 128 | d_loss = d_real_loss + d_fake_loss 129 | elif args.loss == 'wgangp': 130 | gradient_penalty = compute_gradient_penalty(dis_net, real_imgs, fake_imgs.detach(), args.phi) 131 | d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + gradient_penalty * 10 / ( 132 | args.phi ** 2) 133 | elif args.loss == 'wgangp-mode': 134 | gradient_penalty = compute_gradient_penalty(dis_net, real_imgs, fake_imgs.detach(), args.phi) 135 | d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + gradient_penalty * 10 / ( 136 | args.phi ** 2) 137 | elif args.loss == 'wgangp-eps': 138 | gradient_penalty = compute_gradient_penalty(dis_net, real_imgs, fake_imgs.detach(), args.phi) 139 | d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + gradient_penalty * 10 / ( 140 | args.phi ** 2) 141 | d_loss += (torch.mean(real_validity) ** 2) * 1e-3 142 | else: 143 | raise NotImplementedError(args.loss) 144 | d_loss = d_loss/float(args.accumulated_times) 145 | d_loss.backward() 146 | 147 | if (iter_idx + 1) % args.accumulated_times == 0: 148 | torch.nn.utils.clip_grad_norm_(dis_net.parameters(), 5.) 149 | dis_optimizer.step() 150 | dis_optimizer.zero_grad() 151 | 152 | writer.add_scalar('d_loss', d_loss.item(), global_steps) if args.rank == 0 else 0 153 | 154 | 155 | # # adjust learning rate 156 | # if schedulers: 157 | # gen_scheduler, dis_scheduler = schedulers 158 | # # g_lr = gen_scheduler.step(global_steps) 159 | # d_lr = dis_scheduler.step(global_steps) 160 | # # writer.add_scalar('LR/g_lr', g_lr, global_steps) 161 | # writer.add_scalar('LR/d_lr', d_lr, global_steps) 162 | 163 | # # moving average weight 164 | # ema_nimg = args.ema_kimg * 1000 165 | # cur_nimg = args.dis_batch_size * args.world_size * global_steps 166 | # if args.ema_warmup != 0: 167 | # ema_nimg = min(ema_nimg, cur_nimg * args.ema_warmup) 168 | # ema_beta = 0.5 ** (float(args.dis_batch_size * args.world_size) / max(ema_nimg, 1e-8)) 169 | # else: 170 | # ema_beta = args.ema 171 | 172 | # # moving average weight 173 | # for p, avg_p in zip(gen_net.parameters(), gen_avg_param): 174 | # cpu_p = deepcopy(p) 175 | # avg_p.mul_(ema_beta).add_(1. - ema_beta, cpu_p.cpu().data) 176 | # del cpu_p 177 | 178 | # # writer.add_scalar('g_loss', g_loss.item(), global_steps) if args.rank == 0 else 0 179 | # # gen_step += 1 180 | 181 | # # verbose 182 | # if gen_step and iter_idx % args.print_freq == 0 and args.rank == 0: 183 | # sample_imgs = torch.cat((gen_imgs[:16], real_imgs[:16]), dim=0) 184 | # # scale_factor = args.img_size // int(sample_imgs.size(3)) 185 | # # sample_imgs = torch.nn.functional.interpolate(sample_imgs, scale_factor=2) 186 | # # img_grid = make_grid(sample_imgs, nrow=4, normalize=True, scale_each=True) 187 | # # save_image(sample_imgs, f'sampled_images_{args.exp_name}.jpg', nrow=4, normalize=True, scale_each=True) 188 | # # writer.add_image(f'sampled_images_{args.exp_name}', img_grid, global_steps) 189 | # tqdm.write( 190 | # "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [ema: %f] " % 191 | # (epoch, args.max_epoch, iter_idx % len(train_loader), len(train_loader), d_loss.item(), ema_beta)) 192 | # del gen_imgs 193 | # del real_imgs 194 | # del fake_validity 195 | # del real_validity 196 | # del d_loss 197 | tqdm.write( "[Epoch %d/%d] [Batch %d/%d] [D loss: %f]" % 198 | (epoch, args.max_epoch, iter_idx % len(train_loader), len(train_loader), d_loss.item())) 199 | 200 | writer_dict['train_global_steps'] = global_steps + 1 201 | 202 | 203 | def train(args, gen_net: nn.Module, dis_net: nn.Module, gen_optimizer, dis_optimizer, gen_avg_param, train_loader, 204 | epoch, writer_dict, fixed_z, schedulers=None): 205 | writer = writer_dict['writer'] 206 | gen_step = 0 207 | # train mode 208 | gen_net.train() 209 | dis_net.train() 210 | 211 | dis_optimizer.zero_grad() 212 | gen_optimizer.zero_grad() 213 | for iter_idx, (imgs, _) in enumerate(tqdm(train_loader)): 214 | global_steps = writer_dict['train_global_steps'] 215 | 216 | 217 | # Adversarial ground truths 218 | real_imgs = imgs.type(torch.cuda.FloatTensor).cuda(args.gpu, non_blocking=True) 219 | 220 | # Sample noise as generator input 221 | z = torch.cuda.FloatTensor(np.random.normal(0, 1, (imgs.shape[0], args.latent_dim))).cuda(args.gpu, non_blocking=True) 222 | 223 | # --------------------- 224 | # Train Discriminator 225 | # --------------------- 226 | 227 | 228 | real_validity = dis_net(real_imgs) 229 | fake_imgs = gen_net(z).detach() 230 | 231 | assert fake_imgs.size() == real_imgs.size(), f"fake_imgs.size(): {fake_imgs.size()} real_imgs.size(): {real_imgs.size()}" 232 | 233 | fake_validity = dis_net(fake_imgs) 234 | 235 | # cal loss 236 | if args.loss == 'hinge': 237 | d_loss = 0 238 | d_loss = torch.mean(nn.ReLU(inplace=True)(1.0 - real_validity)) + \ 239 | torch.mean(nn.ReLU(inplace=True)(1 + fake_validity)) 240 | elif args.loss == 'standard': 241 | #soft label 242 | real_label = torch.full((imgs.shape[0],), 0.9, dtype=torch.float, device=real_imgs.get_device()) 243 | fake_label = torch.full((imgs.shape[0],), 0.1, dtype=torch.float, device=real_imgs.get_device()) 244 | real_validity = nn.Sigmoid()(real_validity.view(-1)) 245 | fake_validity = nn.Sigmoid()(fake_validity.view(-1)) 246 | d_real_loss = nn.BCELoss()(real_validity, real_label) 247 | d_fake_loss = nn.BCELoss()(fake_validity, fake_label) 248 | d_loss = d_real_loss + d_fake_loss 249 | elif args.loss == 'lsgan': 250 | if isinstance(fake_validity, list): 251 | d_loss = 0 252 | for real_validity_item, fake_validity_item in zip(real_validity, fake_validity): 253 | real_label = torch.full((real_validity_item.shape[0],real_validity_item.shape[1]), 1., dtype=torch.float, device=real_imgs.get_device()) 254 | fake_label = torch.full((real_validity_item.shape[0],real_validity_item.shape[1]), 0., dtype=torch.float, device=real_imgs.get_device()) 255 | d_real_loss = nn.MSELoss()(real_validity_item, real_label) 256 | d_fake_loss = nn.MSELoss()(fake_validity_item, fake_label) 257 | d_loss += d_real_loss + d_fake_loss 258 | else: 259 | real_label = torch.full((real_validity.shape[0],real_validity.shape[1]), 1., dtype=torch.float, device=real_imgs.get_device()) 260 | fake_label = torch.full((real_validity.shape[0],real_validity.shape[1]), 0., dtype=torch.float, device=real_imgs.get_device()) 261 | d_real_loss = nn.MSELoss()(real_validity, real_label) 262 | d_fake_loss = nn.MSELoss()(fake_validity, fake_label) 263 | d_loss = d_real_loss + d_fake_loss 264 | elif args.loss == 'wgangp': 265 | gradient_penalty = compute_gradient_penalty(dis_net, real_imgs, fake_imgs.detach(), args.phi) 266 | d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + gradient_penalty * 10 / ( 267 | args.phi ** 2) 268 | elif args.loss == 'wgangp-mode': 269 | gradient_penalty = compute_gradient_penalty(dis_net, real_imgs, fake_imgs.detach(), args.phi) 270 | d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + gradient_penalty * 10 / ( 271 | args.phi ** 2) 272 | elif args.loss == 'wgangp-eps': 273 | gradient_penalty = compute_gradient_penalty(dis_net, real_imgs, fake_imgs.detach(), args.phi) 274 | d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + gradient_penalty * 10 / ( 275 | args.phi ** 2) 276 | d_loss += (torch.mean(real_validity) ** 2) * 1e-3 277 | else: 278 | raise NotImplementedError(args.loss) 279 | d_loss = d_loss/float(args.accumulated_times) 280 | d_loss.backward() 281 | 282 | if (iter_idx + 1) % args.accumulated_times == 0: 283 | torch.nn.utils.clip_grad_norm_(dis_net.parameters(), 5.) 284 | dis_optimizer.step() 285 | dis_optimizer.zero_grad() 286 | 287 | writer.add_scalar('d_loss', d_loss.item(), global_steps) if args.rank == 0 else 0 288 | 289 | # ----------------- 290 | # Train Generator 291 | # ----------------- 292 | if global_steps % (args.n_critic * args.accumulated_times) == 0: 293 | 294 | for accumulated_idx in range(args.g_accumulated_times): 295 | gen_z = torch.cuda.FloatTensor(np.random.normal(0, 1, (args.gen_batch_size, args.latent_dim))) 296 | gen_imgs = gen_net(gen_z) 297 | fake_validity = dis_net(gen_imgs) 298 | 299 | # cal loss 300 | loss_lz = torch.tensor(0) 301 | if args.loss == "standard": 302 | real_label = torch.full((args.gen_batch_size,), 1., dtype=torch.float, device=real_imgs.get_device()) 303 | fake_validity = nn.Sigmoid()(fake_validity.view(-1)) 304 | g_loss = nn.BCELoss()(fake_validity.view(-1), real_label) 305 | if args.loss == "lsgan": 306 | if isinstance(fake_validity, list): 307 | g_loss = 0 308 | for fake_validity_item in fake_validity: 309 | real_label = torch.full((fake_validity_item.shape[0],fake_validity_item.shape[1]), 1., dtype=torch.float, device=real_imgs.get_device()) 310 | g_loss += nn.MSELoss()(fake_validity_item, real_label) 311 | else: 312 | real_label = torch.full((fake_validity.shape[0],fake_validity.shape[1]), 1., dtype=torch.float, device=real_imgs.get_device()) 313 | # fake_validity = nn.Sigmoid()(fake_validity.view(-1)) 314 | g_loss = nn.MSELoss()(fake_validity, real_label) 315 | elif args.loss == 'wgangp-mode': 316 | fake_image1, fake_image2 = gen_imgs[:args.gen_batch_size//2], gen_imgs[args.gen_batch_size//2:] 317 | z_random1, z_random2 = gen_z[:args.gen_batch_size//2], gen_z[args.gen_batch_size//2:] 318 | lz = torch.mean(torch.abs(fake_image2 - fake_image1)) / torch.mean( 319 | torch.abs(z_random2 - z_random1)) 320 | eps = 1 * 1e-5 321 | loss_lz = 1 / (lz + eps) 322 | 323 | g_loss = -torch.mean(fake_validity) + loss_lz 324 | else: 325 | g_loss = -torch.mean(fake_validity) 326 | g_loss = g_loss/float(args.g_accumulated_times) 327 | g_loss.backward() 328 | 329 | torch.nn.utils.clip_grad_norm_(gen_net.parameters(), 5.) 330 | gen_optimizer.step() 331 | gen_optimizer.zero_grad() 332 | 333 | # adjust learning rate 334 | if schedulers: 335 | gen_scheduler, dis_scheduler = schedulers 336 | g_lr = gen_scheduler.step(global_steps) 337 | d_lr = dis_scheduler.step(global_steps) 338 | writer.add_scalar('LR/g_lr', g_lr, global_steps) 339 | writer.add_scalar('LR/d_lr', d_lr, global_steps) 340 | 341 | # moving average weight 342 | ema_nimg = args.ema_kimg * 1000 343 | cur_nimg = args.dis_batch_size * args.world_size * global_steps 344 | if args.ema_warmup != 0: 345 | ema_nimg = min(ema_nimg, cur_nimg * args.ema_warmup) 346 | ema_beta = 0.5 ** (float(args.dis_batch_size * args.world_size) / max(ema_nimg, 1e-8)) 347 | else: 348 | ema_beta = args.ema 349 | 350 | # moving average weight 351 | for p, avg_p in zip(gen_net.parameters(), gen_avg_param): 352 | cpu_p = deepcopy(p) 353 | avg_p.mul_(ema_beta).add_(1. - ema_beta, cpu_p.cpu().data) 354 | del cpu_p 355 | 356 | writer.add_scalar('g_loss', g_loss.item(), global_steps) if args.rank == 0 else 0 357 | gen_step += 1 358 | 359 | # verbose 360 | if gen_step and iter_idx % args.print_freq == 0 and args.rank == 0: 361 | sample_imgs = torch.cat((gen_imgs[:16], real_imgs[:16]), dim=0) 362 | # scale_factor = args.img_size // int(sample_imgs.size(3)) 363 | # sample_imgs = torch.nn.functional.interpolate(sample_imgs, scale_factor=2) 364 | # img_grid = make_grid(sample_imgs, nrow=4, normalize=True, scale_each=True) 365 | # save_image(sample_imgs, f'sampled_images_{args.exp_name}.jpg', nrow=4, normalize=True, scale_each=True) 366 | # writer.add_image(f'sampled_images_{args.exp_name}', img_grid, global_steps) 367 | tqdm.write( 368 | "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [ema: %f] " % 369 | (epoch, args.max_epoch, iter_idx % len(train_loader), len(train_loader), d_loss.item(), g_loss.item(), ema_beta)) 370 | del gen_imgs 371 | del real_imgs 372 | del fake_validity 373 | del real_validity 374 | del g_loss 375 | del d_loss 376 | 377 | writer_dict['train_global_steps'] = global_steps + 1 378 | 379 | 380 | 381 | 382 | 383 | def get_is(args, gen_net: nn.Module, num_img): 384 | """ 385 | Get inception score. 386 | :param args: 387 | :param gen_net: 388 | :param num_img: 389 | :return: Inception score 390 | """ 391 | 392 | # eval mode 393 | gen_net = gen_net.eval() 394 | 395 | eval_iter = num_img // args.eval_batch_size 396 | img_list = list() 397 | for _ in range(eval_iter): 398 | z = torch.cuda.FloatTensor(np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim))) 399 | 400 | # Generate a batch of images 401 | gen_imgs = gen_net(z).mul_(127.5).add_(127.5).clamp_(0.0, 255.0).permute(0, 2, 3, 1).to('cpu', 402 | torch.uint8).numpy() 403 | img_list.extend(list(gen_imgs)) 404 | 405 | # get inception score 406 | logger.info('calculate Inception score...') 407 | mean, std = get_inception_score(img_list) 408 | 409 | return mean 410 | 411 | 412 | def validate(args, fixed_z, fid_stat, epoch, gen_net: nn.Module, writer_dict, clean_dir=True): 413 | writer = writer_dict['writer'] 414 | global_steps = writer_dict['valid_global_steps'] 415 | 416 | # eval mode 417 | gen_net.eval() 418 | 419 | # generate images 420 | # with torch.no_grad(): 421 | # sample_imgs = gen_net(fixed_z, epoch) 422 | # img_grid = make_grid(sample_imgs, nrow=5, normalize=True, scale_each=True) 423 | 424 | # get fid and inception score 425 | # if args.gpu == 0: 426 | # fid_buffer_dir = os.path.join(args.path_helper['sample_path'], 'fid_buffer') 427 | # os.makedirs(fid_buffer_dir, exist_ok=True) if args.gpu == 0 else 0 428 | 429 | # eval_iter = args.num_eval_imgs // args.eval_batch_size 430 | # img_list = list() 431 | # for iter_idx in tqdm(range(eval_iter), desc='sample images'): 432 | # z = torch.cuda.FloatTensor(np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim))) 433 | 434 | # # Generate a batch of images 435 | # gen_imgs = gen_net(z, epoch).mul_(127.5).add_(127.5).clamp_(0.0, 255.0).permute(0, 2, 3, 1).to('cpu', 436 | # torch.uint8).numpy() 437 | # for img_idx, img in enumerate(gen_imgs): 438 | # file_name = os.path.join(fid_buffer_dir, f'iter{iter_idx}_b{img_idx}.png') 439 | # imsave(file_name, img) 440 | # img_list.extend(list(gen_imgs)) 441 | 442 | # get inception score 443 | logger.info('=> calculate inception score') if args.rank == 0 else 0 444 | if args.rank == 0: 445 | # mean, std = get_inception_score(img_list) 446 | mean, std = 0, 0 447 | else: 448 | mean, std = 0, 0 449 | print(f"Inception score: {mean}") if args.rank == 0 else 0 450 | # mean, std = 0, 0 451 | # get fid score 452 | print('=> calculate fid score') if args.rank == 0 else 0 453 | if args.rank == 0: 454 | fid_score = get_fid(args, fid_stat, epoch, gen_net, args.num_eval_imgs, args.gen_batch_size, args.eval_batch_size, writer_dict=writer_dict, cls_idx=None) 455 | else: 456 | fid_score = 10000 457 | # fid_score = 10000 458 | print(f"FID score: {fid_score}") if args.rank == 0 else 0 459 | 460 | # if args.gpu == 0: 461 | # if clean_dir: 462 | # os.system('rm -r {}'.format(fid_buffer_dir)) 463 | # else: 464 | # logger.info(f'=> sampled images are saved to {fid_buffer_dir}') 465 | 466 | # writer.add_image('sampled_images', img_grid, global_steps) 467 | if args.rank == 0: 468 | writer.add_scalar('Inception_score/mean', mean, global_steps) 469 | writer.add_scalar('Inception_score/std', std, global_steps) 470 | writer.add_scalar('FID_score', fid_score, global_steps) 471 | 472 | writer_dict['valid_global_steps'] = global_steps + 1 473 | 474 | return mean, fid_score 475 | 476 | 477 | def save_samples(args, fixed_z, fid_stat, epoch, gen_net: nn.Module, writer_dict, clean_dir=True): 478 | 479 | # eval mode 480 | gen_net.eval() 481 | with torch.no_grad(): 482 | # generate images 483 | batch_size = fixed_z.size(0) 484 | sample_imgs = [] 485 | for i in range(fixed_z.size(0)): 486 | sample_img = gen_net(fixed_z[i:(i+1)], epoch) 487 | sample_imgs.append(sample_img) 488 | sample_imgs = torch.cat(sample_imgs, dim=0) 489 | os.makedirs(f"./samples/{args.exp_name}", exist_ok=True) 490 | save_image(sample_imgs, f'./samples/{args.exp_name}/sampled_images_{epoch}.png', nrow=10, normalize=True, scale_each=True) 491 | return 0 492 | 493 | 494 | def get_topk_arch_hidden(args, controller, gen_net, prev_archs, prev_hiddens): 495 | """ 496 | ~ 497 | :param args: 498 | :param controller: 499 | :param gen_net: 500 | :param prev_archs: previous architecture 501 | :param prev_hiddens: previous hidden vector 502 | :return: a list of topk archs and hiddens. 503 | """ 504 | logger.info(f'=> get top{args.topk} archs out of {args.num_candidate} candidate archs...') 505 | assert args.num_candidate >= args.topk 506 | controller.eval() 507 | cur_stage = controller.cur_stage 508 | archs, _, _, hiddens = controller.sample(args.num_candidate, with_hidden=True, prev_archs=prev_archs, 509 | prev_hiddens=prev_hiddens) 510 | hxs, cxs = hiddens 511 | arch_idx_perf_table = {} 512 | for arch_idx in range(len(archs)): 513 | logger.info(f'arch: {archs[arch_idx]}') 514 | gen_net.set_arch(archs[arch_idx], cur_stage) 515 | is_score = get_is(args, gen_net, args.rl_num_eval_img) 516 | logger.info(f'get Inception score of {is_score}') 517 | arch_idx_perf_table[arch_idx] = is_score 518 | topk_arch_idx_perf = sorted(arch_idx_perf_table.items(), key=operator.itemgetter(1))[::-1][:args.topk] 519 | topk_archs = [] 520 | topk_hxs = [] 521 | topk_cxs = [] 522 | logger.info(f'top{args.topk} archs:') 523 | for arch_idx_perf in topk_arch_idx_perf: 524 | logger.info(arch_idx_perf) 525 | arch_idx = arch_idx_perf[0] 526 | topk_archs.append(archs[arch_idx]) 527 | topk_hxs.append(hxs[arch_idx].detach().requires_grad_(False)) 528 | topk_cxs.append(cxs[arch_idx].detach().requires_grad_(False)) 529 | 530 | return topk_archs, (topk_hxs, topk_cxs) 531 | 532 | 533 | class LinearLrDecay(object): 534 | def __init__(self, optimizer, start_lr, end_lr, decay_start_step, decay_end_step): 535 | 536 | assert start_lr > end_lr 537 | self.optimizer = optimizer 538 | self.delta = (start_lr - end_lr) / (decay_end_step - decay_start_step) 539 | self.decay_start_step = decay_start_step 540 | self.decay_end_step = decay_end_step 541 | self.start_lr = start_lr 542 | self.end_lr = end_lr 543 | 544 | def step(self, current_step): 545 | if current_step <= self.decay_start_step: 546 | lr = self.start_lr 547 | elif current_step >= self.decay_end_step: 548 | lr = self.end_lr 549 | else: 550 | lr = self.start_lr - self.delta * (current_step - self.decay_start_step) 551 | for param_group in self.optimizer.param_groups: 552 | param_group['lr'] = lr 553 | return lr 554 | 555 | def load_params(model, new_param, args, mode="gpu"): 556 | if mode == "cpu": 557 | for p, new_p in zip(model.parameters(), new_param): 558 | cpu_p = deepcopy(new_p) 559 | # p.data.copy_(cpu_p.cuda().to(f"cuda:{args.gpu}")) 560 | p.data.copy_(cpu_p.cuda().to("cpu")) 561 | del cpu_p 562 | 563 | else: 564 | for p, new_p in zip(model.parameters(), new_param): 565 | p.data.copy_(new_p) 566 | 567 | 568 | def copy_params(model, mode='cpu'): 569 | if mode == 'gpu': 570 | flatten = [] 571 | for p in model.parameters(): 572 | cpu_p = deepcopy(p).cpu() 573 | flatten.append(cpu_p.data) 574 | else: 575 | flatten = deepcopy(list(p.data for p in model.parameters())) 576 | return flatten -------------------------------------------------------------------------------- /images/PositionalEncoding.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imics-lab/tts-gan/3f8b36ab84d1c00d48021d6e7c5dbd461686844e/images/PositionalEncoding.pdf -------------------------------------------------------------------------------- /images/PositionalEncoding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imics-lab/tts-gan/3f8b36ab84d1c00d48021d6e7c5dbd461686844e/images/PositionalEncoding.png -------------------------------------------------------------------------------- /images/TTS-GAN.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imics-lab/tts-gan/3f8b36ab84d1c00d48021d6e7c5dbd461686844e/images/TTS-GAN.pdf -------------------------------------------------------------------------------- /images/TTS-GAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imics-lab/tts-gan/3f8b36ab84d1c00d48021d6e7c5dbd461686844e/images/TTS-GAN.png -------------------------------------------------------------------------------- /pre-trained-models/JumpingGAN_checkpoint: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imics-lab/tts-gan/3f8b36ab84d1c00d48021d6e7c5dbd461686844e/pre-trained-models/JumpingGAN_checkpoint -------------------------------------------------------------------------------- /pre-trained-models/RunningGAN_checkpoint: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imics-lab/tts-gan/3f8b36ab84d1c00d48021d6e7c5dbd461686844e/pre-trained-models/RunningGAN_checkpoint -------------------------------------------------------------------------------- /train_GAN.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import cfg 6 | # import models_search 7 | # import datasets 8 | from dataLoader import * 9 | from GANModels import * 10 | from functions import train, train_d, validate, save_samples, LinearLrDecay, load_params, copy_params, cur_stages 11 | from utils.utils import set_log_dir, save_checkpoint, create_logger 12 | # from utils.inception_score import _init_inception 13 | # from utils.fid_score import create_inception_graph, check_or_download_inception 14 | 15 | import torch 16 | import torch.multiprocessing as mp 17 | import torch.distributed as dist 18 | import torch.utils.data.distributed 19 | from torch.utils import data 20 | import os 21 | import numpy as np 22 | import torch.nn as nn 23 | # from tensorboardX import SummaryWriter 24 | from torch.utils.tensorboard import SummaryWriter 25 | from tqdm import tqdm 26 | from copy import deepcopy 27 | from adamw import AdamW 28 | import random 29 | import matplotlib.pyplot as plt 30 | import io 31 | import PIL.Image 32 | from torchvision.transforms import ToTensor 33 | 34 | # torch.backends.cudnn.enabled = True 35 | # torch.backends.cudnn.benchmark = True 36 | 37 | 38 | def main(): 39 | args = cfg.parse_args() 40 | 41 | # _init_inception() 42 | # inception_path = check_or_download_inception(None) 43 | # create_inception_graph(inception_path) 44 | 45 | if args.seed is not None: 46 | torch.manual_seed(args.random_seed) 47 | torch.cuda.manual_seed(args.random_seed) 48 | torch.cuda.manual_seed_all(args.random_seed) 49 | np.random.seed(args.random_seed) 50 | random.seed(args.random_seed) 51 | torch.backends.cudnn.benchmark = False 52 | torch.backends.cudnn.deterministic = True 53 | 54 | if args.gpu is not None: 55 | warnings.warn('You have chosen a specific GPU. This will completely ' 56 | 'disable data parallelism.') 57 | 58 | if args.dist_url == "env://" and args.world_size == -1: 59 | args.world_size = int(os.environ["WORLD_SIZE"]) 60 | 61 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 62 | 63 | ngpus_per_node = torch.cuda.device_count() 64 | if args.multiprocessing_distributed: 65 | # Since we have ngpus_per_node processes per node, the total world_size 66 | # needs to be adjusted accordingly 67 | args.world_size = ngpus_per_node * args.world_size 68 | # Use torch.multiprocessing.spawn to launch distributed processes: the 69 | # main_worker process function 70 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 71 | else: 72 | # Simply call main_worker function 73 | main_worker(args.gpu, ngpus_per_node, args) 74 | 75 | def main_worker(gpu, ngpus_per_node, args): 76 | args.gpu = gpu 77 | 78 | if args.gpu is not None: 79 | print("Use GPU: {} for training".format(args.gpu)) 80 | 81 | if args.distributed: 82 | if args.dist_url == "env://" and args.rank == -1: 83 | args.rank = int(os.environ["RANK"]) 84 | if args.multiprocessing_distributed: 85 | # For multiprocessing distributed training, rank needs to be the 86 | # global rank among all the processes 87 | args.rank = args.rank * ngpus_per_node + gpu 88 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 89 | world_size=args.world_size, rank=args.rank) 90 | # weight init 91 | def weights_init(m): 92 | classname = m.__class__.__name__ 93 | if classname.find('Conv2d') != -1: 94 | if args.init_type == 'normal': 95 | nn.init.normal_(m.weight.data, 0.0, 0.02) 96 | elif args.init_type == 'orth': 97 | nn.init.orthogonal_(m.weight.data) 98 | elif args.init_type == 'xavier_uniform': 99 | nn.init.xavier_uniform(m.weight.data, 1.) 100 | else: 101 | raise NotImplementedError('{} unknown inital type'.format(args.init_type)) 102 | # elif classname.find('Linear') != -1: 103 | # if args.init_type == 'normal': 104 | # nn.init.normal_(m.weight.data, 0.0, 0.02) 105 | # elif args.init_type == 'orth': 106 | # nn.init.orthogonal_(m.weight.data) 107 | # elif args.init_type == 'xavier_uniform': 108 | # nn.init.xavier_uniform(m.weight.data, 1.) 109 | # else: 110 | # raise NotImplementedError('{} unknown inital type'.format(args.init_type)) 111 | elif classname.find('BatchNorm2d') != -1: 112 | nn.init.normal_(m.weight.data, 1.0, 0.02) 113 | nn.init.constant_(m.bias.data, 0.0) 114 | 115 | # import network 116 | 117 | gen_net = Generator() 118 | print(gen_net) 119 | dis_net = Discriminator() 120 | print(dis_net) 121 | if not torch.cuda.is_available(): 122 | print('using CPU, this will be slow') 123 | elif args.distributed: 124 | # For multiprocessing distributed, DistributedDataParallel constructor 125 | # should always set the single device scope, otherwise, 126 | # DistributedDataParallel will use all available devices. 127 | if args.gpu is not None: 128 | torch.cuda.set_device(args.gpu) 129 | # gen_net = eval('models_search.'+args.gen_model+'.Generator')(args=args) 130 | # dis_net = eval('models_search.'+args.dis_model+'.Discriminator')(args=args) 131 | 132 | gen_net.apply(weights_init) 133 | dis_net.apply(weights_init) 134 | gen_net.cuda(args.gpu) 135 | dis_net.cuda(args.gpu) 136 | # When using a single GPU per process and per 137 | # DistributedDataParallel, we need to divide the batch size 138 | # ourselves based on the total number of GPUs we have 139 | args.dis_batch_size = int(args.dis_batch_size / ngpus_per_node) 140 | args.gen_batch_size = int(args.gen_batch_size / ngpus_per_node) 141 | args.batch_size = args.dis_batch_size 142 | 143 | args.num_workers = int((args.num_workers + ngpus_per_node - 1) / ngpus_per_node) 144 | gen_net = torch.nn.parallel.DistributedDataParallel(gen_net, device_ids=[args.gpu], find_unused_parameters=True) 145 | dis_net = torch.nn.parallel.DistributedDataParallel(dis_net, device_ids=[args.gpu], find_unused_parameters=True) 146 | else: 147 | gen_net.cuda() 148 | dis_net.cuda() 149 | # DistributedDataParallel will divide and allocate batch_size to all 150 | # available GPUs if device_ids are not set 151 | gen_net = torch.nn.parallel.DistributedDataParallel(gen_net) 152 | dis_net = torch.nn.parallel.DistributedDataParallel(dis_net) 153 | elif args.gpu is not None: 154 | torch.cuda.set_device(args.gpu) 155 | gen_net.cuda(args.gpu) 156 | dis_net.cuda(args.gpu) 157 | else: 158 | gen_net = torch.nn.DataParallel(gen_net).cuda() 159 | dis_net = torch.nn.DataParallel(dis_net).cuda() 160 | print(dis_net) if args.rank == 0 else 0 161 | 162 | 163 | # set optimizer 164 | if args.optimizer == "adam": 165 | gen_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, gen_net.parameters()), 166 | args.g_lr, (args.beta1, args.beta2)) 167 | dis_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, dis_net.parameters()), 168 | args.d_lr, (args.beta1, args.beta2)) 169 | elif args.optimizer == "adamw": 170 | gen_optimizer = AdamW(filter(lambda p: p.requires_grad, gen_net.parameters()), 171 | args.g_lr, weight_decay=args.wd) 172 | dis_optimizer = AdamW(filter(lambda p: p.requires_grad, dis_net.parameters()), 173 | args.g_lr, weight_decay=args.wd) 174 | 175 | gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0, args.max_iter * args.n_critic) 176 | dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0, args.max_iter * args.n_critic) 177 | 178 | # fid stat 179 | # if args.dataset.lower() == 'cifar10': 180 | # fid_stat = 'fid_stat/fid_stats_cifar10_train.npz' 181 | # elif args.dataset.lower() == 'stl10': 182 | # fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz' 183 | # elif args.fid_stat is not None: 184 | # fid_stat = args.fid_stat 185 | # else: 186 | # raise NotImplementedError(f'no fid stat for {args.dataset.lower()}') 187 | # assert os.path.exists(fid_stat) 188 | 189 | 190 | # epoch number for dis_net 191 | args.max_epoch = args.max_epoch * args.n_critic 192 | # dataset = datasets.ImageDataset(args, cur_img_size=8) 193 | # train_loader = dataset.train 194 | # train_sampler = dataset.train_sampler 195 | 196 | # train_set = unimib_load_dataset(incl_xyz_accel = True, incl_rms_accel = False, incl_val_group = False, one_hot_encode = False, data_mode = 'Train') 197 | # test_set = unimib_load_dataset(incl_xyz_accel = True, incl_rms_accel = False, incl_val_group = False, one_hot_encode = False, data_mode = 'Test') 198 | # train_loader = data.DataLoader(train_set, batch_size=args.dis_batch_size, num_workers=args.num_workers, shuffle=True) 199 | # test_loader = data.DataLoader(test_set, batch_size=args.dis_batch_size, num_workers=args.num_workers, shuffle=True) 200 | 201 | train_set = unimib_load_dataset(incl_xyz_accel = True, incl_rms_accel = False, incl_val_group = False, is_normalize = True, one_hot_encode = False, data_mode = 'Train', single_class = True, class_name = args.class_name, augment_times=args.augment_times) 202 | train_loader = data.DataLoader(train_set, batch_size=args.batch_size, num_workers=args.num_workers, shuffle = True) 203 | test_set = unimib_load_dataset(incl_xyz_accel = True, incl_rms_accel = False, incl_val_group = False, is_normalize = True, one_hot_encode = False, data_mode = 'Test', single_class = True, class_name = args.class_name) 204 | test_loader = data.DataLoader(test_set, batch_size=args.batch_size, num_workers=args.num_workers, shuffle = True) 205 | 206 | print(len(train_loader)) 207 | 208 | if args.max_iter: 209 | args.max_epoch = np.ceil(args.max_iter * args.n_critic / len(train_loader)) 210 | 211 | # initial 212 | fixed_z = torch.cuda.FloatTensor(np.random.normal(0, 1, (100, args.latent_dim))) 213 | avg_gen_net = deepcopy(gen_net).cpu() 214 | gen_avg_param = copy_params(avg_gen_net) 215 | del avg_gen_net 216 | start_epoch = 0 217 | best_fid = 1e4 218 | 219 | # set writer 220 | writer = None 221 | if args.load_path: 222 | print(f'=> resuming from {args.load_path}') 223 | assert os.path.exists(args.load_path) 224 | checkpoint_file = os.path.join(args.load_path) 225 | assert os.path.exists(checkpoint_file) 226 | loc = 'cuda:{}'.format(args.gpu) 227 | checkpoint = torch.load(checkpoint_file, map_location=loc) 228 | start_epoch = checkpoint['epoch'] 229 | best_fid = checkpoint['best_fid'] 230 | 231 | 232 | dis_net.load_state_dict(checkpoint['dis_state_dict']) 233 | gen_optimizer.load_state_dict(checkpoint['gen_optimizer']) 234 | dis_optimizer.load_state_dict(checkpoint['dis_optimizer']) 235 | 236 | # avg_gen_net = deepcopy(gen_net) 237 | gen_net.load_state_dict(checkpoint['avg_gen_state_dict']) 238 | gen_avg_param = copy_params(gen_net, mode='gpu') 239 | gen_net.load_state_dict(checkpoint['gen_state_dict']) 240 | fixed_z = checkpoint['fixed_z'] 241 | # del avg_gen_net 242 | # gen_avg_param = list(p.cuda().to(f"cuda:{args.gpu}") for p in gen_avg_param) 243 | 244 | 245 | 246 | args.path_helper = checkpoint['path_helper'] 247 | logger = create_logger(args.path_helper['log_path']) if args.rank == 0 else None 248 | print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})') 249 | writer = SummaryWriter(args.path_helper['log_path']) if args.rank == 0 else None 250 | del checkpoint 251 | else: 252 | # create new log dir 253 | assert args.exp_name 254 | if args.rank == 0: 255 | args.path_helper = set_log_dir('logs', args.exp_name) 256 | logger = create_logger(args.path_helper['log_path']) 257 | writer = SummaryWriter(args.path_helper['log_path']) 258 | 259 | if args.rank == 0: 260 | logger.info(args) 261 | writer_dict = { 262 | 'writer': writer, 263 | 'train_global_steps': start_epoch * len(train_loader), 264 | 'valid_global_steps': start_epoch // args.val_freq, 265 | } 266 | 267 | # train loop 268 | for epoch in range(int(start_epoch), int(args.max_epoch)): 269 | # train_sampler.set_epoch(epoch) 270 | lr_schedulers = (gen_scheduler, dis_scheduler) if args.lr_decay else None 271 | cur_stage = cur_stages(epoch, args) 272 | print("cur_stage " + str(cur_stage)) if args.rank==0 else 0 273 | print(f"path: {args.path_helper['prefix']}") if args.rank==0 else 0 274 | 275 | # if (epoch+1) % 3 == 0: 276 | # # train discriminator and generator both 277 | # train(args, gen_net, dis_net, gen_optimizer, dis_optimizer, gen_avg_param, train_loader, epoch, writer_dict,fixed_z, lr_schedulers) 278 | # else: 279 | # #only train discriminator 280 | # train_d(args, gen_net, dis_net, dis_optimizer, train_loader, epoch, writer_dict,fixed_z, lr_schedulers) 281 | train(args, gen_net, dis_net, gen_optimizer, dis_optimizer, gen_avg_param, train_loader, epoch, writer_dict,fixed_z, lr_schedulers) 282 | 283 | if args.rank == 0 and args.show: 284 | backup_param = copy_params(gen_net) 285 | load_params(gen_net, gen_avg_param, args, mode="cpu") 286 | save_samples(args, fixed_z, fid_stat, epoch, gen_net, writer_dict) 287 | load_params(gen_net, backup_param, args) 288 | 289 | #fid_stat is not defined It doesn't make sense to use image evaluate matrics 290 | # if epoch and epoch % args.val_freq == 0 or epoch == int(args.max_epoch)-1: 291 | # backup_param = copy_params(gen_net) 292 | # load_params(gen_net, gen_avg_param, args, mode="cpu") 293 | # inception_score, fid_score = validate(args, fixed_z, fid_stat, epoch, gen_net, writer_dict) 294 | # if args.rank==0: 295 | # logger.info(f'Inception score: {inception_score}, FID score: {fid_score} || @ epoch {epoch}.') 296 | # load_params(gen_net, backup_param, args) 297 | # if fid_score < best_fid: 298 | # best_fid = fid_score 299 | # is_best = True 300 | # else: 301 | # is_best = False 302 | # else: 303 | # is_best = False 304 | 305 | #TO DO: Validate add synthetic data plot in tensorboard 306 | #Plot synthetic data every 5 epochs 307 | # if epoch and epoch % 1 == 0: 308 | gen_net.eval() 309 | plot_buf = gen_plot(gen_net, epoch, args.class_name) 310 | image = PIL.Image.open(plot_buf) 311 | image = ToTensor()(image).unsqueeze(0) 312 | #writer = SummaryWriter(comment='synthetic signals') 313 | writer.add_image('Image', image[0], epoch) 314 | 315 | is_best = False 316 | avg_gen_net = deepcopy(gen_net) 317 | load_params(avg_gen_net, gen_avg_param, args) 318 | # if not args.multiprocessing_distributed or (args.multiprocessing_distributed 319 | # and args.rank == 0): 320 | # Add module in model saving code exp'gen_net.module.state_dict()' to solve the model loading unpaired name problem 321 | save_checkpoint({ 322 | 'epoch': epoch + 1, 323 | 'gen_model': args.gen_model, 324 | 'dis_model': args.dis_model, 325 | 'gen_state_dict': gen_net.module.state_dict(), 326 | 'dis_state_dict': dis_net.module.state_dict(), 327 | 'avg_gen_state_dict': avg_gen_net.module.state_dict(), 328 | 'gen_optimizer': gen_optimizer.state_dict(), 329 | 'dis_optimizer': dis_optimizer.state_dict(), 330 | 'best_fid': best_fid, 331 | 'path_helper': args.path_helper, 332 | 'fixed_z': fixed_z 333 | }, is_best, args.path_helper['ckpt_path'], filename="checkpoint") 334 | del avg_gen_net 335 | 336 | def gen_plot(gen_net, epoch, class_name): 337 | """Create a pyplot plot and save to buffer.""" 338 | synthetic_data = [] 339 | 340 | for i in range(10): 341 | fake_noise = torch.FloatTensor(np.random.normal(0, 1, (1, 100))) 342 | fake_sigs = gen_net(fake_noise).to('cpu').detach().numpy() 343 | synthetic_data.append(fake_sigs) 344 | 345 | fig, axs = plt.subplots(2, 5, figsize=(20,5)) 346 | fig.suptitle(f'Synthetic {class_name} at epoch {epoch}', fontsize=30) 347 | for i in range(2): 348 | for j in range(5): 349 | axs[i, j].plot(synthetic_data[i*5+j][0][0][0][:]) 350 | axs[i, j].plot(synthetic_data[i*5+j][0][1][0][:]) 351 | axs[i, j].plot(synthetic_data[i*5+j][0][2][0][:]) 352 | buf = io.BytesIO() 353 | plt.savefig(buf, format='jpeg') 354 | buf.seek(0) 355 | return buf 356 | 357 | if __name__ == '__main__': 358 | main() 359 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2019-07-25 3 | # @Author : Xinyu Gong (xy_gong@tamu.edu) 4 | # @Link : None 5 | # @Version : 0.0 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | from utils import utils 12 | -------------------------------------------------------------------------------- /utils/cal_fid_stat.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2019-07-26 3 | # @Author : Xinyu Gong (xy_gong@tamu.edu) 4 | # @Link : None 5 | # @Version : 0.0 6 | 7 | 8 | import os 9 | import glob 10 | import argparse 11 | import numpy as np 12 | from imageio import imread 13 | import tensorflow as tf 14 | 15 | import utils.fid_score as fid 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument( 21 | '--data_path', 22 | type=str, 23 | required=True, 24 | help='set path to training set jpg images dir') 25 | parser.add_argument( 26 | '--output_file', 27 | type=str, 28 | default='fid_stat/fid_stats_cifar10_train.npz', 29 | help='path for where to store the statistics') 30 | 31 | opt = parser.parse_args() 32 | print(opt) 33 | return opt 34 | 35 | 36 | def main(): 37 | args = parse_args() 38 | 39 | ######## 40 | # PATHS 41 | ######## 42 | data_path = args.data_path 43 | output_path = args.output_file 44 | # if you have downloaded and extracted 45 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 46 | # set this path to the directory where the extracted files are, otherwise 47 | # just set it to None and the script will later download the files for you 48 | inception_path = None 49 | print("check for inception model..", end=" ", flush=True) 50 | inception_path = fid.check_or_download_inception(inception_path) # download inception if necessary 51 | print("ok") 52 | 53 | # loads all images into memory (this might require a lot of RAM!) 54 | print("load images..", end=" ", flush=True) 55 | image_list = glob.glob(os.path.join(data_path, '*.jpg')) 56 | images = np.array([imread(str(fn)).astype(np.float32) for fn in image_list]) 57 | print("%d images found and loaded" % len(images)) 58 | 59 | print("create inception graph..", end=" ", flush=True) 60 | fid.create_inception_graph(inception_path) # load the graph into the current TF graph 61 | print("ok") 62 | 63 | print("calculte FID stats..", end=" ", flush=True) 64 | config = tf.ConfigProto() 65 | config.gpu_options.allow_growth = True 66 | with tf.Session(config=config) as sess: 67 | sess.run(tf.global_variables_initializer()) 68 | mu, sigma = fid.calculate_activation_statistics(images, sess, batch_size=100) 69 | np.savez_compressed(output_path, mu=mu, sigma=sigma) 70 | print("finished") 71 | 72 | 73 | if __name__ == '__main__': 74 | main() 75 | -------------------------------------------------------------------------------- /utils/fid_score.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ Calculates the Frechet Inception Distance (FID) to evaluate GANs. 3 | 4 | The FID metric calculates the distance between two distributions of images. 5 | Typically, we have summary statistics (mean & covariance matrix) of one 6 | of these distributions, while the 2nd distribution is given by a GAN. 7 | 8 | When run as a stand-alone program, it compares the distribution of 9 | images that are stored as PNG/JPEG at a specified location with a 10 | distribution given by summary statistics (in pickle format). 11 | 12 | The FID is calculated by assuming that X_1 and X_2 are the activations of 13 | the pool_3 layer of the inception net for generated samples and real world 14 | samples respectively. 15 | 16 | See --help to see further details. 17 | """ 18 | 19 | from __future__ import absolute_import, division, print_function 20 | 21 | import os 22 | import pathlib 23 | import warnings 24 | 25 | import numpy as np 26 | import tensorflow.compat.v1 as tf 27 | tf.disable_v2_behavior() 28 | 29 | from scipy import linalg 30 | from imageio import imread 31 | from tqdm import tqdm 32 | 33 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 34 | 35 | 36 | class InvalidFIDException(Exception): 37 | pass 38 | 39 | 40 | def create_inception_graph(pth): 41 | """Creates a graph from saved GraphDef file.""" 42 | # Creates graph from saved graph_def.pb. 43 | with tf.gfile.FastGFile(pth, 'rb') as f: 44 | graph_def = tf.GraphDef() 45 | graph_def.ParseFromString(f.read()) 46 | _ = tf.import_graph_def(graph_def, name='FID_Inception_Net') 47 | 48 | 49 | # ------------------------------------------------------------------------------- 50 | 51 | 52 | # code for handling inception net derived from 53 | # https://github.com/openai/improved-gan/blob/master/inception_score/model.py 54 | def _get_inception_layer(sess): 55 | """Prepares inception net for batched usage and returns pool_3 layer. """ 56 | layername = 'FID_Inception_Net/pool_3:0' 57 | pool3 = sess.graph.get_tensor_by_name(layername) 58 | ops = pool3.graph.get_operations() 59 | for op_idx, op in enumerate(ops): 60 | for o in op.outputs: 61 | shape = o.get_shape() 62 | if shape._dims != []: 63 | shape = [s.value for s in shape] 64 | new_shape = [] 65 | for j, s in enumerate(shape): 66 | if s == 1 and j == 0: 67 | new_shape.append(None) 68 | else: 69 | new_shape.append(s) 70 | o.__dict__['_shape_val'] = tf.TensorShape(new_shape) 71 | return pool3 72 | 73 | 74 | # ------------------------------------------------------------------------------- 75 | 76 | 77 | def get_activations(images, sess, batch_size=16, verbose=False): 78 | """Calculates the activations of the pool_3 layer for all images. 79 | 80 | Params: 81 | -- images : Numpy array of dimension (n_images, hi, wi, 3). The values 82 | must lie between 0 and 256. 83 | -- sess : current session 84 | -- batch_size : the images numpy array is split into batches with batch size 85 | batch_size. A reasonable batch size depends on the disposable hardware. 86 | -- verbose : If set to True and parameter out_step is given, the number of calculated 87 | batches is reported. 88 | Returns: 89 | -- A numpy array of dimension (num images, 2048) that contains the 90 | activations of the given tensor when feeding inception with the query tensor. 91 | """ 92 | inception_layer = _get_inception_layer(sess) 93 | d0 = len(images) 94 | if batch_size > d0: 95 | print("warning: batch size is bigger than the data size. setting batch size to data size") 96 | batch_size = d0 97 | n_batches = d0 // batch_size 98 | n_used_imgs = n_batches * batch_size 99 | pred_arr = np.empty((n_used_imgs, 2048)) 100 | for i in tqdm(range(n_batches)): 101 | if verbose: 102 | print("\rPropagating batch %d/%d" % (i + 1, n_batches), end="", flush=True) 103 | start = i * batch_size 104 | end = start + batch_size 105 | batch = images[start:end] 106 | pred = sess.run(inception_layer, {'FID_Inception_Net/ExpandDims:0': batch}) 107 | pred_arr[start:end] = pred.reshape(batch_size, -1) 108 | if verbose: 109 | print(" done") 110 | return pred_arr 111 | 112 | 113 | # ------------------------------------------------------------------------------- 114 | 115 | 116 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 117 | """Numpy implementation of the Frechet Distance. 118 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 119 | and X_2 ~ N(mu_2, C_2) is 120 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 121 | 122 | Stable version by Dougal J. Sutherland. 123 | 124 | Params: 125 | -- mu1 : Numpy array containing the activations of the pool_3 layer of the 126 | inception net ( like returned by the function 'get_predictions') 127 | for generated samples. 128 | -- mu2 : The sample mean over activations of the pool_3 layer, precalcualted 129 | on an representive data set. 130 | -- sigma1: The covariance matrix over activations of the pool_3 layer for 131 | generated samples. 132 | -- sigma2: The covariance matrix over activations of the pool_3 layer, 133 | precalcualted on an representive data set. 134 | 135 | Returns: 136 | -- : The Frechet Distance. 137 | """ 138 | 139 | mu1 = np.atleast_1d(mu1) 140 | mu2 = np.atleast_1d(mu2) 141 | 142 | sigma1 = np.atleast_2d(sigma1) 143 | sigma2 = np.atleast_2d(sigma2) 144 | 145 | assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths" 146 | assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions" 147 | 148 | diff = mu1 - mu2 149 | 150 | # product might be almost singular 151 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 152 | if not np.isfinite(covmean).all(): 153 | msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps 154 | warnings.warn(msg) 155 | offset = np.eye(sigma1.shape[0]) * eps 156 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 157 | 158 | # numerical error might give slight imaginary component 159 | if np.iscomplexobj(covmean): 160 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 161 | m = np.max(np.abs(covmean.imag)) 162 | raise ValueError("Imaginary component {}".format(m)) 163 | covmean = covmean.real 164 | 165 | tr_covmean = np.trace(covmean) 166 | 167 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 168 | 169 | 170 | # ------------------------------------------------------------------------------- 171 | 172 | 173 | def calculate_activation_statistics(images, sess, batch_size=16, verbose=False): 174 | """Calculation of the statistics used by the FID. 175 | Params: 176 | -- images : Numpy array of dimension (n_images, hi, wi, 3). The values 177 | must lie between 0 and 255. 178 | -- sess : current session 179 | -- batch_size : the images numpy array is split into batches with batch size 180 | batch_size. A reasonable batch size depends on the available hardware. 181 | -- verbose : If set to True and parameter out_step is given, the number of calculated 182 | batches is reported. 183 | Returns: 184 | -- mu : The mean over samples of the activations of the pool_3 layer of 185 | the incption model. 186 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 187 | the incption model. 188 | """ 189 | act = get_activations(images, sess, batch_size, verbose) 190 | mu = np.mean(act, axis=0) 191 | sigma = np.cov(act, rowvar=False) 192 | return mu, sigma 193 | 194 | 195 | # ------------------ 196 | # The following methods are implemented to obtain a batched version of the activations. 197 | # This has the advantage to reduce memory requirements, at the cost of slightly reduced efficiency. 198 | # - Pyrestone 199 | # ------------------ 200 | 201 | 202 | def load_image_batch(files): 203 | """Convenience method for batch-loading images 204 | Params: 205 | -- files : list of paths to image files. Images need to have same dimensions for all files. 206 | Returns: 207 | -- A numpy array of dimensions (num_images,hi, wi, 3) representing the image pixel values. 208 | """ 209 | return np.array([imread(str(fn)).astype(np.float32) for fn in files]) 210 | 211 | 212 | def get_activations_from_files(files, sess, batch_size=16, verbose=False): 213 | """Calculates the activations of the pool_3 layer for all images. 214 | 215 | Params: 216 | -- files : list of paths to image files. Images need to have same dimensions for all files. 217 | -- sess : current session 218 | -- batch_size : the images numpy array is split into batches with batch size 219 | batch_size. A reasonable batch size depends on the disposable hardware. 220 | -- verbose : If set to True and parameter out_step is given, the number of calculated 221 | batches is reported. 222 | Returns: 223 | -- A numpy array of dimension (num images, 2048) that contains the 224 | activations of the given tensor when feeding inception with the query tensor. 225 | """ 226 | inception_layer = _get_inception_layer(sess) 227 | d0 = len(files) 228 | if batch_size > d0: 229 | print("warning: batch size is bigger than the data size. setting batch size to data size") 230 | batch_size = d0 231 | n_batches = d0 // batch_size 232 | n_used_imgs = n_batches * batch_size 233 | pred_arr = np.empty((n_used_imgs, 2048)) 234 | for i in range(n_batches): 235 | if verbose: 236 | print("\rPropagating batch %d/%d" % (i + 1, n_batches), end="", flush=True) 237 | start = i * batch_size 238 | end = start + batch_size 239 | batch = load_image_batch(files[start:end]) 240 | pred = sess.run(inception_layer, {'FID_Inception_Net/ExpandDims:0': batch}) 241 | pred_arr[start:end] = pred.reshape(batch_size, -1) 242 | del batch # clean up memory 243 | if verbose: 244 | print(" done") 245 | return pred_arr 246 | 247 | 248 | def calculate_activation_statistics_from_files(files, sess, batch_size=1, verbose=False): 249 | """Calculation of the statistics used by the FID. 250 | Params: 251 | -- files : list of paths to image files. Images need to have same dimensions for all files. 252 | -- sess : current session 253 | -- batch_size : the images numpy array is split into batches with batch size 254 | batch_size. A reasonable batch size depends on the available hardware. 255 | -- verbose : If set to True and parameter out_step is given, the number of calculated 256 | batches is reported. 257 | Returns: 258 | -- mu : The mean over samples of the activations of the pool_3 layer of 259 | the incption model. 260 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 261 | the incption model. 262 | """ 263 | act = get_activations_from_files(files, sess, batch_size, verbose) 264 | mu = np.mean(act, axis=0) 265 | sigma = np.cov(act, rowvar=False) 266 | return mu, sigma 267 | 268 | 269 | # ------------------------------------------------------------------------------- 270 | 271 | 272 | # ------------------------------------------------------------------------------- 273 | # The following functions aren't needed for calculating the FID 274 | # they're just here to make this module work as a stand-alone script 275 | # for calculating FID scores 276 | # ------------------------------------------------------------------------------- 277 | def check_or_download_inception(inception_path): 278 | """ Checks if the path to the inception file is valid, or downloads 279 | the file if it is not present. """ 280 | INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 281 | if inception_path is None: 282 | inception_path = '/tmp' 283 | inception_path = pathlib.Path(inception_path) 284 | model_file = inception_path / 'classify_image_graph_def.pb' 285 | if not model_file.exists(): 286 | print("Downloading Inception model") 287 | from urllib import request 288 | import tarfile 289 | fn, _ = request.urlretrieve(INCEPTION_URL) 290 | with tarfile.open(fn, mode='r') as f: 291 | f.extract('classify_image_graph_def.pb', str(model_file.parent)) 292 | return str(model_file) 293 | 294 | 295 | def _handle_path(path, sess, low_profile=False): 296 | if isinstance(path, str): 297 | f = np.load(path) 298 | m, s = f['mu'][:], f['sigma'][:] 299 | f.close() 300 | else: 301 | # path = pathlib.Path(path) 302 | files = path 303 | if low_profile: 304 | m, s = calculate_activation_statistics_from_files(files, sess) 305 | else: 306 | # x = np.array([imread(str(fn)).astype(np.float32) for fn in files]) 307 | x = path 308 | m, s = calculate_activation_statistics(x, sess) 309 | del x # clean up memory 310 | return m, s 311 | 312 | 313 | def calculate_fid_given_paths(paths, inception_path, low_profile=False): 314 | """ Calculates the FID of two paths. """ 315 | # inception_path = check_or_download_inception(inception_path) 316 | 317 | # for p in paths: 318 | # if not os.path.exists(p): 319 | # raise RuntimeError("Invalid path: %s" % p) 320 | 321 | config = tf.ConfigProto() 322 | config.gpu_options.allow_growth = True 323 | with tf.Session(config=config) as sess: 324 | sess.run(tf.global_variables_initializer()) 325 | m1, s1 = _handle_path(paths[0], sess, low_profile=low_profile) 326 | m2, s2 = _handle_path(paths[1], sess, low_profile=low_profile) 327 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 328 | sess.close() 329 | 330 | return fid_value 331 | -------------------------------------------------------------------------------- /utils/inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | 6 | try: 7 | from torchvision.models.utils import load_state_dict_from_url 8 | except ImportError: 9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 10 | 11 | # Inception weights ported to Pytorch from 12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' 14 | 15 | 16 | class InceptionV3(nn.Module): 17 | """Pretrained InceptionV3 network returning feature maps""" 18 | 19 | # Index of default block of inception to return, 20 | # corresponds to output of final average pooling 21 | DEFAULT_BLOCK_INDEX = 3 22 | 23 | # Maps feature dimensionality to their output blocks indices 24 | BLOCK_INDEX_BY_DIM = { 25 | 64: 0, # First max pooling features 26 | 192: 1, # Second max pooling featurs 27 | 768: 2, # Pre-aux classifier features 28 | 2048: 3 # Final average pooling features 29 | } 30 | 31 | def __init__(self, 32 | output_blocks=[DEFAULT_BLOCK_INDEX], 33 | resize_input=True, 34 | normalize_input=True, 35 | requires_grad=False, 36 | use_fid_inception=True): 37 | """Build pretrained InceptionV3 38 | Parameters 39 | ---------- 40 | output_blocks : list of int 41 | Indices of blocks to return features of. Possible values are: 42 | - 0: corresponds to output of first max pooling 43 | - 1: corresponds to output of second max pooling 44 | - 2: corresponds to output which is fed to aux classifier 45 | - 3: corresponds to output of final average pooling 46 | resize_input : bool 47 | If true, bilinearly resizes input to width and height 299 before 48 | feeding input to model. As the network without fully connected 49 | layers is fully convolutional, it should be able to handle inputs 50 | of arbitrary size, so resizing might not be strictly needed 51 | normalize_input : bool 52 | If true, scales the input from range (0, 1) to the range the 53 | pretrained Inception network expects, namely (-1, 1) 54 | requires_grad : bool 55 | If true, parameters of the model require gradients. Possibly useful 56 | for finetuning the network 57 | use_fid_inception : bool 58 | If true, uses the pretrained Inception model used in Tensorflow's 59 | FID implementation. If false, uses the pretrained Inception model 60 | available in torchvision. The FID Inception model has different 61 | weights and a slightly different structure from torchvision's 62 | Inception model. If you want to compute FID scores, you are 63 | strongly advised to set this parameter to true to get comparable 64 | results. 65 | """ 66 | super(InceptionV3, self).__init__() 67 | 68 | self.resize_input = resize_input 69 | self.normalize_input = normalize_input 70 | self.output_blocks = sorted(output_blocks) 71 | self.last_needed_block = max(output_blocks) 72 | 73 | assert self.last_needed_block <= 3, \ 74 | 'Last possible output block index is 3' 75 | 76 | self.blocks = nn.ModuleList() 77 | 78 | if use_fid_inception: 79 | inception = fid_inception_v3() 80 | else: 81 | inception = models.inception_v3(pretrained=True) 82 | 83 | # Block 0: input to maxpool1 84 | block0 = [ 85 | inception.Conv2d_1a_3x3, 86 | inception.Conv2d_2a_3x3, 87 | inception.Conv2d_2b_3x3, 88 | nn.MaxPool2d(kernel_size=3, stride=2) 89 | ] 90 | self.blocks.append(nn.Sequential(*block0)) 91 | 92 | # Block 1: maxpool1 to maxpool2 93 | if self.last_needed_block >= 1: 94 | block1 = [ 95 | inception.Conv2d_3b_1x1, 96 | inception.Conv2d_4a_3x3, 97 | nn.MaxPool2d(kernel_size=3, stride=2) 98 | ] 99 | self.blocks.append(nn.Sequential(*block1)) 100 | 101 | # Block 2: maxpool2 to aux classifier 102 | if self.last_needed_block >= 2: 103 | block2 = [ 104 | inception.Mixed_5b, 105 | inception.Mixed_5c, 106 | inception.Mixed_5d, 107 | inception.Mixed_6a, 108 | inception.Mixed_6b, 109 | inception.Mixed_6c, 110 | inception.Mixed_6d, 111 | inception.Mixed_6e, 112 | ] 113 | self.blocks.append(nn.Sequential(*block2)) 114 | 115 | # Block 3: aux classifier to final avgpool 116 | if self.last_needed_block >= 3: 117 | block3 = [ 118 | inception.Mixed_7a, 119 | inception.Mixed_7b, 120 | inception.Mixed_7c, 121 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 122 | ] 123 | self.blocks.append(nn.Sequential(*block3)) 124 | 125 | for param in self.parameters(): 126 | param.requires_grad = requires_grad 127 | 128 | def forward(self, inp): 129 | """Get Inception feature maps 130 | Parameters 131 | ---------- 132 | inp : torch.autograd.Variable 133 | Input tensor of shape Bx3xHxW. Values are expected to be in 134 | range (0, 1) 135 | Returns 136 | ------- 137 | List of torch.autograd.Variable, corresponding to the selected output 138 | block, sorted ascending by index 139 | """ 140 | outp = [] 141 | x = inp 142 | 143 | if self.resize_input: 144 | x = F.interpolate(x, 145 | size=(299, 299), 146 | mode='bilinear', 147 | align_corners=False) 148 | 149 | if self.normalize_input: 150 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 151 | 152 | for idx, block in enumerate(self.blocks): 153 | x = block(x) 154 | if idx in self.output_blocks: 155 | outp.append(x) 156 | 157 | if idx == self.last_needed_block: 158 | break 159 | 160 | return outp 161 | 162 | 163 | def fid_inception_v3(): 164 | """Build pretrained Inception model for FID computation 165 | The Inception model for FID computation uses a different set of weights 166 | and has a slightly different structure than torchvision's Inception. 167 | This method first constructs torchvision's Inception and then patches the 168 | necessary parts that are different in the FID Inception model. 169 | """ 170 | inception = models.inception_v3(num_classes=1008, 171 | aux_logits=False, 172 | pretrained=False, 173 | init_weights=False) 174 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 175 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 176 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 177 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 178 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 179 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 180 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 181 | inception.Mixed_7b = FIDInceptionE_1(1280) 182 | inception.Mixed_7c = FIDInceptionE_2(2048) 183 | 184 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 185 | inception.load_state_dict(state_dict) 186 | return inception 187 | 188 | 189 | class FIDInceptionA(models.inception.InceptionA): 190 | """InceptionA block patched for FID computation""" 191 | 192 | def __init__(self, in_channels, pool_features): 193 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 194 | 195 | def forward(self, x): 196 | branch1x1 = self.branch1x1(x) 197 | 198 | branch5x5 = self.branch5x5_1(x) 199 | branch5x5 = self.branch5x5_2(branch5x5) 200 | 201 | branch3x3dbl = self.branch3x3dbl_1(x) 202 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 203 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 204 | 205 | # Patch: Tensorflow's average pool does not use the padded zero's in 206 | # its average calculation 207 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 208 | count_include_pad=False) 209 | branch_pool = self.branch_pool(branch_pool) 210 | 211 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 212 | return torch.cat(outputs, 1) 213 | 214 | 215 | class FIDInceptionC(models.inception.InceptionC): 216 | """InceptionC block patched for FID computation""" 217 | 218 | def __init__(self, in_channels, channels_7x7): 219 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 220 | 221 | def forward(self, x): 222 | branch1x1 = self.branch1x1(x) 223 | 224 | branch7x7 = self.branch7x7_1(x) 225 | branch7x7 = self.branch7x7_2(branch7x7) 226 | branch7x7 = self.branch7x7_3(branch7x7) 227 | 228 | branch7x7dbl = self.branch7x7dbl_1(x) 229 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 230 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 231 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 232 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 233 | 234 | # Patch: Tensorflow's average pool does not use the padded zero's in 235 | # its average calculation 236 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 237 | count_include_pad=False) 238 | branch_pool = self.branch_pool(branch_pool) 239 | 240 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 241 | return torch.cat(outputs, 1) 242 | 243 | 244 | class FIDInceptionE_1(models.inception.InceptionE): 245 | """First InceptionE block patched for FID computation""" 246 | 247 | def __init__(self, in_channels): 248 | super(FIDInceptionE_1, self).__init__(in_channels) 249 | 250 | def forward(self, x): 251 | branch1x1 = self.branch1x1(x) 252 | 253 | branch3x3 = self.branch3x3_1(x) 254 | branch3x3 = [ 255 | self.branch3x3_2a(branch3x3), 256 | self.branch3x3_2b(branch3x3), 257 | ] 258 | branch3x3 = torch.cat(branch3x3, 1) 259 | 260 | branch3x3dbl = self.branch3x3dbl_1(x) 261 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 262 | branch3x3dbl = [ 263 | self.branch3x3dbl_3a(branch3x3dbl), 264 | self.branch3x3dbl_3b(branch3x3dbl), 265 | ] 266 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 267 | 268 | # Patch: Tensorflow's average pool does not use the padded zero's in 269 | # its average calculation 270 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 271 | count_include_pad=False) 272 | branch_pool = self.branch_pool(branch_pool) 273 | 274 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 275 | return torch.cat(outputs, 1) 276 | 277 | 278 | class FIDInceptionE_2(models.inception.InceptionE): 279 | """Second InceptionE block patched for FID computation""" 280 | 281 | def __init__(self, in_channels): 282 | super(FIDInceptionE_2, self).__init__(in_channels) 283 | 284 | def forward(self, x): 285 | branch1x1 = self.branch1x1(x) 286 | 287 | branch3x3 = self.branch3x3_1(x) 288 | branch3x3 = [ 289 | self.branch3x3_2a(branch3x3), 290 | self.branch3x3_2b(branch3x3), 291 | ] 292 | branch3x3 = torch.cat(branch3x3, 1) 293 | 294 | branch3x3dbl = self.branch3x3dbl_1(x) 295 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 296 | branch3x3dbl = [ 297 | self.branch3x3dbl_3a(branch3x3dbl), 298 | self.branch3x3dbl_3b(branch3x3dbl), 299 | ] 300 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 301 | 302 | # Patch: The FID Inception model uses max pooling instead of average 303 | # pooling. This is likely an error in this specific Inception 304 | # implementation, as other Inception models use average pooling here 305 | # (which matches the description in the paper). 306 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 307 | branch_pool = self.branch_pool(branch_pool) 308 | 309 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 310 | return torch.cat(outputs, 1) 311 | -------------------------------------------------------------------------------- /utils/inception_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | 6 | try: 7 | from torchvision.models.utils import load_state_dict_from_url 8 | except ImportError: 9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 10 | 11 | # Inception weights ported to Pytorch from 12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' 14 | 15 | 16 | class InceptionV3(nn.Module): 17 | """Pretrained InceptionV3 network returning feature maps""" 18 | 19 | # Index of default block of inception to return, 20 | # corresponds to output of final average pooling 21 | DEFAULT_BLOCK_INDEX = 3 22 | 23 | # Maps feature dimensionality to their output blocks indices 24 | BLOCK_INDEX_BY_DIM = { 25 | 64: 0, # First max pooling features 26 | 192: 1, # Second max pooling featurs 27 | 768: 2, # Pre-aux classifier features 28 | 2048: 3 # Final average pooling features 29 | } 30 | 31 | def __init__(self, 32 | output_blocks=[DEFAULT_BLOCK_INDEX], 33 | resize_input=True, 34 | normalize_input=True, 35 | requires_grad=False, 36 | use_fid_inception=True): 37 | """Build pretrained InceptionV3 38 | Parameters 39 | ---------- 40 | output_blocks : list of int 41 | Indices of blocks to return features of. Possible values are: 42 | - 0: corresponds to output of first max pooling 43 | - 1: corresponds to output of second max pooling 44 | - 2: corresponds to output which is fed to aux classifier 45 | - 3: corresponds to output of final average pooling 46 | resize_input : bool 47 | If true, bilinearly resizes input to width and height 299 before 48 | feeding input to model. As the network without fully connected 49 | layers is fully convolutional, it should be able to handle inputs 50 | of arbitrary size, so resizing might not be strictly needed 51 | normalize_input : bool 52 | If true, scales the input from range (0, 1) to the range the 53 | pretrained Inception network expects, namely (-1, 1) 54 | requires_grad : bool 55 | If true, parameters of the model require gradients. Possibly useful 56 | for finetuning the network 57 | use_fid_inception : bool 58 | If true, uses the pretrained Inception model used in Tensorflow's 59 | FID implementation. If false, uses the pretrained Inception model 60 | available in torchvision. The FID Inception model has different 61 | weights and a slightly different structure from torchvision's 62 | Inception model. If you want to compute FID scores, you are 63 | strongly advised to set this parameter to true to get comparable 64 | results. 65 | """ 66 | super(InceptionV3, self).__init__() 67 | 68 | self.resize_input = resize_input 69 | self.normalize_input = normalize_input 70 | self.output_blocks = sorted(output_blocks) 71 | self.last_needed_block = max(output_blocks) 72 | 73 | assert self.last_needed_block <= 3, \ 74 | 'Last possible output block index is 3' 75 | 76 | self.blocks = nn.ModuleList() 77 | 78 | if use_fid_inception: 79 | inception = fid_inception_v3() 80 | else: 81 | inception = models.inception_v3(pretrained=True) 82 | 83 | # Block 0: input to maxpool1 84 | block0 = [ 85 | inception.Conv2d_1a_3x3, 86 | inception.Conv2d_2a_3x3, 87 | inception.Conv2d_2b_3x3, 88 | nn.MaxPool2d(kernel_size=3, stride=2) 89 | ] 90 | self.blocks.append(nn.Sequential(*block0)) 91 | 92 | # Block 1: maxpool1 to maxpool2 93 | if self.last_needed_block >= 1: 94 | block1 = [ 95 | inception.Conv2d_3b_1x1, 96 | inception.Conv2d_4a_3x3, 97 | nn.MaxPool2d(kernel_size=3, stride=2) 98 | ] 99 | self.blocks.append(nn.Sequential(*block1)) 100 | 101 | # Block 2: maxpool2 to aux classifier 102 | if self.last_needed_block >= 2: 103 | block2 = [ 104 | inception.Mixed_5b, 105 | inception.Mixed_5c, 106 | inception.Mixed_5d, 107 | inception.Mixed_6a, 108 | inception.Mixed_6b, 109 | inception.Mixed_6c, 110 | inception.Mixed_6d, 111 | inception.Mixed_6e, 112 | ] 113 | self.blocks.append(nn.Sequential(*block2)) 114 | 115 | # Block 3: aux classifier to final avgpool 116 | if self.last_needed_block >= 3: 117 | block3 = [ 118 | inception.Mixed_7a, 119 | inception.Mixed_7b, 120 | inception.Mixed_7c, 121 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 122 | ] 123 | self.blocks.append(nn.Sequential(*block3)) 124 | 125 | for param in self.parameters(): 126 | param.requires_grad = requires_grad 127 | 128 | def forward(self, inp): 129 | """Get Inception feature maps 130 | Parameters 131 | ---------- 132 | inp : torch.autograd.Variable 133 | Input tensor of shape Bx3xHxW. Values are expected to be in 134 | range (0, 1) 135 | Returns 136 | ------- 137 | List of torch.autograd.Variable, corresponding to the selected output 138 | block, sorted ascending by index 139 | """ 140 | outp = [] 141 | x = inp 142 | 143 | if self.resize_input: 144 | x = F.interpolate(x, 145 | size=(299, 299), 146 | mode='bilinear', 147 | align_corners=False) 148 | 149 | if self.normalize_input: 150 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 151 | 152 | for idx, block in enumerate(self.blocks): 153 | x = block(x) 154 | if idx in self.output_blocks: 155 | outp.append(x) 156 | 157 | if idx == self.last_needed_block: 158 | break 159 | 160 | return outp 161 | 162 | 163 | def fid_inception_v3(): 164 | """Build pretrained Inception model for FID computation 165 | The Inception model for FID computation uses a different set of weights 166 | and has a slightly different structure than torchvision's Inception. 167 | This method first constructs torchvision's Inception and then patches the 168 | necessary parts that are different in the FID Inception model. 169 | """ 170 | inception = models.inception_v3(num_classes=1008, 171 | aux_logits=False, 172 | pretrained=False) 173 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 174 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 175 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 176 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 177 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 178 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 179 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 180 | inception.Mixed_7b = FIDInceptionE_1(1280) 181 | inception.Mixed_7c = FIDInceptionE_2(2048) 182 | 183 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 184 | inception.load_state_dict(state_dict) 185 | return inception 186 | 187 | 188 | class FIDInceptionA(models.inception.InceptionA): 189 | """InceptionA block patched for FID computation""" 190 | 191 | def __init__(self, in_channels, pool_features): 192 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 193 | 194 | def forward(self, x): 195 | branch1x1 = self.branch1x1(x) 196 | 197 | branch5x5 = self.branch5x5_1(x) 198 | branch5x5 = self.branch5x5_2(branch5x5) 199 | 200 | branch3x3dbl = self.branch3x3dbl_1(x) 201 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 202 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 203 | 204 | # Patch: Tensorflow's average pool does not use the padded zero's in 205 | # its average calculation 206 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 207 | count_include_pad=False) 208 | branch_pool = self.branch_pool(branch_pool) 209 | 210 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 211 | return torch.cat(outputs, 1) 212 | 213 | 214 | class FIDInceptionC(models.inception.InceptionC): 215 | """InceptionC block patched for FID computation""" 216 | 217 | def __init__(self, in_channels, channels_7x7): 218 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 219 | 220 | def forward(self, x): 221 | branch1x1 = self.branch1x1(x) 222 | 223 | branch7x7 = self.branch7x7_1(x) 224 | branch7x7 = self.branch7x7_2(branch7x7) 225 | branch7x7 = self.branch7x7_3(branch7x7) 226 | 227 | branch7x7dbl = self.branch7x7dbl_1(x) 228 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 229 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 230 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 231 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 232 | 233 | # Patch: Tensorflow's average pool does not use the padded zero's in 234 | # its average calculation 235 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 236 | count_include_pad=False) 237 | branch_pool = self.branch_pool(branch_pool) 238 | 239 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 240 | return torch.cat(outputs, 1) 241 | 242 | 243 | class FIDInceptionE_1(models.inception.InceptionE): 244 | """First InceptionE block patched for FID computation""" 245 | 246 | def __init__(self, in_channels): 247 | super(FIDInceptionE_1, self).__init__(in_channels) 248 | 249 | def forward(self, x): 250 | branch1x1 = self.branch1x1(x) 251 | 252 | branch3x3 = self.branch3x3_1(x) 253 | branch3x3 = [ 254 | self.branch3x3_2a(branch3x3), 255 | self.branch3x3_2b(branch3x3), 256 | ] 257 | branch3x3 = torch.cat(branch3x3, 1) 258 | 259 | branch3x3dbl = self.branch3x3dbl_1(x) 260 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 261 | branch3x3dbl = [ 262 | self.branch3x3dbl_3a(branch3x3dbl), 263 | self.branch3x3dbl_3b(branch3x3dbl), 264 | ] 265 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 266 | 267 | # Patch: Tensorflow's average pool does not use the padded zero's in 268 | # its average calculation 269 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 270 | count_include_pad=False) 271 | branch_pool = self.branch_pool(branch_pool) 272 | 273 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 274 | return torch.cat(outputs, 1) 275 | 276 | 277 | class FIDInceptionE_2(models.inception.InceptionE): 278 | """Second InceptionE block patched for FID computation""" 279 | 280 | def __init__(self, in_channels): 281 | super(FIDInceptionE_2, self).__init__(in_channels) 282 | 283 | def forward(self, x): 284 | branch1x1 = self.branch1x1(x) 285 | 286 | branch3x3 = self.branch3x3_1(x) 287 | branch3x3 = [ 288 | self.branch3x3_2a(branch3x3), 289 | self.branch3x3_2b(branch3x3), 290 | ] 291 | branch3x3 = torch.cat(branch3x3, 1) 292 | 293 | branch3x3dbl = self.branch3x3dbl_1(x) 294 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 295 | branch3x3dbl = [ 296 | self.branch3x3dbl_3a(branch3x3dbl), 297 | self.branch3x3dbl_3b(branch3x3dbl), 298 | ] 299 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 300 | 301 | # Patch: The FID Inception model uses max pooling instead of average 302 | # pooling. This is likely an error in this specific Inception 303 | # implementation, as other Inception models use average pooling here 304 | # (which matches the description in the paper). 305 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 306 | branch_pool = self.branch_pool(branch_pool) 307 | 308 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 309 | return torch.cat(outputs, 1) -------------------------------------------------------------------------------- /utils/inception_score.py: -------------------------------------------------------------------------------- 1 | # Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import math 7 | import os 8 | import os.path 9 | import sys 10 | import tarfile 11 | 12 | import numpy as np 13 | import tensorflow.compat.v1 as tf 14 | tf.disable_v2_behavior() 15 | from six.moves import urllib 16 | from tqdm import tqdm 17 | 18 | 19 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 20 | MODEL_DIR = '/tmp/imagenet' 21 | DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 22 | softmax = None 23 | config = tf.ConfigProto() 24 | # config = tf.ConfigProto(device_count = {'GPU': 0}) 25 | config.gpu_options.visible_device_list= '0' 26 | config.gpu_options.allow_growth = True 27 | 28 | 29 | # Call this function with list of images. Each of elements should be a 30 | # numpy array with values ranging from 0 to 255. 31 | def get_inception_score(images, splits=10): 32 | assert (type(images) == list) 33 | assert (type(images[0]) == np.ndarray) 34 | assert (len(images[0].shape) == 3) 35 | assert (np.max(images[0]) > 10) 36 | assert (np.min(images[0]) >= 0.0) 37 | inps = [] 38 | for img in images: 39 | img = img.astype(np.float32) 40 | inps.append(np.expand_dims(img, 0)) 41 | bs = 128 42 | with tf.Session(config=config) as sess: 43 | preds = [] 44 | n_batches = int(math.ceil(float(len(inps)) / float(bs))) 45 | for i in tqdm(range(n_batches), desc="Calculate inception score"): 46 | sys.stdout.flush() 47 | inp = inps[(i * bs):min((i + 1) * bs, len(inps))] 48 | inp = np.concatenate(inp, 0) 49 | pred = sess.run(softmax, {'ExpandDims:0': inp}) 50 | preds.append(pred) 51 | preds = np.concatenate(preds, 0) 52 | scores = [] 53 | for i in range(splits): 54 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] 55 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 56 | kl = np.mean(np.sum(kl, 1)) 57 | scores.append(np.exp(kl)) 58 | 59 | sess.close() 60 | return np.mean(scores), np.std(scores) 61 | 62 | 63 | # This function is called automatically. 64 | def _init_inception(): 65 | global softmax 66 | if not os.path.exists(MODEL_DIR): 67 | os.makedirs(MODEL_DIR) 68 | filename = DATA_URL.split('/')[-1] 69 | filepath = os.path.join(MODEL_DIR, filename) 70 | if not os.path.exists(filepath): 71 | def _progress(count, block_size, total_size): 72 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 73 | filename, float(count * block_size) / float(total_size) * 100.0)) 74 | sys.stdout.flush() 75 | 76 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 77 | print() 78 | statinfo = os.stat(filepath) 79 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') 80 | tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR) 81 | with tf.gfile.FastGFile(os.path.join( 82 | MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f: 83 | graph_def = tf.GraphDef() 84 | graph_def.ParseFromString(f.read()) 85 | _ = tf.import_graph_def(graph_def, name='') 86 | # Works with an arbitrary minibatch size. 87 | with tf.Session(config=config) as sess: 88 | pool3 = sess.graph.get_tensor_by_name('pool_3:0') 89 | ops = pool3.graph.get_operations() 90 | for op_idx, op in enumerate(ops): 91 | for o in op.outputs: 92 | shape = o.get_shape() 93 | if shape._dims != []: 94 | shape = [s.value for s in shape] 95 | new_shape = [] 96 | for j, s in enumerate(shape): 97 | if s == 1 and j == 0: 98 | new_shape.append(None) 99 | else: 100 | new_shape.append(s) 101 | o.__dict__['_shape_val'] = tf.TensorShape(new_shape) 102 | w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1] 103 | logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w) 104 | softmax = tf.nn.softmax(logits) 105 | sess.close() 106 | -------------------------------------------------------------------------------- /utils/torch_fid_score.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs 3 | The FID metric calculates the distance between two distributions of images. 4 | Typically, we have summary statistics (mean & covariance matrix) of one 5 | of these distributions, while the 2nd distribution is given by a GAN. 6 | When run as a stand-alone program, it compares the distribution of 7 | images that are stored as PNG/JPEG at a specified location with a 8 | distribution given by summary statistics (in pickle format). 9 | The FID is calculated by assuming that X_1 and X_2 are the activations of 10 | the pool_3 layer of the inception net for generated samples and real world 11 | samples respectively. 12 | See --help to see further details. 13 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 14 | of Tensorflow 15 | Copyright 2018 Institute of Bioinformatics, JKU Linz 16 | Licensed under the Apache License, Version 2.0 (the "License"); 17 | you may not use this file except in compliance with the License. 18 | You may obtain a copy of the License at 19 | http://www.apache.org/licenses/LICENSE-2.0 20 | Unless required by applicable law or agreed to in writing, software 21 | distributed under the License is distributed on an "AS IS" BASIS, 22 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 23 | See the License for the specific language governing permissions and 24 | limitations under the License. 25 | """ 26 | import os 27 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 28 | 29 | import numpy as np 30 | import torch 31 | from utils.inception import InceptionV3 32 | from torch.nn.functional import adaptive_avg_pool2d 33 | 34 | try: 35 | from tqdm import tqdm 36 | except ImportError: 37 | # If not tqdm is not available, provide a mock version of it 38 | def tqdm(x): 39 | return x 40 | 41 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 42 | parser.add_argument('path', type=str, nargs=2, 43 | help=('Path to the generated images or ' 44 | 'to .npz statistic files')) 45 | parser.add_argument('--batch-size', type=int, default=50, 46 | help='Batch size to use') 47 | parser.add_argument('--dims', type=int, default=2048, 48 | choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), 49 | help=('Dimensionality of Inception features to use. ' 50 | 'By default, uses pool3 features')) 51 | parser.add_argument('-c', '--gpu', default='1', type=str, 52 | help='GPU to use (leave blank for CPU only)') 53 | 54 | def _get_no_grad_ctx_mgr(require_grad): 55 | """Returns a the `torch.no_grad` context manager for PyTorch version >= 56 | 0.4, or a no-op context manager otherwise. 57 | """ 58 | if not require_grad and float(torch.__version__[0:3]) >= 0.4: 59 | return torch.no_grad() 60 | 61 | return contextlib.suppress() 62 | 63 | # Pytorch implementation of matrix sqrt, from Tsung-Yu Lin, and Subhransu Maji 64 | # https://github.com/msubhransu/matrix-sqrt 65 | def sqrt_newton_schulz(A, numIters, dtype=None): 66 | if dtype is None: 67 | dtype = A.type() 68 | batchSize = A.shape[0] 69 | dim = A.shape[1] 70 | normA = A.mul(A).sum(dim=1).sum(dim=1).sqrt() 71 | Y = A.div(normA.view(batchSize, 1, 1).expand_as(A)).to("cuda:0") 72 | I = torch.eye(dim, dim).view(1, dim, dim).repeat(batchSize, 1, 1).type(dtype).to("cuda:0") 73 | Z = torch.eye(dim, dim).view(1, dim, dim).repeat(batchSize, 1, 1).type(dtype).to("cuda:0") 74 | for i in range(numIters): 75 | T = 0.5 * (3.0 * I - Z.bmm(Y)) 76 | Y = Y.bmm(T) 77 | Z = T.bmm(Z) 78 | sA = Y * torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A) 79 | return sA 80 | 81 | 82 | # A pytorch implementation of cov, from Modar M. Alfadly 83 | # https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2 84 | def torch_cov(m, rowvar=False): 85 | '''Estimate a covariance matrix given data. 86 | Covariance indicates the level to which two variables vary together. 87 | If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, 88 | then the covariance matrix element `C_{ij}` is the covariance of 89 | `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. 90 | Args: 91 | m: A 1-D or 2-D array containing multiple variables and observations. 92 | Each row of `m` represents a variable, and each column a single 93 | observation of all those variables. 94 | rowvar: If `rowvar` is True, then each row represents a 95 | variable, with observations in the columns. Otherwise, the 96 | relationship is transposed: each column represents a variable, 97 | while the rows contain observations. 98 | Returns: 99 | The covariance matrix of the variables. 100 | ''' 101 | if m.dim() > 2: 102 | raise ValueError('m has more than 2 dimensions') 103 | if m.dim() < 2: 104 | m = m.view(1, -1) 105 | if not rowvar and m.size(0) != 1: 106 | m = m.t() 107 | # m = m.type(torch.double) # uncomment this line if desired 108 | fact = 1.0 / (m.size(1) - 1) 109 | m -= torch.mean(m, dim=1, keepdim=True) 110 | mt = m.t() # if complex: mt = m.t().conj() 111 | return fact * m.matmul(mt).squeeze() 112 | 113 | 114 | def get_activations(args, gen_net, model, batch_size=50, dims=2048, 115 | cuda=False, verbose=False): 116 | """Calculates the activations of the pool_3 layer for all images. 117 | Params: 118 | -- files : List of image files paths 119 | -- model : Instance of inception model 120 | -- batch_size : Batch size of images for the model to process at once. 121 | Make sure that the number of samples is a multiple of 122 | the batch size, otherwise some samples are ignored. This 123 | behavior is retained to match the original FID score 124 | implementation. 125 | -- dims : Dimensionality of features returned by Inception 126 | -- cuda : If set to True, use GPU 127 | -- verbose : If set to True and parameter out_step is given, the number 128 | of calculated batches is reported. 129 | Returns: 130 | -- A numpy array of dimension (num images, dims) that contains the 131 | activations of the given tensor when feeding inception with the 132 | query tensor. 133 | """ 134 | with torch.no_grad(): 135 | gen_net.eval() 136 | model.eval() 137 | 138 | # if gen_imgs.shape[0] % batch_size != 0: 139 | # print(('Warning: number of images is not a multiple of the ' 140 | # 'batch size. Some samples are going to be ignored.')) 141 | # if batch_size > gen_imgs.shape[0]: 142 | # print(('Warning: batch size is bigger than the data size. ' 143 | # 'Setting batch size to data size')) 144 | # batch_size = gen_imgs.shape[0] 145 | 146 | n_batches = args.num_eval_imgs // batch_size 147 | 148 | # normalize 149 | 150 | pred_arr = [] 151 | for i in tqdm(range(n_batches)): 152 | z = torch.cuda.FloatTensor(np.random.normal(0, 1, (batch_size, args.latent_dim))) 153 | gen_imgs = gen_net(z, 200) 154 | 155 | if verbose: 156 | print('\rPropagating batch %d/%d' % (i + 1, n_batches), 157 | end='', flush=True) 158 | start = i * batch_size 159 | end = start + batch_size 160 | 161 | images = (gen_imgs + 1.0) / 2.0 162 | model.to("cuda:0") 163 | pred = model(images.to("cuda:0"))[0] 164 | 165 | # If model output is not scalar, apply global spatial average pooling. 166 | # This happens if you choose a dimensionality not equal 2048. 167 | if pred.shape[2] != 1 or pred.shape[3] != 1: 168 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 169 | 170 | pred_arr += [pred.view(batch_size, -1)] 171 | 172 | if verbose: 173 | print('done') 174 | del images 175 | 176 | return torch.cat(pred_arr, dim=0) 177 | 178 | 179 | def torch_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 180 | """Pytorch implementation of the Frechet Distance. 181 | Taken from https://github.com/bioinf-jku/TTUR 182 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 183 | and X_2 ~ N(mu_2, C_2) is 184 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 185 | Stable version by Dougal J. Sutherland. 186 | Params: 187 | -- mu1 : Numpy array containing the activations of a layer of the 188 | inception net (like returned by the function 'get_predictions') 189 | for generated samples. 190 | -- mu2 : The sample mean over activations, precalculated on an 191 | representive data set. 192 | -- sigma1: The covariance matrix over activations for generated samples. 193 | -- sigma2: The covariance matrix over activations, precalculated on an 194 | representive data set. 195 | Returns: 196 | -- : The Frechet Distance. 197 | """ 198 | 199 | assert mu1.shape == mu2.shape, \ 200 | 'Training and test mean vectors have different lengths' 201 | assert sigma1.shape == sigma2.shape, \ 202 | 'Training and test covariances have different dimensions' 203 | 204 | diff = mu1 - mu2 205 | # Run 50 itrs of newton-schulz to get the matrix sqrt of sigma1 dot sigma2 206 | covmean = sqrt_newton_schulz(sigma1.mm(sigma2).unsqueeze(0), 50).squeeze() 207 | out = (diff.dot(diff) + torch.trace(sigma1) + torch.trace(sigma2) 208 | - 2 * torch.trace(covmean)) 209 | return out 210 | 211 | 212 | def calculate_activation_statistics(gen_net, model, batch_size=50, 213 | dims=2048, cuda=False, verbose=False): 214 | """Calculation of the statistics used by the FID. 215 | Params: 216 | -- gen_imgs : gen_imgs, tensor 217 | -- model : Instance of inception model 218 | -- batch_size : The images numpy array is split into batches with 219 | batch size batch_size. A reasonable batch size 220 | depends on the hardware. 221 | -- dims : Dimensionality of features returned by Inception 222 | -- cuda : If set to True, use GPU 223 | -- verbose : If set to True and parameter out_step is given, the 224 | number of calculated batches is reported. 225 | Returns: 226 | -- mu : The mean over samples of the activations of the pool_3 layer of 227 | the inception model. 228 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 229 | the inception model. 230 | """ 231 | act = get_activations(gen_net, model, batch_size, dims, cuda, verbose) 232 | mu = torch.mean(act, dim=0) 233 | sigma = torch_cov(act, rowvar=False) 234 | return mu, sigma 235 | 236 | 237 | def _compute_statistics_of_path(args, path, model, batch_size, dims, cuda): 238 | if isinstance(path, str): 239 | assert path.endswith('.npz') 240 | f = np.load(path) 241 | if 'mean' in f: 242 | m, s = f['mean'][:], f['cov'][:] 243 | else: 244 | m, s = f['mu'][:], f['sigma'][:] 245 | f.close() 246 | else: 247 | # a tensor 248 | gen_net = path 249 | m, s = calculate_activation_statistics(args, gen_net, model, batch_size, 250 | dims, cuda) 251 | 252 | return m, s 253 | 254 | 255 | def calculate_fid_given_paths_torch(args, gen_net, path, require_grad=False, gen_batch_size=1, batch_size=1, cuda=True, dims=2048): 256 | """ 257 | Calculates the FID of two paths 258 | :param gen_imgs: The value range of gen_imgs should be (-1, 1). Just the output of tanh. 259 | :param path: fid file path. *.npz. 260 | :param batch_size: 261 | :param cuda: 262 | :param dims: 263 | :return: 264 | """ 265 | if not os.path.exists(path): 266 | raise RuntimeError('Invalid path: %s' % path) 267 | 268 | assert args.num_eval_imgs >= dims, f'gen_imgs size: {args.num_eval_imgs}' # or will lead to nan 269 | 270 | with _get_no_grad_ctx_mgr(require_grad=require_grad): 271 | 272 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 273 | 274 | model = InceptionV3([block_idx]) 275 | if cuda: 276 | model.cuda() 277 | 278 | m1, s1 = _compute_statistics_of_path(args, gen_net, model, batch_size, 279 | dims, cuda) 280 | # print(f'generated stat: {m1}, {s1}') 281 | m2, s2 = _compute_statistics_of_path(args, path, model, batch_size, 282 | dims, cuda) 283 | # print(f'GT stat: {m2}, {s2}') 284 | fid_value = torch_calculate_frechet_distance(m1.to("cuda:0"), s1.to("cuda:0"), torch.tensor(m2).float().cuda().to("cuda:0"), 285 | torch.tensor(s2).float().cuda().to("cuda:0")) 286 | del model 287 | 288 | return fid_value 289 | 290 | 291 | def get_fid(args, fid_stat, epoch, gen_net, num_img, gen_batch_size, val_batch_size, writer_dict=None, cls_idx=None): 292 | gen_net.eval() 293 | with torch.no_grad(): 294 | # eval mode 295 | gen_net.eval() 296 | 297 | # eval_iter = num_img // gen_batch_size 298 | # img_list = [] 299 | # for _ in tqdm(range(eval_iter), desc='sample images'): 300 | # z = torch.cuda.FloatTensor(np.random.normal(0, 1, (gen_batch_size, args.latent_dim))) 301 | 302 | # # Generate a batch of images 303 | # if args.n_classes > 0: 304 | # if cls_idx is not None: 305 | # label = torch.ones(z.shape[0]) * cls_idx 306 | # label = label.type(torch.cuda.LongTensor) 307 | # else: 308 | # label = torch.randint(low=0, high=args.n_classes, size=(z.shape[0],), device='cuda') 309 | # gen_imgs = gen_net(z, epoch) 310 | # else: 311 | # gen_imgs = gen_net(z, epoch) 312 | # if isinstance(gen_imgs, tuple): 313 | # gen_imgs = gen_imgs[0] 314 | # img_list += [gen_imgs] 315 | 316 | # img_list = torch.cat(img_list, 0) 317 | fid_score = calculate_fid_given_paths_torch(args, gen_net, fid_stat, gen_batch_size=gen_batch_size, batch_size=val_batch_size) 318 | 319 | if writer_dict: 320 | writer = writer_dict['writer'] 321 | global_steps = writer_dict['valid_global_steps'] 322 | writer.add_scalar('FID_score', fid_score, global_steps) 323 | writer_dict['valid_global_steps'] = global_steps + 1 324 | 325 | return fid_score -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2019-07-25 3 | # @Author : Xinyu Gong (xy_gong@tamu.edu) 4 | # @Link : None 5 | # @Version : 0.0 6 | 7 | import collections 8 | import logging 9 | import math 10 | import os 11 | import time 12 | from datetime import datetime 13 | 14 | import dateutil.tz 15 | import torch 16 | 17 | from typing import Union, Optional, List, Tuple, Text, BinaryIO 18 | import pathlib 19 | import torch 20 | import math 21 | import warnings 22 | import numpy as np 23 | from PIL import Image, ImageDraw, ImageFont, ImageColor 24 | 25 | @torch.no_grad() 26 | def make_grid( 27 | tensor: Union[torch.Tensor, List[torch.Tensor]], 28 | nrow: int = 8, 29 | padding: int = 2, 30 | normalize: bool = False, 31 | value_range: Optional[Tuple[int, int]] = None, 32 | scale_each: bool = False, 33 | pad_value: int = 0, 34 | **kwargs 35 | ) -> torch.Tensor: 36 | """ 37 | Make a grid of images. 38 | Args: 39 | tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) 40 | or a list of images all of the same size. 41 | nrow (int, optional): Number of images displayed in each row of the grid. 42 | The final grid size is ``(B / nrow, nrow)``. Default: ``8``. 43 | padding (int, optional): amount of padding. Default: ``2``. 44 | normalize (bool, optional): If True, shift the image to the range (0, 1), 45 | by the min and max values specified by :attr:`range`. Default: ``False``. 46 | value_range (tuple, optional): tuple (min, max) where min and max are numbers, 47 | then these numbers are used to normalize the image. By default, min and max 48 | are computed from the tensor. 49 | scale_each (bool, optional): If ``True``, scale each image in the batch of 50 | images separately rather than the (min, max) over all images. Default: ``False``. 51 | pad_value (float, optional): Value for the padded pixels. Default: ``0``. 52 | Returns: 53 | grid (Tensor): the tensor containing grid of images. 54 | Example: 55 | See this notebook 56 | `here `_ 57 | """ 58 | if not (torch.is_tensor(tensor) or 59 | (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): 60 | raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') 61 | 62 | if "range" in kwargs.keys(): 63 | warning = "range will be deprecated, please use value_range instead." 64 | warnings.warn(warning) 65 | value_range = kwargs["range"] 66 | 67 | # if list of tensors, convert to a 4D mini-batch Tensor 68 | if isinstance(tensor, list): 69 | tensor = torch.stack(tensor, dim=0) 70 | 71 | if tensor.dim() == 2: # single image H x W 72 | tensor = tensor.unsqueeze(0) 73 | if tensor.dim() == 3: # single image 74 | if tensor.size(0) == 1: # if single-channel, convert to 3-channel 75 | tensor = torch.cat((tensor, tensor, tensor), 0) 76 | tensor = tensor.unsqueeze(0) 77 | 78 | if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images 79 | tensor = torch.cat((tensor, tensor, tensor), 1) 80 | 81 | if normalize is True: 82 | tensor = tensor.clone() # avoid modifying tensor in-place 83 | if value_range is not None: 84 | assert isinstance(value_range, tuple), \ 85 | "value_range has to be a tuple (min, max) if specified. min and max are numbers" 86 | 87 | def norm_ip(img, low, high): 88 | img.clamp(min=low, max=high) 89 | img.sub_(low).div_(max(high - low, 1e-5)) 90 | 91 | def norm_range(t, value_range): 92 | if value_range is not None: 93 | norm_ip(t, value_range[0], value_range[1]) 94 | else: 95 | norm_ip(t, float(t.min()), float(t.max())) 96 | 97 | if scale_each is True: 98 | for t in tensor: # loop over mini-batch dimension 99 | norm_range(t, value_range) 100 | else: 101 | norm_range(tensor, value_range) 102 | 103 | if tensor.size(0) == 1: 104 | return tensor.squeeze(0) 105 | 106 | # make the mini-batch of images into a grid 107 | nmaps = tensor.size(0) 108 | xmaps = min(nrow, nmaps) 109 | ymaps = int(math.ceil(float(nmaps) / xmaps)) 110 | height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) 111 | num_channels = tensor.size(1) 112 | grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value) 113 | k = 0 114 | for y in range(ymaps): 115 | for x in range(xmaps): 116 | if k >= nmaps: 117 | break 118 | # Tensor.copy_() is a valid method but seems to be missing from the stubs 119 | # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_ 120 | grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined] 121 | 2, x * width + padding, width - padding 122 | ).copy_(tensor[k]) 123 | k = k + 1 124 | return grid 125 | 126 | 127 | @torch.no_grad() 128 | def save_image( 129 | tensor: Union[torch.Tensor, List[torch.Tensor]], 130 | fp: Union[Text, pathlib.Path, BinaryIO], 131 | format: Optional[str] = None, 132 | **kwargs 133 | ) -> None: 134 | """ 135 | Save a given Tensor into an image file. 136 | Args: 137 | tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, 138 | saves the tensor as a grid of images by calling ``make_grid``. 139 | fp (string or file object): A filename or a file object 140 | format(Optional): If omitted, the format to use is determined from the filename extension. 141 | If a file object was used instead of a filename, this parameter should always be used. 142 | **kwargs: Other arguments are documented in ``make_grid``. 143 | """ 144 | 145 | grid = make_grid(tensor, **kwargs) 146 | # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer 147 | ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() 148 | im = Image.fromarray(ndarr) 149 | im.save(fp, format=format) 150 | 151 | 152 | def create_logger(log_dir, phase='train'): 153 | time_str = time.strftime('%Y-%m-%d-%H-%M') 154 | log_file = '{}_{}.log'.format(time_str, phase) 155 | final_log_file = os.path.join(log_dir, log_file) 156 | head = '%(asctime)-15s %(message)s' 157 | logging.basicConfig(filename=str(final_log_file), 158 | format=head) 159 | logger = logging.getLogger() 160 | logger.setLevel(logging.INFO) 161 | console = logging.StreamHandler() 162 | logging.getLogger('').addHandler(console) 163 | 164 | return logger 165 | 166 | 167 | def set_log_dir(root_dir, exp_name): 168 | path_dict = {} 169 | os.makedirs(root_dir, exist_ok=True) 170 | 171 | # set log path 172 | exp_path = os.path.join(root_dir, exp_name) 173 | now = datetime.now(dateutil.tz.tzlocal()) 174 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 175 | prefix = exp_path + '_' + timestamp 176 | os.makedirs(prefix) 177 | path_dict['prefix'] = prefix 178 | 179 | # set checkpoint path 180 | ckpt_path = os.path.join(prefix, 'Model') 181 | os.makedirs(ckpt_path) 182 | path_dict['ckpt_path'] = ckpt_path 183 | 184 | log_path = os.path.join(prefix, 'Log') 185 | os.makedirs(log_path) 186 | path_dict['log_path'] = log_path 187 | 188 | # set sample image path for fid calculation 189 | sample_path = os.path.join(prefix, 'Samples') 190 | os.makedirs(sample_path) 191 | path_dict['sample_path'] = sample_path 192 | 193 | return path_dict 194 | 195 | 196 | def save_checkpoint(states, is_best, output_dir, 197 | filename='checkpoint.pth'): 198 | torch.save(states, os.path.join(output_dir, filename)) 199 | if is_best: 200 | torch.save(states, os.path.join(output_dir, 'checkpoint_best.pth')) 201 | 202 | 203 | class RunningStats: 204 | def __init__(self, WIN_SIZE): 205 | self.mean = 0 206 | self.run_var = 0 207 | self.WIN_SIZE = WIN_SIZE 208 | 209 | self.window = collections.deque(maxlen=WIN_SIZE) 210 | 211 | def clear(self): 212 | self.window.clear() 213 | self.mean = 0 214 | self.run_var = 0 215 | 216 | def is_full(self): 217 | return len(self.window) == self.WIN_SIZE 218 | 219 | def push(self, x): 220 | 221 | if len(self.window) == self.WIN_SIZE: 222 | # Adjusting variance 223 | x_removed = self.window.popleft() 224 | self.window.append(x) 225 | old_m = self.mean 226 | self.mean += (x - x_removed) / self.WIN_SIZE 227 | self.run_var += (x + x_removed - old_m - self.mean) * (x - x_removed) 228 | else: 229 | # Calculating first variance 230 | self.window.append(x) 231 | delta = x - self.mean 232 | self.mean += delta / len(self.window) 233 | self.run_var += delta * (x - self.mean) 234 | 235 | def get_mean(self): 236 | return self.mean if len(self.window) else 0.0 237 | 238 | def get_var(self): 239 | return self.run_var / len(self.window) if len(self.window) > 1 else 0.0 240 | 241 | def get_std(self): 242 | return math.sqrt(self.get_var()) 243 | 244 | def get_all(self): 245 | return list(self.window) 246 | 247 | def __str__(self): 248 | return "Current window values: {}".format(list(self.window)) 249 | -------------------------------------------------------------------------------- /visualizationMetrics.py: -------------------------------------------------------------------------------- 1 | """Time-series Generative Adversarial Networks (TimeGAN) Codebase. 2 | Reference: Jinsung Yoon, Daniel Jarrett, Mihaela van der Schaar, 3 | "Time-series Generative Adversarial Networks," 4 | Neural Information Processing Systems (NeurIPS), 2019. 5 | Paper link: https://papers.nips.cc/paper/8789-time-series-generative-adversarial-networks 6 | Last updated Date: April 24th 2020 7 | Code author: Jinsung Yoon (jsyoon0823@gmail.com) 8 | ----------------------------- 9 | visualization_metrics.py 10 | Note: Use PCA or tSNE for generated and original data visualization 11 | """ 12 | 13 | # Necessary packages 14 | from sklearn.manifold import TSNE 15 | from sklearn.decomposition import PCA 16 | import matplotlib.pyplot as plt 17 | import numpy as np 18 | 19 | 20 | def visualization (ori_data, generated_data, analysis, save_name): 21 | """Using PCA or tSNE for generated and original data visualization. 22 | 23 | Args: 24 | - ori_data: original data 25 | - generated_data: generated synthetic data 26 | - analysis: tsne or pca 27 | """ 28 | # Analysis sample size (for faster computation) 29 | anal_sample_no = min([1000, len(ori_data)]) 30 | idx = np.random.permutation(len(ori_data))[:anal_sample_no] 31 | 32 | # Data preprocessing 33 | ori_data = np.asarray(ori_data) 34 | generated_data = np.asarray(generated_data) 35 | 36 | ori_data = ori_data[idx] 37 | generated_data = generated_data[idx] 38 | 39 | no, seq_len, dim = ori_data.shape 40 | 41 | for i in range(anal_sample_no): 42 | if (i == 0): 43 | prep_data = np.reshape(np.mean(ori_data[0,:,:], 1), [1,seq_len]) 44 | prep_data_hat = np.reshape(np.mean(generated_data[0,:,:],1), [1,seq_len]) 45 | else: 46 | prep_data = np.concatenate((prep_data, 47 | np.reshape(np.mean(ori_data[i,:,:],1), [1,seq_len]))) 48 | prep_data_hat = np.concatenate((prep_data_hat, 49 | np.reshape(np.mean(generated_data[i,:,:],1), [1,seq_len]))) 50 | 51 | # Visualization parameter 52 | colors = ["red" for i in range(anal_sample_no)] + ["blue" for i in range(anal_sample_no)] 53 | 54 | if analysis == 'pca': 55 | # PCA Analysis 56 | pca = PCA(n_components = 2) 57 | pca.fit(prep_data) 58 | pca_results = pca.transform(prep_data) 59 | pca_hat_results = pca.transform(prep_data_hat) 60 | 61 | # Plotting 62 | f, ax = plt.subplots(1) 63 | plt.scatter(pca_results[:,0], pca_results[:,1], 64 | c = colors[:anal_sample_no], alpha = 0.2, label = "Original") 65 | plt.scatter(pca_hat_results[:,0], pca_hat_results[:,1], 66 | c = colors[anal_sample_no:], alpha = 0.2, label = "Synthetic") 67 | 68 | ax.legend() 69 | plt.title('PCA plot') 70 | plt.xlabel('x-pca') 71 | plt.ylabel('y_pca') 72 | # plt.show() 73 | 74 | elif analysis == 'tsne': 75 | 76 | # Do t-SNE Analysis together 77 | prep_data_final = np.concatenate((prep_data, prep_data_hat), axis = 0) 78 | 79 | # TSNE anlaysis 80 | tsne = TSNE(n_components = 2, verbose = 1, perplexity = 40, n_iter = 300) 81 | tsne_results = tsne.fit_transform(prep_data_final) 82 | 83 | # Plotting 84 | f, ax = plt.subplots(1) 85 | 86 | plt.scatter(tsne_results[:anal_sample_no,0], tsne_results[:anal_sample_no,1], 87 | c = colors[:anal_sample_no], alpha = 0.2, label = "Original") 88 | plt.scatter(tsne_results[anal_sample_no:,0], tsne_results[anal_sample_no:,1], 89 | c = colors[anal_sample_no:], alpha = 0.2, label = "Synthetic") 90 | 91 | ax.legend() 92 | 93 | plt.title('t-SNE plot') 94 | plt.xlabel('x-tsne') 95 | plt.ylabel('y_tsne') 96 | # plt.show() 97 | 98 | plt.savefig(f'./images/{save_name}.pdf', format="pdf") 99 | plt.show() --------------------------------------------------------------------------------