├── .gitignore
├── GANModels.py
├── JumpingGAN_Train.py
├── LICENSE
├── LoadRealRunningJumping.py
├── LoadSyntheticRunningJumping.py
├── README.md
├── Running&JumpingVisualization.ipynb
├── RunningGAN_Train.py
├── adamw.py
├── cfg.py
├── dataLoader.py
├── functions.py
├── images
├── PositionalEncoding.pdf
├── PositionalEncoding.png
├── TTS-GAN.pdf
└── TTS-GAN.png
├── pre-trained-models
├── JumpingGAN_checkpoint
└── RunningGAN_checkpoint
├── train_GAN.py
├── utils
├── __init__.py
├── cal_fid_stat.py
├── fid_score.py
├── inception.py
├── inception_model.py
├── inception_score.py
├── torch_fid_score.py
└── utils.py
└── visualizationMetrics.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/GANModels.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch import Tensor
5 | import math
6 | import numpy as np
7 |
8 | from torchvision.transforms import Compose, Resize, ToTensor
9 | from einops import rearrange, reduce, repeat
10 | from einops.layers.torch import Rearrange, Reduce
11 | from torchsummary import summary
12 |
13 |
14 | class Generator(nn.Module):
15 | def __init__(self, seq_len=150, patch_size=15, channels=3, num_classes=9, latent_dim=100, embed_dim=10, depth=3,
16 | num_heads=5, forward_drop_rate=0.5, attn_drop_rate=0.5):
17 | super(Generator, self).__init__()
18 | self.channels = channels
19 | self.latent_dim = latent_dim
20 | self.seq_len = seq_len
21 | self.embed_dim = embed_dim
22 | self.patch_size = patch_size
23 | self.depth = depth
24 | self.attn_drop_rate = attn_drop_rate
25 | self.forward_drop_rate = forward_drop_rate
26 |
27 | self.l1 = nn.Linear(self.latent_dim, self.seq_len * self.embed_dim)
28 | self.pos_embed = nn.Parameter(torch.zeros(1, self.seq_len, self.embed_dim))
29 | self.blocks = Gen_TransformerEncoder(
30 | depth=self.depth,
31 | emb_size = self.embed_dim,
32 | drop_p = self.attn_drop_rate,
33 | forward_drop_p=self.forward_drop_rate
34 | )
35 |
36 | self.deconv = nn.Sequential(
37 | nn.Conv2d(self.embed_dim, self.channels, 1, 1, 0)
38 | )
39 |
40 | def forward(self, z):
41 | x = self.l1(z).view(-1, self.seq_len, self.embed_dim)
42 | x = x + self.pos_embed
43 | H, W = 1, self.seq_len
44 | x = self.blocks(x)
45 | x = x.reshape(x.shape[0], 1, x.shape[1], x.shape[2])
46 | output = self.deconv(x.permute(0, 3, 1, 2))
47 | output = output.view(-1, self.channels, H, W)
48 | return output
49 |
50 |
51 | class Gen_TransformerEncoderBlock(nn.Sequential):
52 | def __init__(self,
53 | emb_size,
54 | num_heads=5,
55 | drop_p=0.5,
56 | forward_expansion=4,
57 | forward_drop_p=0.5):
58 | super().__init__(
59 | ResidualAdd(nn.Sequential(
60 | nn.LayerNorm(emb_size),
61 | MultiHeadAttention(emb_size, num_heads, drop_p),
62 | nn.Dropout(drop_p)
63 | )),
64 | ResidualAdd(nn.Sequential(
65 | nn.LayerNorm(emb_size),
66 | FeedForwardBlock(
67 | emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
68 | nn.Dropout(drop_p)
69 | )
70 | ))
71 |
72 |
73 | class Gen_TransformerEncoder(nn.Sequential):
74 | def __init__(self, depth=8, **kwargs):
75 | super().__init__(*[Gen_TransformerEncoderBlock(**kwargs) for _ in range(depth)])
76 |
77 |
78 | class MultiHeadAttention(nn.Module):
79 | def __init__(self, emb_size, num_heads, dropout):
80 | super().__init__()
81 | self.emb_size = emb_size
82 | self.num_heads = num_heads
83 | self.keys = nn.Linear(emb_size, emb_size)
84 | self.queries = nn.Linear(emb_size, emb_size)
85 | self.values = nn.Linear(emb_size, emb_size)
86 | self.att_drop = nn.Dropout(dropout)
87 | self.projection = nn.Linear(emb_size, emb_size)
88 |
89 | def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
90 | queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
91 | keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
92 | values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
93 | energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
94 | if mask is not None:
95 | fill_value = torch.finfo(torch.float32).min
96 | energy.mask_fill(~mask, fill_value)
97 |
98 | scaling = self.emb_size ** (1 / 2)
99 | att = F.softmax(energy / scaling, dim=-1)
100 | att = self.att_drop(att)
101 | out = torch.einsum('bhal, bhlv -> bhav ', att, values)
102 | out = rearrange(out, "b h n d -> b n (h d)")
103 | out = self.projection(out)
104 | return out
105 |
106 |
107 | class ResidualAdd(nn.Module):
108 | def __init__(self, fn):
109 | super().__init__()
110 | self.fn = fn
111 |
112 | def forward(self, x, **kwargs):
113 | res = x
114 | x = self.fn(x, **kwargs)
115 | x += res
116 | return x
117 |
118 |
119 | class FeedForwardBlock(nn.Sequential):
120 | def __init__(self, emb_size, expansion, drop_p):
121 | super().__init__(
122 | nn.Linear(emb_size, expansion * emb_size),
123 | nn.GELU(),
124 | nn.Dropout(drop_p),
125 | nn.Linear(expansion * emb_size, emb_size),
126 | )
127 |
128 |
129 |
130 | class Dis_TransformerEncoderBlock(nn.Sequential):
131 | def __init__(self,
132 | emb_size=100,
133 | num_heads=5,
134 | drop_p=0.,
135 | forward_expansion=4,
136 | forward_drop_p=0.):
137 | super().__init__(
138 | ResidualAdd(nn.Sequential(
139 | nn.LayerNorm(emb_size),
140 | MultiHeadAttention(emb_size, num_heads, drop_p),
141 | nn.Dropout(drop_p)
142 | )),
143 | ResidualAdd(nn.Sequential(
144 | nn.LayerNorm(emb_size),
145 | FeedForwardBlock(
146 | emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
147 | nn.Dropout(drop_p)
148 | )
149 | ))
150 |
151 |
152 | class Dis_TransformerEncoder(nn.Sequential):
153 | def __init__(self, depth=8, **kwargs):
154 | super().__init__(*[Dis_TransformerEncoderBlock(**kwargs) for _ in range(depth)])
155 |
156 |
157 | class ClassificationHead(nn.Sequential):
158 | def __init__(self, emb_size=100, n_classes=2):
159 | super().__init__()
160 | self.clshead = nn.Sequential(
161 | Reduce('b n e -> b e', reduction='mean'),
162 | nn.LayerNorm(emb_size),
163 | nn.Linear(emb_size, n_classes)
164 | )
165 |
166 | def forward(self, x):
167 | out = self.clshead(x)
168 | return out
169 |
170 |
171 | class PatchEmbedding_Linear(nn.Module):
172 | #what are the proper parameters set here?
173 | def __init__(self, in_channels = 21, patch_size = 16, emb_size = 100, seq_length = 1024):
174 | # self.patch_size = patch_size
175 | super().__init__()
176 | #change the conv2d parameters here
177 | self.projection = nn.Sequential(
178 | Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)',s1 = 1, s2 = patch_size),
179 | nn.Linear(patch_size*in_channels, emb_size)
180 | )
181 | self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
182 | self.positions = nn.Parameter(torch.randn((seq_length // patch_size) + 1, emb_size))
183 |
184 | def forward(self, x: Tensor) -> Tensor:
185 | b, _, _, _ = x.shape
186 | x = self.projection(x)
187 | cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
188 | #prepend the cls token to the input
189 | x = torch.cat([cls_tokens, x], dim=1)
190 | # position
191 | x += self.positions
192 | return x
193 |
194 |
195 | class Discriminator(nn.Sequential):
196 | def __init__(self,
197 | in_channels=3,
198 | patch_size=15,
199 | emb_size=50,
200 | seq_length = 150,
201 | depth=3,
202 | n_classes=1,
203 | **kwargs):
204 | super().__init__(
205 | PatchEmbedding_Linear(in_channels, patch_size, emb_size, seq_length),
206 | Dis_TransformerEncoder(depth, emb_size=emb_size, drop_p=0.5, forward_drop_p=0.5, **kwargs),
207 | ClassificationHead(emb_size, n_classes)
208 | )
209 |
210 |
--------------------------------------------------------------------------------
/JumpingGAN_Train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | import os
4 | import argparse
5 |
6 | def parse_args():
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('--rank', type=str, default="0")
9 | parser.add_argument('--node', type=str, default="0015")
10 | opt = parser.parse_args()
11 |
12 | return opt
13 | args = parse_args()
14 |
15 | os.system(f"CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train_GAN.py \
16 | -gen_bs 16 \
17 | -dis_bs 16 \
18 | --dist-url 'tcp://localhost:4321' \
19 | --dist-backend 'nccl' \
20 | --world-size 1 \
21 | --rank {args.rank} \
22 | --dataset UniMiB \
23 | --bottom_width 8 \
24 | --max_iter 500000 \
25 | --img_size 32 \
26 | --gen_model my_gen \
27 | --dis_model my_dis \
28 | --df_dim 384 \
29 | --d_heads 4 \
30 | --d_depth 3 \
31 | --g_depth 5,4,2 \
32 | --dropout 0 \
33 | --latent_dim 100 \
34 | --gf_dim 1024 \
35 | --num_workers 16 \
36 | --g_lr 0.0001 \
37 | --d_lr 0.0003 \
38 | --optimizer adam \
39 | --loss lsgan \
40 | --wd 1e-3 \
41 | --beta1 0.9 \
42 | --beta2 0.999 \
43 | --phi 1 \
44 | --batch_size 16 \
45 | --num_eval_imgs 50000 \
46 | --init_type xavier_uniform \
47 | --n_critic 1 \
48 | --val_freq 20 \
49 | --print_freq 50 \
50 | --grow_steps 0 0 \
51 | --fade_in 0 \
52 | --patch_size 2 \
53 | --ema_kimg 500 \
54 | --ema_warmup 0.1 \
55 | --ema 0.9999 \
56 | --diff_aug translation,cutout,color \
57 | --class_name Jumping \
58 | --exp_name Jumping")
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/LoadRealRunningJumping.py:
--------------------------------------------------------------------------------
1 | #A binary classification dataset, Jumping or Running
2 |
3 |
4 | import os
5 | import shutil #https://docs.python.org/3/library/shutil.html
6 | from shutil import unpack_archive # to unzip
7 | #from shutil import make_archive # to create zip for storage
8 | import requests #for downloading zip file
9 | from scipy import io #for loadmat, matlab conversion
10 | import pandas as pd
11 | import numpy as np
12 | #import matplotlib.pyplot as plt # for plotting - pandas uses matplotlib
13 | from tabulate import tabulate # for verbose tables
14 | #from tensorflow.keras.utils import to_categorical # for one-hot encoding
15 |
16 | #credit https://stackoverflow.com/questions/9419162/download-returned-zip-file-from-url
17 | #many other methods I tried failed to download the file properly
18 | from torch.utils.data import Dataset, DataLoader
19 |
20 | class_dict = {'StandingUpFS':0,'StandingUpFL':1,'Walking':2,'Running':3,'GoingUpS':4,'Jumping':5,'GoingDownS':6,'LyingDownFS':7,'SittingDown':8}
21 |
22 | class Running_Or_Jumping(Dataset):
23 | def __init__(self,
24 | incl_xyz_accel = False, #include component accel_x/y/z in ____X data
25 | incl_rms_accel = True, #add rms value (total accel) of accel_x/y/z in ____X data
26 | is_normalize = False,
27 | split_subj = dict
28 | (train_subj = [4,5,6,7,8,10,11,12,14,15,19,20,21,22,24,26,27,29,1,9,16,23,25,28],
29 | test_subj = [2,3,13,17,18,30]),
30 | data_mode = 'Train'):
31 |
32 | self.incl_xyz_accel = incl_xyz_accel
33 | self.incl_rms_accel = incl_rms_accel
34 | self.split_subj = split_subj
35 | self.data_mode = data_mode
36 | self.is_normalize = is_normalize
37 |
38 | #Download and unzip original dataset
39 | if (not os.path.isfile('./UniMiB-SHAR.zip')):
40 | print("Downloading UniMiB-SHAR.zip file")
41 | #invoking the shell command fails when exported to .py file
42 | #redirect link https://www.dropbox.com/s/raw/x2fpfqj0bpf8ep6/UniMiB-SHAR.zip
43 | #!wget https://www.dropbox.com/s/x2fpfqj0bpf8ep6/UniMiB-SHAR.zip
44 | self.download_url('https://www.dropbox.com/s/raw/x2fpfqj0bpf8ep6/UniMiB-SHAR.zip','./UniMiB-SHAR.zip')
45 | if (not os.path.isdir('./UniMiB-SHAR')):
46 | shutil.unpack_archive('./UniMiB-SHAR.zip','.','zip')
47 | #Convert .mat files to numpy ndarrays
48 | path_in = './UniMiB-SHAR/data'
49 | #loadmat loads matlab files as dictionary, keys: header, version, globals, data
50 | adl_data = io.loadmat(path_in + '/adl_data.mat')['adl_data']
51 | adl_names = io.loadmat(path_in + '/adl_names.mat', chars_as_strings=True)['adl_names']
52 | adl_labels = io.loadmat(path_in + '/adl_labels.mat')['adl_labels']
53 |
54 | #Reshape data and compute total (rms) acceleration
55 | num_samples = 151
56 | #UniMiB SHAR has fixed size of 453 which is 151 accelX, 151 accely, 151 accelz
57 | adl_data = np.reshape(adl_data,(-1,num_samples,3), order='F') #uses Fortran order
58 | if (self.incl_rms_accel):
59 | rms_accel = np.sqrt((adl_data[:,:,0]**2) + (adl_data[:,:,1]**2) + (adl_data[:,:,2]**2))
60 | adl_data = np.dstack((adl_data,rms_accel))
61 | #remove component accel if needed
62 | if (not self.incl_xyz_accel):
63 | adl_data = np.delete(adl_data, [0,1,2], 2)
64 |
65 | #Split train/test sets, combine or make separate validation set
66 | #ref for this numpy gymnastics - find index of matching subject to sub_train/sub_test/sub_validate
67 | #https://numpy.org/doc/stable/reference/generated/numpy.isin.html
68 |
69 |
70 | act_num = (adl_labels[:,0])-1 #matlab source was 1 indexed, change to 0 indexed
71 | sub_num = (adl_labels[:,1]) #subject numbers are in column 1 of labels
72 |
73 |
74 | train_index = np.nonzero(np.isin(sub_num, self.split_subj['train_subj']))
75 | x_train = adl_data[train_index]
76 | y_train = act_num[train_index]
77 |
78 | test_index = np.nonzero(np.isin(sub_num, self.split_subj['test_subj']))
79 | x_test = adl_data[test_index]
80 | y_test = act_num[test_index]
81 |
82 | self.x_train = np.transpose(x_train, (0, 2, 1))
83 | self.x_train = self.x_train.reshape(self.x_train.shape[0], self.x_train.shape[1], 1, self.x_train.shape[2])
84 | self.x_train = self.x_train[:,:,:,:-1]
85 | self.y_train = y_train
86 |
87 | self.x_test = np.transpose(x_test, (0, 2, 1))
88 | self.x_test = self.x_test.reshape(self.x_test.shape[0], self.x_test.shape[1], 1, self.x_test.shape[2])
89 | self.x_test = self.x_test[:,:,:,:-1]
90 | self.y_test = y_test
91 |
92 | if self.is_normalize:
93 | self.x_train = self.normalization(self.x_train)
94 | self.x_test = self.normalization(self.x_test)
95 |
96 | #Select running and jumping data
97 | #Label running as 0 and jumping as 1
98 |
99 | Jumping_train_data = []
100 | Running_train_data = []
101 | Jumping_test_data = []
102 | Running_test_data = []
103 |
104 |
105 | for i, label in enumerate(y_train):
106 | if label == class_dict['Running']:
107 | Running_train_data.append(self.x_train[i])
108 | elif label == class_dict['Jumping']:
109 | Jumping_train_data.append(self.x_train[i])
110 | else:
111 | continue
112 |
113 | for i, label in enumerate(y_test):
114 | if label == class_dict['Running']:
115 | Running_test_data.append(self.x_test[i])
116 | elif label == class_dict['Jumping']:
117 | Jumping_test_data.append(self.x_test[i])
118 | else:
119 | continue
120 |
121 | self.Jumping_train_labels = np.ones(len(Jumping_train_data))
122 | self.Jumping_test_labels = np.ones(len(Jumping_test_data))
123 | self.Running_train_labels = np.zeros(len(Running_train_data))
124 | self.Running_test_labels = np.zeros(len(Running_test_data))
125 |
126 | self.Jumping_train_data = np.array(Jumping_train_data)
127 | self.Running_train_data = np.array(Running_train_data)
128 | self.Jumping_test_data = np.array(Jumping_test_data)
129 | self.Running_test_data = np.array(Running_test_data)
130 |
131 |
132 | #Crop Running to only 600 samples
133 | self.Running_train_data = self.Running_train_data[:600][:][:][:]
134 | self.Running_train_labels = self.Running_train_labels[:600]
135 |
136 | self.Running_test_data = self.Running_test_data[:146][:][:][:]
137 | self.Running_test_labels = self.Running_test_labels[:146]
138 |
139 | self.combined_train_data = np.concatenate((self.Jumping_train_data, self.Running_train_data), axis=0)
140 | self.combined_test_data = np.concatenate((self.Jumping_test_data, self.Running_test_data), axis=0)
141 |
142 | self.combined_train_label = np.concatenate((self.Jumping_train_labels, self.Running_train_labels), axis=0)
143 | self.combined_train_label = self.combined_train_label.reshape(self.combined_train_label.shape[0], 1)
144 |
145 | self.combined_test_label = np.concatenate((self.Jumping_test_labels, self.Running_test_labels), axis=0)
146 | self.combined_test_label = self.combined_test_label.reshape(self.combined_test_label.shape[0], 1)
147 |
148 | if self.data_mode == 'Train':
149 | print(f'data shape is {self.combined_train_data.shape}, label shape is {self.combined_train_label.shape}')
150 | print(f'Jumping label is 1, has {len(self.Jumping_train_labels)} samples, Running label is 0, has {len(self.Running_train_labels)} samples')
151 | else:
152 | print(f'data shape is {self.combined_test_data.shape}, label shape is {self.combined_test_label.shape}')
153 | print(f'Jumping label is 1, has {len(self.Jumping_test_labels)} samples, Running label is 0, has {len(self.Running_test_labels)} samples')
154 |
155 |
156 | def download_url(self, url, save_path, chunk_size=128):
157 | r = requests.get(url, stream=True)
158 | with open(save_path, 'wb') as fd:
159 | for chunk in r.iter_content(chunk_size=chunk_size):
160 | fd.write(chunk)
161 |
162 | def to_categorical(self, y, num_classes):
163 | """ 1-hot encodes a tensor """
164 | return np.eye(num_classes, dtype='uint8')[y]
165 |
166 |
167 | def _normalize(self, epoch):
168 | """ A helper method for the normalization method.
169 | Returns
170 | result: a normalized epoch
171 | """
172 | e = 1e-10
173 | result = (epoch - epoch.mean(axis=0)) / ((np.sqrt(epoch.var(axis=0)))+e)
174 | return result
175 |
176 | def _min_max_normalize(self, epoch):
177 |
178 | result = (epoch - min(epoch)) / (max(epoch) - min(epoch))
179 | return result
180 |
181 | def normalization(self, epochs):
182 | """ Normalizes each epoch e s.t mean(e) = 0 and var(e) = 1
183 | Args:
184 | epochs - Numpy structure of epochs
185 | Returns:
186 | epochs_n - mne data structure of normalized epochs (mean=0, var=1)
187 | """
188 | for i in range(epochs.shape[0]):
189 | for j in range(epochs.shape[1]):
190 | epochs[i,j,0,:] = self._normalize(epochs[i,j,0,:])
191 | # epochs[i,j,0,:] = self._min_max_normalize(epochs[i,j,0,:])
192 |
193 | return epochs
194 |
195 |
196 | def __len__(self):
197 |
198 | if self.data_mode == 'Train':
199 | return len(self.combined_train_label)
200 | else:
201 | return len(self.combined_test_label)
202 |
203 | def __getitem__(self, idx):
204 |
205 | if self.data_mode == 'Train':
206 | return self.combined_train_data[idx], self.combined_train_label[idx]
207 | else:
208 | return self.combined_test_data[idx], self.combined_test_label[idx]
209 |
210 | def collate_fn(self):
211 | pass
212 |
213 |
--------------------------------------------------------------------------------
/LoadSyntheticRunningJumping.py:
--------------------------------------------------------------------------------
1 | # Generator synthetic Running and Jumping data
2 | # Made them to a Pytorch Dataset
3 |
4 | from torch.utils.data import Dataset, DataLoader
5 | import torch
6 | from GANModels import *
7 | import numpy as np
8 | import os
9 |
10 | class Synthetic_Dataset(Dataset):
11 | def __init__(self,
12 | Jumping_model_path = './pre-trained-models/JumpingGAN_checkpoint',
13 | Running_model_path = './pre-trained-models/RunningGAN_checkpoint',
14 | sample_size = 1000
15 | ):
16 |
17 | self.sample_size = sample_size
18 |
19 | #Generate Running Data
20 | running_gen_net = Generator(seq_len=150, channels=3, latent_dim=100)
21 | running_ckp = torch.load(Running_model_path)
22 | running_gen_net.load_state_dict(running_ckp['gen_state_dict'])
23 |
24 | #Generate Jumping Data
25 | jumping_gen_net = Generator(seq_len=150, channels=3, latent_dim=100)
26 | jumping_ckp = torch.load(Jumping_model_path)
27 | jumping_gen_net.load_state_dict(jumping_ckp['gen_state_dict'])
28 |
29 |
30 | #generate synthetic running data label is 0
31 | z = torch.FloatTensor(np.random.normal(0, 1, (self.sample_size, 100)))
32 | self.syn_running = running_gen_net(z)
33 | self.syn_running = self.syn_running.detach().numpy()
34 | self.running_label = np.zeros(len(self.syn_running))
35 |
36 | #generate synthetic jumping data label is 1
37 | z = torch.FloatTensor(np.random.normal(0, 1, (self.sample_size, 100)))
38 | self.syn_jumping = jumping_gen_net(z)
39 | self.syn_jumping = self.syn_jumping.detach().numpy()
40 | self.jumping_label = np.ones(len(self.syn_jumping))
41 |
42 | self.combined_train_data = np.concatenate((self.syn_running, self.syn_jumping), axis=0)
43 | self.combined_train_label = np.concatenate((self.running_label, self.jumping_label), axis=0)
44 | self.combined_train_label = self.combined_train_label.reshape(self.combined_train_label.shape[0], 1)
45 |
46 | print(self.combined_train_data.shape)
47 | print(self.combined_train_label.shape)
48 |
49 |
50 | def __len__(self):
51 | return self.sample_size * 2
52 |
53 | def __getitem__(self, idx):
54 | return self.combined_train_data[idx], self.combined_train_label[idx]
55 |
56 |
57 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TTS-GAN: A Transformer-based Time-Series Generative Adversarial Network
2 | ---
3 |
4 | This repository contains code from the paper "TTS-GAN: A Transformer-based Time-Series Generative Adversarial Network".
5 |
6 | The paper has been accepted to publish in the 20th International Conference on Artificial Intelligence in Medicine (AIME 2022).
7 |
8 | Please find the paper [here](https://arxiv.org/abs/2202.02691)
9 |
10 | ---
11 |
12 | **Abstract:**
13 | Time-series datasets used in machine learning applications often are small in size, making the training of deep neural network architectures ineffective. For time series, the suite of data augmentation tricks we can use to expand the size of the dataset is limited by the need to maintain the basic properties of the signal. Data generated by a Generative Adversarial Network (GAN) can be utilized as another data augmentation tool. RNN-based GANs suffer from the fact that they cannot effectively model long sequences of data points with irregular temporal relations. To tackle these problems, we introduce TTS-GAN, a transformer-based GAN which can successfully generate realistic synthetic time series data sequences of arbitrary length, similar to the original ones. Both the generator and discriminator networks of the GAN model are built using a pure transformer encoder architecture. We use visualizations to demonstrate the similarity of real and generated time series and a simple classification task that shows how we can use synthetically generated data to augment real data and improve classification accuracy.
14 |
15 | ---
16 |
17 | **Key Idea:**
18 |
19 | Transformer GAN generate synthetic time-series data
20 |
21 | **The TTS-GAN Architecture**
22 |
23 | 
24 |
25 | The TTS-GAN model architecture is shown in the upper figure. It contains two main parts, a generator, and a discriminator. Both of them are built based on the transformer encoder architecture. An encoder is a composition of two compound blocks. A multi-head self-attention module constructs the first block and the second block is a feed-forward MLP with GELU activation function. The normalization layer is applied before both of the two blocks and the dropout layer is added after each block. Both blocks employ residual connections.
26 |
27 |
28 | **The time series data processing step**
29 |
30 | 
31 |
32 | We view a time-series data sequence like an image with a height equal to 1. The number of time-steps is the width of an image, *W*. A time-series sequence can have a single channel or multiple channels, and those can be viewed as the number of channels (RGB) of an image, *C*. So an input sequence can be represented with the matrix of size *(Batch Size, C, 1, W)*. Then we choose a patch size *N* to divide a sequence into *W / N* patches. We then add a soft positional encoding value by the end of each patch, the positional value is learned during model training. Each patch will then have the data shape *(Batch Size, C, 1, (W/N) + 1)* This process is shown in the upper figure.
33 |
34 | ---
35 |
36 | **Repository structures:**
37 |
38 | > ./images
39 |
40 | Several images of the TTS-GAN project
41 |
42 |
43 | > ./pre-trained-models
44 |
45 | Saved pre-trained GAN model checkpoints
46 |
47 |
48 | > dataLoader.py
49 |
50 | The UniMiB dataset dataLoader used for loading GAN model training/testing data
51 |
52 |
53 | > LoadRealRunningJumping.py
54 |
55 | Load real running and jumping data from UniMiB dataset
56 |
57 |
58 | > LoadSyntheticRunningJumping.py
59 |
60 | Load Synthetic running and jumping data from the pre-trained GAN models
61 |
62 |
63 | > functions.py
64 |
65 | The GAN model training and evaluation functions
66 |
67 |
68 | > train_GAN.py
69 |
70 | The major GAN model training file
71 |
72 |
73 | > visualizationMetrics.py
74 |
75 | The help functions to draw T-SNE and PCA plots
76 |
77 |
78 | > adamw.py
79 |
80 | The adamw function file
81 |
82 |
83 | > cfg.py
84 |
85 | The parse function used for reading parameters to train_GAN.py file
86 |
87 |
88 | > JumpingGAN_Train.py
89 |
90 | Run this file to start training the Jumping GAN model
91 |
92 |
93 | > RunningGAN_Train.py
94 |
95 | Run this file to start training the Running GAN model
96 |
97 |
98 | ---
99 |
100 | **Code Instructions:**
101 |
102 |
103 | To train the Running data GAN model:
104 | ```
105 | python RunningGAN_Train.py
106 | ```
107 |
108 | To train the Jumping data GAN model:
109 | ```
110 | python JumpingGAN_Train.py
111 | ```
112 |
113 | A simple example of visualizing the similarity between the synthetic running&jumping data and the real running&jumping data:
114 | ```
115 | Running&JumpingVisualization.ipynb
116 | ```
117 | ---
118 |
--------------------------------------------------------------------------------
/RunningGAN_Train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | import os
4 | import argparse
5 |
6 | def parse_args():
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('--rank', type=str, default="0")
9 | parser.add_argument('--node', type=str, default="0015")
10 | opt = parser.parse_args()
11 |
12 | return opt
13 | args = parse_args()
14 |
15 | os.system(f"CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train_GAN.py \
16 | -gen_bs 16 \
17 | -dis_bs 16 \
18 | --dist-url 'tcp://localhost:4321' \
19 | --dist-backend 'nccl' \
20 | --world-size 1 \
21 | --rank {args.rank} \
22 | --dataset UniMiB \
23 | --bottom_width 8 \
24 | --max_iter 500000 \
25 | --img_size 32 \
26 | --gen_model my_gen \
27 | --dis_model my_dis \
28 | --df_dim 384 \
29 | --d_heads 4 \
30 | --d_depth 3 \
31 | --g_depth 5,4,2 \
32 | --dropout 0 \
33 | --latent_dim 100 \
34 | --gf_dim 1024 \
35 | --num_workers 16 \
36 | --g_lr 0.0001 \
37 | --d_lr 0.0003 \
38 | --optimizer adam \
39 | --loss lsgan \
40 | --wd 1e-3 \
41 | --beta1 0.9 \
42 | --beta2 0.999 \
43 | --phi 1 \
44 | --batch_size 16 \
45 | --num_eval_imgs 50000 \
46 | --init_type xavier_uniform \
47 | --n_critic 1 \
48 | --val_freq 20 \
49 | --print_freq 50 \
50 | --grow_steps 0 0 \
51 | --fade_in 0 \
52 | --patch_size 2 \
53 | --ema_kimg 500 \
54 | --ema_warmup 0.1 \
55 | --ema 0.9999 \
56 | --diff_aug translation,cutout,color \
57 | --class_name Running \
58 | --exp_name Running")
--------------------------------------------------------------------------------
/adamw.py:
--------------------------------------------------------------------------------
1 | """ AdamW Optimizer
2 | Impl copied from PyTorch master
3 | """
4 | import math
5 | import torch
6 | from torch.optim.optimizer import Optimizer
7 |
8 |
9 | class AdamW(Optimizer):
10 | r"""Implements AdamW algorithm.
11 |
12 | The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
13 | The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
14 |
15 | Arguments:
16 | params (iterable): iterable of parameters to optimize or dicts defining
17 | parameter groups
18 | lr (float, optional): learning rate (default: 1e-3)
19 | betas (Tuple[float, float], optional): coefficients used for computing
20 | running averages of gradient and its square (default: (0.9, 0.999))
21 | eps (float, optional): term added to the denominator to improve
22 | numerical stability (default: 1e-8)
23 | weight_decay (float, optional): weight decay coefficient (default: 1e-2)
24 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this
25 | algorithm from the paper `On the Convergence of Adam and Beyond`_
26 | (default: False)
27 |
28 | .. _Adam\: A Method for Stochastic Optimization:
29 | https://arxiv.org/abs/1412.6980
30 | .. _Decoupled Weight Decay Regularization:
31 | https://arxiv.org/abs/1711.05101
32 | .. _On the Convergence of Adam and Beyond:
33 | https://openreview.net/forum?id=ryQu7f-RZ
34 | """
35 |
36 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
37 | weight_decay=1e-2, amsgrad=False):
38 | if not 0.0 <= lr:
39 | raise ValueError("Invalid learning rate: {}".format(lr))
40 | if not 0.0 <= eps:
41 | raise ValueError("Invalid epsilon value: {}".format(eps))
42 | if not 0.0 <= betas[0] < 1.0:
43 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
44 | if not 0.0 <= betas[1] < 1.0:
45 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
46 | defaults = dict(lr=lr, betas=betas, eps=eps,
47 | weight_decay=weight_decay, amsgrad=amsgrad)
48 | super(AdamW, self).__init__(params, defaults)
49 |
50 | def __setstate__(self, state):
51 | super(AdamW, self).__setstate__(state)
52 | for group in self.param_groups:
53 | group.setdefault('amsgrad', False)
54 |
55 | def step(self, closure=None):
56 | """Performs a single optimization step.
57 |
58 | Arguments:
59 | closure (callable, optional): A closure that reevaluates the model
60 | and returns the loss.
61 | """
62 | loss = None
63 | if closure is not None:
64 | loss = closure()
65 |
66 | for group in self.param_groups:
67 | for p in group['params']:
68 | if p.grad is None:
69 | continue
70 |
71 | # Perform stepweight decay
72 | p.data.mul_(1 - group['lr'] * group['weight_decay'])
73 |
74 | # Perform optimization step
75 | grad = p.grad.data
76 | if grad.is_sparse:
77 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
78 | amsgrad = group['amsgrad']
79 |
80 | state = self.state[p]
81 |
82 | # State initialization
83 | if len(state) == 0:
84 | state['step'] = 0
85 | # Exponential moving average of gradient values
86 | state['exp_avg'] = torch.zeros_like(p.data)
87 | # Exponential moving average of squared gradient values
88 | state['exp_avg_sq'] = torch.zeros_like(p.data)
89 | if amsgrad:
90 | # Maintains max of all exp. moving avg. of sq. grad. values
91 | state['max_exp_avg_sq'] = torch.zeros_like(p.data)
92 |
93 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
94 | if amsgrad:
95 | max_exp_avg_sq = state['max_exp_avg_sq']
96 | beta1, beta2 = group['betas']
97 |
98 | state['step'] += 1
99 | bias_correction1 = 1 - beta1 ** state['step']
100 | bias_correction2 = 1 - beta2 ** state['step']
101 |
102 | # Decay the first and second moment running average coefficient
103 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
104 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
105 | if amsgrad:
106 | # Maintains the maximum of all 2nd moment running avg. till now
107 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
108 | # Use the max. for normalizing running avg. of gradient
109 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
110 | else:
111 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
112 |
113 | step_size = group['lr'] / bias_correction1
114 |
115 | p.data.addcdiv_(-step_size, exp_avg, denom)
116 |
117 | return loss
--------------------------------------------------------------------------------
/cfg.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2019-07-25
3 | # @Author : Xinyu Gong (xy_gong@tamu.edu)
4 | # @Link : None
5 | # @Version : 0.0
6 |
7 | import argparse
8 |
9 |
10 | def str2bool(v):
11 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
12 | return True
13 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
14 | return False
15 | else:
16 | raise argparse.ArgumentTypeError('Boolean value expected.')
17 |
18 |
19 | def parse_args():
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('--world-size', default=-1, type=int,
22 | help='number of nodes for distributed training')
23 | parser.add_argument('--rank', default=-1, type=int,
24 | help='node rank for distributed training')
25 | parser.add_argument('--loca_rank', default=-1, type=int,
26 | help='node rank for distributed training')
27 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
28 | help='url used to set up distributed training')
29 | parser.add_argument('--dist-backend', default='nccl', type=str,
30 | help='distributed backend')
31 | parser.add_argument('--seed', default=12345, type=int,
32 | help='seed for initializing training. ')
33 | parser.add_argument('--gpu', default=None, type=int,
34 | help='GPU id to use.')
35 | parser.add_argument('--multiprocessing-distributed', action='store_true',
36 | help='Use multi-processing distributed training to launch '
37 | 'N processes per node, which has N GPUs. This is the '
38 | 'fastest way to use PyTorch for either single node or '
39 | 'multi node data parallel training')
40 | parser.add_argument(
41 | '--max_epoch',
42 | type=int,
43 | default=200,
44 | help='number of epochs of training')
45 | parser.add_argument(
46 | '--max_iter',
47 | type=int,
48 | default=None,
49 | help='set the max iteration number')
50 | parser.add_argument(
51 | '-gen_bs',
52 | '--gen_batch_size',
53 | type=int,
54 | default=64,
55 | help='size of the batches')
56 | parser.add_argument(
57 | '-dis_bs',
58 | '--dis_batch_size',
59 | type=int,
60 | default=64,
61 | help='size of the batches')
62 | parser.add_argument(
63 | '-bs',
64 | '--batch_size',
65 | type=int,
66 | default=64,
67 | help='size of the batches to load dataset')
68 | parser.add_argument(
69 | '--g_lr',
70 | type=float,
71 | default=0.0002,
72 | help='adam: gen learning rate')
73 | parser.add_argument(
74 | '--wd',
75 | type=float,
76 | default=0,
77 | help='adamw: gen weight decay')
78 | parser.add_argument(
79 | '--d_lr',
80 | type=float,
81 | default=0.0002,
82 | help='adam: disc learning rate')
83 | parser.add_argument(
84 | '--ctrl_lr',
85 | type=float,
86 | default=3.5e-4,
87 | help='adam: ctrl learning rate')
88 | parser.add_argument(
89 | '--lr_decay',
90 | action='store_true',
91 | help='learning rate decay or not')
92 | parser.add_argument(
93 | '--beta1',
94 | type=float,
95 | default=0.0,
96 | help='adam: decay of first order momentum of gradient')
97 | parser.add_argument(
98 | '--beta2',
99 | type=float,
100 | default=0.9,
101 | help='adam: decay of first order momentum of gradient')
102 | parser.add_argument(
103 | '--num_workers',
104 | type=int,
105 | default=8,
106 | help='number of cpu threads to use during batch generation')
107 | parser.add_argument(
108 | '--latent_dim',
109 | type=int,
110 | default=128,
111 | help='dimensionality of the latent space')
112 | parser.add_argument(
113 | '--img_size',
114 | type=int,
115 | default=32,
116 | help='size of each image dimension')
117 | parser.add_argument(
118 | '--channels',
119 | type=int,
120 | default=3,
121 | help='number of image channels')
122 | parser.add_argument(
123 | '--n_critic',
124 | type=int,
125 | default=1,
126 | help='number of training steps for discriminator per iter')
127 | parser.add_argument(
128 | '--val_freq',
129 | type=int,
130 | default=20,
131 | help='interval between each validation')
132 | parser.add_argument(
133 | '--print_freq',
134 | type=int,
135 | default=100,
136 | help='interval between each verbose')
137 | parser.add_argument(
138 | '--load_path',
139 | type=str,
140 | help='The reload model path')
141 | parser.add_argument(
142 | '--class_name',
143 | type=str,
144 | help='The class name to load in UniMiB dataset')
145 | parser.add_argument(
146 | '--augment_times',
147 | type=int,
148 | default=None,
149 | help='The times of augment signals compare to original data')
150 | parser.add_argument(
151 | '--exp_name',
152 | type=str,
153 | help='The name of exp')
154 | parser.add_argument(
155 | '--d_spectral_norm',
156 | type=str2bool,
157 | default=False,
158 | help='add spectral_norm on discriminator?')
159 | parser.add_argument(
160 | '--g_spectral_norm',
161 | type=str2bool,
162 | default=False,
163 | help='add spectral_norm on generator?')
164 | parser.add_argument(
165 | '--dataset',
166 | type=str,
167 | default='cifar10',
168 | help='dataset type')
169 | parser.add_argument(
170 | '--data_path',
171 | type=str,
172 | default='./data',
173 | help='The path of data set')
174 | parser.add_argument('--init_type', type=str, default='normal',
175 | choices=['normal', 'orth', 'xavier_uniform', 'false'],
176 | help='The init type')
177 | parser.add_argument('--gf_dim', type=int, default=64,
178 | help='The base channel num of gen')
179 | parser.add_argument('--df_dim', type=int, default=64,
180 | help='The base channel num of disc')
181 | parser.add_argument(
182 | '--gen_model',
183 | type=str,
184 | help='path of gen model')
185 | parser.add_argument(
186 | '--dis_model',
187 | type=str,
188 | help='path of dis model')
189 | parser.add_argument(
190 | '--controller',
191 | type=str,
192 | default='controller',
193 | help='path of controller')
194 | parser.add_argument('--eval_batch_size', type=int, default=100)
195 | parser.add_argument('--num_eval_imgs', type=int, default=50000)
196 | parser.add_argument(
197 | '--bottom_width',
198 | type=int,
199 | default=4,
200 | help="the base resolution of the GAN")
201 | parser.add_argument('--random_seed', type=int, default=12345)
202 |
203 | # search
204 | parser.add_argument('--shared_epoch', type=int, default=15,
205 | help='the number of epoch to train the shared gan at each search iteration')
206 | parser.add_argument('--grow_step1', type=int, default=25,
207 | help='which iteration to grow the image size from 8 to 16')
208 | parser.add_argument('--grow_step2', type=int, default=55,
209 | help='which iteration to grow the image size from 16 to 32')
210 | parser.add_argument('--max_search_iter', type=int, default=90,
211 | help='max search iterations of this algorithm')
212 | parser.add_argument('--ctrl_step', type=int, default=30,
213 | help='number of steps to train the controller at each search iteration')
214 | parser.add_argument('--ctrl_sample_batch', type=int, default=1,
215 | help='sample size of controller of each step')
216 | parser.add_argument('--hid_size', type=int, default=100,
217 | help='the size of hidden vector')
218 | parser.add_argument('--baseline_decay', type=float, default=0.9,
219 | help='baseline decay rate in RL')
220 | parser.add_argument('--rl_num_eval_img', type=int, default=5000,
221 | help='number of images to be sampled in order to get the reward')
222 | parser.add_argument('--num_candidate', type=int, default=10,
223 | help='number of candidate architectures to be sampled')
224 | parser.add_argument('--topk', type=int, default=5,
225 | help='preserve topk models architectures after each stage' )
226 | parser.add_argument('--entropy_coeff', type=float, default=1e-3,
227 | help='to encourage the exploration')
228 | parser.add_argument('--dynamic_reset_threshold', type=float, default=1e-3,
229 | help='var threshold')
230 | parser.add_argument('--dynamic_reset_window', type=int, default=500,
231 | help='the window size')
232 | parser.add_argument('--arch', nargs='+', type=int,
233 | help='the vector of a discovered architecture')
234 | parser.add_argument('--optimizer', type=str, default="adam",
235 | help='optimizer')
236 | parser.add_argument('--loss', type=str, default="hinge",
237 | help='loss function')
238 | parser.add_argument('--n_classes', type=int, default=0,
239 | help='classes')
240 | parser.add_argument('--phi', type=float, default=1,
241 | help='wgan-gp phi')
242 | parser.add_argument('--grow_steps', nargs='+', type=int,
243 | help='the vector of a discovered architecture')
244 | parser.add_argument('--D_downsample', type=str, default="avg",
245 | help='downsampling type')
246 | parser.add_argument('--fade_in', type=float, default=1,
247 | help='fade in step')
248 | parser.add_argument('--d_depth', type=int, default=7,
249 | help='Discriminator Depth')
250 | parser.add_argument('--g_depth', type=str, default="5,4,2",
251 | help='Generator Depth')
252 | parser.add_argument('--g_norm', type=str, default="ln",
253 | help='Generator Normalization')
254 | parser.add_argument('--d_norm', type=str, default="ln",
255 | help='Discriminator Normalization')
256 | parser.add_argument('--g_act', type=str, default="gelu",
257 | help='Generator activation Layer')
258 | parser.add_argument('--d_act', type=str, default="gelu",
259 | help='Discriminator activation layer')
260 | parser.add_argument('--patch_size', type=int, default=4,
261 | help='Discriminator Depth')
262 | parser.add_argument('--fid_stat', type=str, default="None",
263 | help='Discriminator Depth')
264 | parser.add_argument('--diff_aug', type=str, default="None",
265 | help='differentiable augmentation type')
266 | parser.add_argument('--accumulated_times', type=int, default=1,
267 | help='gradient accumulation')
268 | parser.add_argument('--g_accumulated_times', type=int, default=1,
269 | help='gradient accumulation')
270 | parser.add_argument('--num_landmarks', type=int, default=64,
271 | help='number of landmarks')
272 | parser.add_argument('--d_heads', type=int, default=4,
273 | help='number of heads')
274 | parser.add_argument('--dropout', type=float, default=0.,
275 | help='dropout ratio')
276 | parser.add_argument('--ema', type=float, default=0.995,
277 | help='ema')
278 | parser.add_argument('--ema_warmup', type=float, default=0.,
279 | help='ema warm up')
280 | parser.add_argument('--ema_kimg', type=int, default=500,
281 | help='ema thousand images')
282 | parser.add_argument('--latent_norm',action='store_true',
283 | help='latent vector normalization')
284 | parser.add_argument('--ministd',action='store_true',
285 | help='mini batch std')
286 | parser.add_argument('--g_mlp', type=int, default=4,
287 | help='generator mlp ratio')
288 | parser.add_argument('--d_mlp', type=int, default=4,
289 | help='discriminator mlp ratio')
290 | parser.add_argument('--g_window_size', type=int, default=8,
291 | help='generator mlp ratio')
292 | parser.add_argument('--d_window_size', type=int, default=8,
293 | help='discriminator mlp ratio')
294 | parser.add_argument('--show', action='store_true',
295 | help='show')
296 |
297 | opt = parser.parse_args()
298 |
299 | return opt
--------------------------------------------------------------------------------
/dataLoader.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """UniMiB_SHAR_ADL_load_dataset.ipynb
3 |
4 | Automatically generated by Colaboratory.
5 |
6 | Original file is located at
7 | https://colab.research.google.com/drive/1U1EY6cZsOFERD3Df1HRqjuTq5bDUGH03
8 |
9 | #UniMiB_SHAR_ADL_load_dataset.ipynb.
10 | Loads the A-9 (ADL) portion of the UniMiB dataset from the Internet repository and converts the data into numpy arrays while adhering to the general format of the [Keras MNIST load_data function](https://keras.io/api/datasets/mnist/#load_data-function).
11 |
12 | Arguments: tbd
13 | Returns: Tuple of Numpy arrays:
14 | (x_train, y_train),(x_validation, y_validation)\[optional\],(x_test, y_test)
15 |
16 | * x_train\/validation\/test: containing float64 with shapes (num_samples, 151, {3,4,1})
17 | * y_train\/validation\/test: containing int8 with shapes (num_samples 0-9)
18 |
19 | The train/test split is by subject
20 |
21 | Example usage:
22 | x_train, y_train, x_test, y_test = unimib_load_dataset()
23 |
24 | Additional References
25 | If you use the dataset and/or code, please cite this paper (downloadable from [here](http://www.mdpi.com/2076-3417/7/10/1101/html))
26 |
27 | Developed and tested using colab.research.google.com
28 | To save as .py version use File > Download .py
29 |
30 | Author: Lee B. Hinkle, IMICS Lab, Texas State University, 2021
31 |
32 | 
This work is licensed under a Creative Commons Attribution-ShareAlike 4.0 International License.
33 |
34 |
35 | TODOs:
36 | * Fix document strings
37 | * Assign names to activities instead of numbers
38 | """
39 |
40 | import os
41 | import shutil #https://docs.python.org/3/library/shutil.html
42 | from shutil import unpack_archive # to unzip
43 | #from shutil import make_archive # to create zip for storage
44 | import requests #for downloading zip file
45 | from scipy import io #for loadmat, matlab conversion
46 | import pandas as pd
47 | import numpy as np
48 | #import matplotlib.pyplot as plt # for plotting - pandas uses matplotlib
49 | from tabulate import tabulate # for verbose tables
50 | #from tensorflow.keras.utils import to_categorical # for one-hot encoding
51 |
52 | #credit https://stackoverflow.com/questions/9419162/download-returned-zip-file-from-url
53 | #many other methods I tried failed to download the file properly
54 | from torch.utils.data import Dataset, DataLoader
55 |
56 | #data augmentation
57 | import tsaug
58 |
59 | class_dict = {'StandingUpFS':0,'StandingUpFL':1,'Walking':2,'Running':3,'GoingUpS':4,'Jumping':5,'GoingDownS':6,'LyingDownFS':7,'SittingDown':8}
60 |
61 | class unimib_load_dataset(Dataset):
62 | def __init__(self,
63 | verbose = False,
64 | incl_xyz_accel = False, #include component accel_x/y/z in ____X data
65 | incl_rms_accel = True, #add rms value (total accel) of accel_x/y/z in ____X data
66 | incl_val_group = False, #True => returns x/y_test, x/y_validation, x/y_train
67 | #False => combine test & validation groups
68 | is_normalize = False,
69 | split_subj = dict
70 | (train_subj = [4,5,6,7,8,10,11,12,14,15,19,20,21,22,24,26,27,29],
71 | validation_subj = [1,9,16,23,25,28],
72 | test_subj = [2,3,13,17,18,30]),
73 | one_hot_encode = True, data_mode = 'Train', single_class = False, class_name= 'Walking', augment_times = None):
74 |
75 | self.verbose = verbose
76 | self.incl_xyz_accel = incl_xyz_accel
77 | self.incl_rms_accel = incl_rms_accel
78 | self.incl_val_group = incl_val_group
79 | self.split_subj = split_subj
80 | self.one_hot_encode = one_hot_encode
81 | self.data_mode = data_mode
82 | self.class_name = class_name
83 | self.single_class = single_class
84 | self.is_normalize = is_normalize
85 |
86 |
87 | #Download and unzip original dataset
88 | if (not os.path.isfile('./UniMiB-SHAR.zip')):
89 | print("Downloading UniMiB-SHAR.zip file")
90 | #invoking the shell command fails when exported to .py file
91 | #redirect link https://www.dropbox.com/s/raw/x2fpfqj0bpf8ep6/UniMiB-SHAR.zip
92 | #!wget https://www.dropbox.com/s/x2fpfqj0bpf8ep6/UniMiB-SHAR.zip
93 | self.download_url('https://www.dropbox.com/s/raw/x2fpfqj0bpf8ep6/UniMiB-SHAR.zip','./UniMiB-SHAR.zip')
94 | if (not os.path.isdir('./UniMiB-SHAR')):
95 | shutil.unpack_archive('./UniMiB-SHAR.zip','.','zip')
96 | #Convert .mat files to numpy ndarrays
97 | path_in = './UniMiB-SHAR/data'
98 | #loadmat loads matlab files as dictionary, keys: header, version, globals, data
99 | adl_data = io.loadmat(path_in + '/adl_data.mat')['adl_data']
100 | adl_names = io.loadmat(path_in + '/adl_names.mat', chars_as_strings=True)['adl_names']
101 | adl_labels = io.loadmat(path_in + '/adl_labels.mat')['adl_labels']
102 |
103 | if(self.verbose):
104 | headers = ("Raw data","shape", "object type", "data type")
105 | mydata = [("adl_data:", adl_data.shape, type(adl_data), adl_data.dtype),
106 | ("adl_labels:", adl_labels.shape ,type(adl_labels), adl_labels.dtype),
107 | ("adl_names:", adl_names.shape, type(adl_names), adl_names.dtype)]
108 | print(tabulate(mydata, headers=headers))
109 | #Reshape data and compute total (rms) acceleration
110 | num_samples = 151
111 | #UniMiB SHAR has fixed size of 453 which is 151 accelX, 151 accely, 151 accelz
112 | adl_data = np.reshape(adl_data,(-1,num_samples,3), order='F') #uses Fortran order
113 | if (self.incl_rms_accel):
114 | rms_accel = np.sqrt((adl_data[:,:,0]**2) + (adl_data[:,:,1]**2) + (adl_data[:,:,2]**2))
115 | adl_data = np.dstack((adl_data,rms_accel))
116 | #remove component accel if needed
117 | if (not self.incl_xyz_accel):
118 | adl_data = np.delete(adl_data, [0,1,2], 2)
119 | if(verbose):
120 | headers = ("Reshaped data","shape", "object type", "data type")
121 | mydata = [("adl_data:", adl_data.shape, type(adl_data), adl_data.dtype),
122 | ("adl_labels:", adl_labels.shape ,type(adl_labels), adl_labels.dtype),
123 | ("adl_names:", adl_names.shape, type(adl_names), adl_names.dtype)]
124 | print(tabulate(mydata, headers=headers))
125 | #Split train/test sets, combine or make separate validation set
126 | #ref for this numpy gymnastics - find index of matching subject to sub_train/sub_test/sub_validate
127 | #https://numpy.org/doc/stable/reference/generated/numpy.isin.html
128 |
129 |
130 | act_num = (adl_labels[:,0])-1 #matlab source was 1 indexed, change to 0 indexed
131 | sub_num = (adl_labels[:,1]) #subject numbers are in column 1 of labels
132 |
133 | if (not self.incl_val_group):
134 | train_index = np.nonzero(np.isin(sub_num, self.split_subj['train_subj'] +
135 | self.split_subj['validation_subj']))
136 | x_train = adl_data[train_index]
137 | y_train = act_num[train_index]
138 | else:
139 | train_index = np.nonzero(np.isin(sub_num, self.split_subj['train_subj']))
140 | x_train = adl_data[train_index]
141 | y_train = act_num[train_index]
142 |
143 | validation_index = np.nonzero(np.isin(sub_num, self.split_subj['validation_subj']))
144 | x_validation = adl_data[validation_index]
145 | y_validation = act_num[validation_index]
146 |
147 | test_index = np.nonzero(np.isin(sub_num, self.split_subj['test_subj']))
148 | x_test = adl_data[test_index]
149 | y_test = act_num[test_index]
150 |
151 | if (verbose):
152 | print("x/y_train shape ",x_train.shape,y_train.shape)
153 | if (self.incl_val_group):
154 | print("x/y_validation shape ",x_validation.shape,y_validation.shape)
155 | print("x/y_test shape ",x_test.shape,y_test.shape)
156 | #If selected one-hot encode y_* using keras to_categorical, reference:
157 | #https://keras.io/api/utils/python_utils/#to_categorical-function and
158 | #https://machinelearningmastery.com/how-to-one-hot-encode-sequence-data-in-python/
159 | if (self.one_hot_encode):
160 | y_train = self.to_categorical(y_train, num_classes=9)
161 | if (self.incl_val_group):
162 | y_validation = self.to_categorical(y_validation, num_classes=9)
163 | y_test = self.to_categorical(y_test, num_classes=9)
164 | if (verbose):
165 | print("After one-hot encoding")
166 | print("x/y_train shape ",x_train.shape,y_train.shape)
167 | if (self.incl_val_group):
168 | print("x/y_validation shape ",x_validation.shape,y_validation.shape)
169 | print("x/y_test shape ",x_test.shape,y_test.shape)
170 | # if (self.incl_val_group):
171 | # return x_train, y_train, x_validation, y_validation, x_test, y_test
172 | # else:
173 | # return x_train, y_train, x_test, y_test
174 |
175 | # reshape x_train, x_test data shape from (BH, length, channel) to (BH, channel, 1, length)
176 | self.x_train = np.transpose(x_train, (0, 2, 1))
177 | self.x_train = self.x_train.reshape(self.x_train.shape[0], self.x_train.shape[1], 1, self.x_train.shape[2])
178 | self.x_train = self.x_train[:,:,:,:-1]
179 | self.y_train = y_train
180 |
181 | self.x_test = np.transpose(x_test, (0, 2, 1))
182 | self.x_test = self.x_test.reshape(self.x_test.shape[0], self.x_test.shape[1], 1, self.x_test.shape[2])
183 | self.x_test = self.x_test[:,:,:,:-1]
184 | self.y_test = y_test
185 | print(f'x_train shape is {self.x_train.shape}, x_test shape is {self.x_test.shape}')
186 | print(f'y_train shape is {self.y_train.shape}, y_test shape is {self.y_test.shape}')
187 |
188 |
189 | if self.is_normalize:
190 | self.x_train = self.normalization(self.x_train)
191 | self.x_test = self.normalization(self.x_test)
192 |
193 | #Return the give class train/test data & labels
194 | if self.single_class:
195 | one_class_train_data = []
196 | one_class_train_labels = []
197 | one_class_test_data = []
198 | one_class_test_labels = []
199 |
200 | for i, label in enumerate(y_train):
201 | if label == class_dict[self.class_name]:
202 | one_class_train_data.append(self.x_train[i])
203 | one_class_train_labels.append(label)
204 |
205 | for i, label in enumerate(y_test):
206 | if label == class_dict[self.class_name]:
207 | one_class_test_data.append(self.x_test[i])
208 | one_class_test_labels.append(label)
209 | self.one_class_train_data = np.array(one_class_train_data)
210 | self.one_class_train_labels = np.array(one_class_train_labels)
211 | self.one_class_test_data = np.array(one_class_test_data)
212 | self.one_class_test_labels = np.array(one_class_test_labels)
213 |
214 | if augment_times:
215 | augment_data = []
216 | augment_labels = []
217 | for data, label in zip(one_class_train_data, one_class_train_labels):
218 | # print(data.shape) # C, 1, T
219 | data = data.reshape(data.shape[0], data.shape[2]) # Channel, Timestep
220 | data = np.transpose(data, (1, 0)) # T, C
221 | data = np.asarray(data)
222 | for i in range(augment_times):
223 |
224 | aug_data = tsaug.Quantize(n_levels=[10, 20, 30]).augment(data)
225 | aug_data = tsaug.Drift(max_drift=(0.1, 0.5)).augment(aug_data)
226 | # aug_data = my_augmenter(data) # T, C
227 | aug_data = np.transpose(aug_data, (1, 0)) # C, T
228 | aug_data = aug_data.reshape(aug_data.shape[0], 1, aug_data.shape[1]) # C, 1, T
229 | augment_data.append(aug_data)
230 | augment_labels.append(label)
231 |
232 | augment_data = np.array(augment_data)
233 | augment_labels = np.array(augment_labels)
234 | print(f'augment_data shape is {augment_data.shape}')
235 | print(f'augment_labels shape is {augment_labels.shape}')
236 | self.one_class_train_data = np.concatenate((augment_data, self.one_class_train_data), axis = 0)
237 | self.one_class_train_labels = np.concatenate((augment_labels, self.one_class_train_labels), axis = 0)
238 |
239 | print(f'return single class data and labels, class is {self.class_name}')
240 | print(f'train_data shape is {self.one_class_train_data.shape}, test_data shape is {self.one_class_test_data.shape}')
241 | print(f'train label shape is {self.one_class_train_labels.shape}, test data shape is {self.one_class_test_labels.shape}')
242 |
243 | def download_url(self, url, save_path, chunk_size=128):
244 | r = requests.get(url, stream=True)
245 | with open(save_path, 'wb') as fd:
246 | for chunk in r.iter_content(chunk_size=chunk_size):
247 | fd.write(chunk)
248 |
249 | def to_categorical(self, y, num_classes):
250 | """ 1-hot encodes a tensor """
251 | return np.eye(num_classes, dtype='uint8')[y]
252 |
253 |
254 | def _normalize(self, epoch):
255 | """ A helper method for the normalization method.
256 | Returns
257 | result: a normalized epoch
258 | """
259 | e = 1e-10
260 | result = (epoch - epoch.mean(axis=0)) / ((np.sqrt(epoch.var(axis=0)))+e)
261 | return result
262 |
263 | def _min_max_normalize(self, epoch):
264 |
265 | result = (epoch - min(epoch)) / (max(epoch) - min(epoch))
266 | return result
267 |
268 | def normalization(self, epochs):
269 | """ Normalizes each epoch e s.t mean(e) = 0 and var(e) = 1
270 | Args:
271 | epochs - Numpy structure of epochs
272 | Returns:
273 | epochs_n - mne data structure of normalized epochs (mean=0, var=1)
274 | """
275 | for i in range(epochs.shape[0]):
276 | for j in range(epochs.shape[1]):
277 | epochs[i,j,0,:] = self._normalize(epochs[i,j,0,:])
278 | # epochs[i,j,0,:] = self._min_max_normalize(epochs[i,j,0,:])
279 |
280 | return epochs
281 |
282 | def __len__(self):
283 |
284 | if self.data_mode == 'Train':
285 | if self.single_class:
286 | return len(self.one_class_train_labels)
287 | else:
288 | return len(self.y_train)
289 | else:
290 | if self.single_class:
291 | return len(self.one_class_test_labels)
292 | else:
293 | return len(self.y_test)
294 |
295 | def __getitem__(self, idx):
296 | if self.data_mode == 'Train':
297 | if self.single_class:
298 | return self.one_class_train_data[idx], self.one_class_train_labels[idx]
299 | else:
300 | return self.x_train[idx], self.y_train[idx]
301 | else:
302 | if self.single_class:
303 | return self.one_class_test_data[idx], self.one_class_test_labels[idx]
304 | else:
305 | return self.x_test[idx], self.y_test[idx]
306 |
307 |
--------------------------------------------------------------------------------
/functions.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2019-07-25
3 | # @Author : Xinyu Gong (xy_gong@tamu.edu)
4 | # @Link : None
5 | # @Version : 0.0
6 |
7 | import logging
8 | import operator
9 | import os
10 | from copy import deepcopy
11 |
12 | import numpy as np
13 | import torch
14 | import torch.nn as nn
15 | from imageio import imsave
16 | from utils.utils import make_grid, save_image
17 | from tqdm import tqdm
18 | import cv2
19 |
20 | # from utils.fid_score import calculate_fid_given_paths
21 | from utils.torch_fid_score import get_fid
22 | # from utils.inception_score import get_inception_scorepython exps/dist1_new_church256.py --node 0022 --rank 0sample
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 | def cur_stages(iter, args):
27 | """
28 | Return current stage.
29 | :param epoch: current epoch.
30 | :return: current stage
31 | """
32 | # if search_iter < self.grow_step1:
33 | # return 0
34 | # elif self.grow_step1 <= search_iter < self.grow_step2:
35 | # return 1
36 | # else:
37 | # return 2
38 | # for idx, grow_step in enumerate(args.grow_steps):
39 | # if iter < grow_step:
40 | # return idx
41 | # return len(args.grow_steps)
42 | idx = 0
43 | for i in range(len(args.grow_steps)):
44 | if iter >= args.grow_steps[i]:
45 | idx = i+1
46 | return idx
47 |
48 | def compute_gradient_penalty(D, real_samples, fake_samples, phi):
49 | """Calculates the gradient penalty loss for WGAN GP"""
50 | # Random weight term for interpolation between real and fake samples
51 | alpha = torch.Tensor(np.random.random((real_samples.size(0), 1, 1, 1))).to(real_samples.get_device())
52 | # Get random interpolation between real and fake samples
53 | interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
54 | d_interpolates = D(interpolates)
55 | fake = torch.ones([real_samples.shape[0], 1], requires_grad=False).to(real_samples.get_device())
56 | # Get gradient w.r.t. interpolates
57 | gradients = torch.autograd.grad(
58 | outputs=d_interpolates,
59 | inputs=interpolates,
60 | grad_outputs=fake,
61 | create_graph=True,
62 | retain_graph=True,
63 | only_inputs=True,
64 | )[0]
65 | gradients = gradients.reshape(gradients.size(0), -1)
66 | gradient_penalty = ((gradients.norm(2, dim=1) - phi) ** 2).mean()
67 | return gradient_penalty
68 |
69 |
70 | def train_d(args, gen_net: nn.Module, dis_net: nn.Module, dis_optimizer, train_loader, epoch, writer_dict,fixed_z, schedulers=None):
71 | writer = writer_dict['writer']
72 | # gen_step = 0
73 | # train mode
74 | dis_net.train()
75 |
76 | dis_optimizer.zero_grad()
77 |
78 | for iter_idx, (imgs, _) in enumerate(tqdm(train_loader)):
79 | global_steps = writer_dict['train_global_steps']
80 |
81 |
82 | # Adversarial ground truths
83 | real_imgs = imgs.type(torch.cuda.FloatTensor).cuda(args.gpu, non_blocking=True)
84 |
85 | # Sample noise as generator input
86 | z = torch.cuda.FloatTensor(np.random.normal(0, 1, (imgs.shape[0], args.latent_dim))).cuda(args.gpu, non_blocking=True)
87 |
88 | # ---------------------
89 | # Train Discriminator
90 | # ---------------------
91 |
92 |
93 | real_validity = dis_net(real_imgs)
94 | fake_imgs = gen_net(z).detach()
95 |
96 | assert fake_imgs.size() == real_imgs.size(), f"fake_imgs.size(): {fake_imgs.size()} real_imgs.size(): {real_imgs.size()}"
97 |
98 | fake_validity = dis_net(fake_imgs)
99 |
100 | # cal loss
101 | if args.loss == 'hinge':
102 | d_loss = 0
103 | d_loss = torch.mean(nn.ReLU(inplace=True)(1.0 - real_validity)) + \
104 | torch.mean(nn.ReLU(inplace=True)(1 + fake_validity))
105 | elif args.loss == 'standard':
106 | #soft label
107 | real_label = torch.full((imgs.shape[0],), 0.9, dtype=torch.float, device=real_imgs.get_device())
108 | fake_label = torch.full((imgs.shape[0],), 0.1, dtype=torch.float, device=real_imgs.get_device())
109 | real_validity = nn.Sigmoid()(real_validity.view(-1))
110 | fake_validity = nn.Sigmoid()(fake_validity.view(-1))
111 | d_real_loss = nn.BCELoss()(real_validity, real_label)
112 | d_fake_loss = nn.BCELoss()(fake_validity, fake_label)
113 | d_loss = d_real_loss + d_fake_loss
114 | elif args.loss == 'lsgan':
115 | if isinstance(fake_validity, list):
116 | d_loss = 0
117 | for real_validity_item, fake_validity_item in zip(real_validity, fake_validity):
118 | real_label = torch.full((real_validity_item.shape[0],real_validity_item.shape[1]), 1., dtype=torch.float, device=real_imgs.get_device())
119 | fake_label = torch.full((real_validity_item.shape[0],real_validity_item.shape[1]), 0., dtype=torch.float, device=real_imgs.get_device())
120 | d_real_loss = nn.MSELoss()(real_validity_item, real_label)
121 | d_fake_loss = nn.MSELoss()(fake_validity_item, fake_label)
122 | d_loss += d_real_loss + d_fake_loss
123 | else:
124 | real_label = torch.full((real_validity.shape[0],real_validity.shape[1]), 1., dtype=torch.float, device=real_imgs.get_device())
125 | fake_label = torch.full((real_validity.shape[0],real_validity.shape[1]), 0., dtype=torch.float, device=real_imgs.get_device())
126 | d_real_loss = nn.MSELoss()(real_validity, real_label)
127 | d_fake_loss = nn.MSELoss()(fake_validity, fake_label)
128 | d_loss = d_real_loss + d_fake_loss
129 | elif args.loss == 'wgangp':
130 | gradient_penalty = compute_gradient_penalty(dis_net, real_imgs, fake_imgs.detach(), args.phi)
131 | d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + gradient_penalty * 10 / (
132 | args.phi ** 2)
133 | elif args.loss == 'wgangp-mode':
134 | gradient_penalty = compute_gradient_penalty(dis_net, real_imgs, fake_imgs.detach(), args.phi)
135 | d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + gradient_penalty * 10 / (
136 | args.phi ** 2)
137 | elif args.loss == 'wgangp-eps':
138 | gradient_penalty = compute_gradient_penalty(dis_net, real_imgs, fake_imgs.detach(), args.phi)
139 | d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + gradient_penalty * 10 / (
140 | args.phi ** 2)
141 | d_loss += (torch.mean(real_validity) ** 2) * 1e-3
142 | else:
143 | raise NotImplementedError(args.loss)
144 | d_loss = d_loss/float(args.accumulated_times)
145 | d_loss.backward()
146 |
147 | if (iter_idx + 1) % args.accumulated_times == 0:
148 | torch.nn.utils.clip_grad_norm_(dis_net.parameters(), 5.)
149 | dis_optimizer.step()
150 | dis_optimizer.zero_grad()
151 |
152 | writer.add_scalar('d_loss', d_loss.item(), global_steps) if args.rank == 0 else 0
153 |
154 |
155 | # # adjust learning rate
156 | # if schedulers:
157 | # gen_scheduler, dis_scheduler = schedulers
158 | # # g_lr = gen_scheduler.step(global_steps)
159 | # d_lr = dis_scheduler.step(global_steps)
160 | # # writer.add_scalar('LR/g_lr', g_lr, global_steps)
161 | # writer.add_scalar('LR/d_lr', d_lr, global_steps)
162 |
163 | # # moving average weight
164 | # ema_nimg = args.ema_kimg * 1000
165 | # cur_nimg = args.dis_batch_size * args.world_size * global_steps
166 | # if args.ema_warmup != 0:
167 | # ema_nimg = min(ema_nimg, cur_nimg * args.ema_warmup)
168 | # ema_beta = 0.5 ** (float(args.dis_batch_size * args.world_size) / max(ema_nimg, 1e-8))
169 | # else:
170 | # ema_beta = args.ema
171 |
172 | # # moving average weight
173 | # for p, avg_p in zip(gen_net.parameters(), gen_avg_param):
174 | # cpu_p = deepcopy(p)
175 | # avg_p.mul_(ema_beta).add_(1. - ema_beta, cpu_p.cpu().data)
176 | # del cpu_p
177 |
178 | # # writer.add_scalar('g_loss', g_loss.item(), global_steps) if args.rank == 0 else 0
179 | # # gen_step += 1
180 |
181 | # # verbose
182 | # if gen_step and iter_idx % args.print_freq == 0 and args.rank == 0:
183 | # sample_imgs = torch.cat((gen_imgs[:16], real_imgs[:16]), dim=0)
184 | # # scale_factor = args.img_size // int(sample_imgs.size(3))
185 | # # sample_imgs = torch.nn.functional.interpolate(sample_imgs, scale_factor=2)
186 | # # img_grid = make_grid(sample_imgs, nrow=4, normalize=True, scale_each=True)
187 | # # save_image(sample_imgs, f'sampled_images_{args.exp_name}.jpg', nrow=4, normalize=True, scale_each=True)
188 | # # writer.add_image(f'sampled_images_{args.exp_name}', img_grid, global_steps)
189 | # tqdm.write(
190 | # "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [ema: %f] " %
191 | # (epoch, args.max_epoch, iter_idx % len(train_loader), len(train_loader), d_loss.item(), ema_beta))
192 | # del gen_imgs
193 | # del real_imgs
194 | # del fake_validity
195 | # del real_validity
196 | # del d_loss
197 | tqdm.write( "[Epoch %d/%d] [Batch %d/%d] [D loss: %f]" %
198 | (epoch, args.max_epoch, iter_idx % len(train_loader), len(train_loader), d_loss.item()))
199 |
200 | writer_dict['train_global_steps'] = global_steps + 1
201 |
202 |
203 | def train(args, gen_net: nn.Module, dis_net: nn.Module, gen_optimizer, dis_optimizer, gen_avg_param, train_loader,
204 | epoch, writer_dict, fixed_z, schedulers=None):
205 | writer = writer_dict['writer']
206 | gen_step = 0
207 | # train mode
208 | gen_net.train()
209 | dis_net.train()
210 |
211 | dis_optimizer.zero_grad()
212 | gen_optimizer.zero_grad()
213 | for iter_idx, (imgs, _) in enumerate(tqdm(train_loader)):
214 | global_steps = writer_dict['train_global_steps']
215 |
216 |
217 | # Adversarial ground truths
218 | real_imgs = imgs.type(torch.cuda.FloatTensor).cuda(args.gpu, non_blocking=True)
219 |
220 | # Sample noise as generator input
221 | z = torch.cuda.FloatTensor(np.random.normal(0, 1, (imgs.shape[0], args.latent_dim))).cuda(args.gpu, non_blocking=True)
222 |
223 | # ---------------------
224 | # Train Discriminator
225 | # ---------------------
226 |
227 |
228 | real_validity = dis_net(real_imgs)
229 | fake_imgs = gen_net(z).detach()
230 |
231 | assert fake_imgs.size() == real_imgs.size(), f"fake_imgs.size(): {fake_imgs.size()} real_imgs.size(): {real_imgs.size()}"
232 |
233 | fake_validity = dis_net(fake_imgs)
234 |
235 | # cal loss
236 | if args.loss == 'hinge':
237 | d_loss = 0
238 | d_loss = torch.mean(nn.ReLU(inplace=True)(1.0 - real_validity)) + \
239 | torch.mean(nn.ReLU(inplace=True)(1 + fake_validity))
240 | elif args.loss == 'standard':
241 | #soft label
242 | real_label = torch.full((imgs.shape[0],), 0.9, dtype=torch.float, device=real_imgs.get_device())
243 | fake_label = torch.full((imgs.shape[0],), 0.1, dtype=torch.float, device=real_imgs.get_device())
244 | real_validity = nn.Sigmoid()(real_validity.view(-1))
245 | fake_validity = nn.Sigmoid()(fake_validity.view(-1))
246 | d_real_loss = nn.BCELoss()(real_validity, real_label)
247 | d_fake_loss = nn.BCELoss()(fake_validity, fake_label)
248 | d_loss = d_real_loss + d_fake_loss
249 | elif args.loss == 'lsgan':
250 | if isinstance(fake_validity, list):
251 | d_loss = 0
252 | for real_validity_item, fake_validity_item in zip(real_validity, fake_validity):
253 | real_label = torch.full((real_validity_item.shape[0],real_validity_item.shape[1]), 1., dtype=torch.float, device=real_imgs.get_device())
254 | fake_label = torch.full((real_validity_item.shape[0],real_validity_item.shape[1]), 0., dtype=torch.float, device=real_imgs.get_device())
255 | d_real_loss = nn.MSELoss()(real_validity_item, real_label)
256 | d_fake_loss = nn.MSELoss()(fake_validity_item, fake_label)
257 | d_loss += d_real_loss + d_fake_loss
258 | else:
259 | real_label = torch.full((real_validity.shape[0],real_validity.shape[1]), 1., dtype=torch.float, device=real_imgs.get_device())
260 | fake_label = torch.full((real_validity.shape[0],real_validity.shape[1]), 0., dtype=torch.float, device=real_imgs.get_device())
261 | d_real_loss = nn.MSELoss()(real_validity, real_label)
262 | d_fake_loss = nn.MSELoss()(fake_validity, fake_label)
263 | d_loss = d_real_loss + d_fake_loss
264 | elif args.loss == 'wgangp':
265 | gradient_penalty = compute_gradient_penalty(dis_net, real_imgs, fake_imgs.detach(), args.phi)
266 | d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + gradient_penalty * 10 / (
267 | args.phi ** 2)
268 | elif args.loss == 'wgangp-mode':
269 | gradient_penalty = compute_gradient_penalty(dis_net, real_imgs, fake_imgs.detach(), args.phi)
270 | d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + gradient_penalty * 10 / (
271 | args.phi ** 2)
272 | elif args.loss == 'wgangp-eps':
273 | gradient_penalty = compute_gradient_penalty(dis_net, real_imgs, fake_imgs.detach(), args.phi)
274 | d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + gradient_penalty * 10 / (
275 | args.phi ** 2)
276 | d_loss += (torch.mean(real_validity) ** 2) * 1e-3
277 | else:
278 | raise NotImplementedError(args.loss)
279 | d_loss = d_loss/float(args.accumulated_times)
280 | d_loss.backward()
281 |
282 | if (iter_idx + 1) % args.accumulated_times == 0:
283 | torch.nn.utils.clip_grad_norm_(dis_net.parameters(), 5.)
284 | dis_optimizer.step()
285 | dis_optimizer.zero_grad()
286 |
287 | writer.add_scalar('d_loss', d_loss.item(), global_steps) if args.rank == 0 else 0
288 |
289 | # -----------------
290 | # Train Generator
291 | # -----------------
292 | if global_steps % (args.n_critic * args.accumulated_times) == 0:
293 |
294 | for accumulated_idx in range(args.g_accumulated_times):
295 | gen_z = torch.cuda.FloatTensor(np.random.normal(0, 1, (args.gen_batch_size, args.latent_dim)))
296 | gen_imgs = gen_net(gen_z)
297 | fake_validity = dis_net(gen_imgs)
298 |
299 | # cal loss
300 | loss_lz = torch.tensor(0)
301 | if args.loss == "standard":
302 | real_label = torch.full((args.gen_batch_size,), 1., dtype=torch.float, device=real_imgs.get_device())
303 | fake_validity = nn.Sigmoid()(fake_validity.view(-1))
304 | g_loss = nn.BCELoss()(fake_validity.view(-1), real_label)
305 | if args.loss == "lsgan":
306 | if isinstance(fake_validity, list):
307 | g_loss = 0
308 | for fake_validity_item in fake_validity:
309 | real_label = torch.full((fake_validity_item.shape[0],fake_validity_item.shape[1]), 1., dtype=torch.float, device=real_imgs.get_device())
310 | g_loss += nn.MSELoss()(fake_validity_item, real_label)
311 | else:
312 | real_label = torch.full((fake_validity.shape[0],fake_validity.shape[1]), 1., dtype=torch.float, device=real_imgs.get_device())
313 | # fake_validity = nn.Sigmoid()(fake_validity.view(-1))
314 | g_loss = nn.MSELoss()(fake_validity, real_label)
315 | elif args.loss == 'wgangp-mode':
316 | fake_image1, fake_image2 = gen_imgs[:args.gen_batch_size//2], gen_imgs[args.gen_batch_size//2:]
317 | z_random1, z_random2 = gen_z[:args.gen_batch_size//2], gen_z[args.gen_batch_size//2:]
318 | lz = torch.mean(torch.abs(fake_image2 - fake_image1)) / torch.mean(
319 | torch.abs(z_random2 - z_random1))
320 | eps = 1 * 1e-5
321 | loss_lz = 1 / (lz + eps)
322 |
323 | g_loss = -torch.mean(fake_validity) + loss_lz
324 | else:
325 | g_loss = -torch.mean(fake_validity)
326 | g_loss = g_loss/float(args.g_accumulated_times)
327 | g_loss.backward()
328 |
329 | torch.nn.utils.clip_grad_norm_(gen_net.parameters(), 5.)
330 | gen_optimizer.step()
331 | gen_optimizer.zero_grad()
332 |
333 | # adjust learning rate
334 | if schedulers:
335 | gen_scheduler, dis_scheduler = schedulers
336 | g_lr = gen_scheduler.step(global_steps)
337 | d_lr = dis_scheduler.step(global_steps)
338 | writer.add_scalar('LR/g_lr', g_lr, global_steps)
339 | writer.add_scalar('LR/d_lr', d_lr, global_steps)
340 |
341 | # moving average weight
342 | ema_nimg = args.ema_kimg * 1000
343 | cur_nimg = args.dis_batch_size * args.world_size * global_steps
344 | if args.ema_warmup != 0:
345 | ema_nimg = min(ema_nimg, cur_nimg * args.ema_warmup)
346 | ema_beta = 0.5 ** (float(args.dis_batch_size * args.world_size) / max(ema_nimg, 1e-8))
347 | else:
348 | ema_beta = args.ema
349 |
350 | # moving average weight
351 | for p, avg_p in zip(gen_net.parameters(), gen_avg_param):
352 | cpu_p = deepcopy(p)
353 | avg_p.mul_(ema_beta).add_(1. - ema_beta, cpu_p.cpu().data)
354 | del cpu_p
355 |
356 | writer.add_scalar('g_loss', g_loss.item(), global_steps) if args.rank == 0 else 0
357 | gen_step += 1
358 |
359 | # verbose
360 | if gen_step and iter_idx % args.print_freq == 0 and args.rank == 0:
361 | sample_imgs = torch.cat((gen_imgs[:16], real_imgs[:16]), dim=0)
362 | # scale_factor = args.img_size // int(sample_imgs.size(3))
363 | # sample_imgs = torch.nn.functional.interpolate(sample_imgs, scale_factor=2)
364 | # img_grid = make_grid(sample_imgs, nrow=4, normalize=True, scale_each=True)
365 | # save_image(sample_imgs, f'sampled_images_{args.exp_name}.jpg', nrow=4, normalize=True, scale_each=True)
366 | # writer.add_image(f'sampled_images_{args.exp_name}', img_grid, global_steps)
367 | tqdm.write(
368 | "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [ema: %f] " %
369 | (epoch, args.max_epoch, iter_idx % len(train_loader), len(train_loader), d_loss.item(), g_loss.item(), ema_beta))
370 | del gen_imgs
371 | del real_imgs
372 | del fake_validity
373 | del real_validity
374 | del g_loss
375 | del d_loss
376 |
377 | writer_dict['train_global_steps'] = global_steps + 1
378 |
379 |
380 |
381 |
382 |
383 | def get_is(args, gen_net: nn.Module, num_img):
384 | """
385 | Get inception score.
386 | :param args:
387 | :param gen_net:
388 | :param num_img:
389 | :return: Inception score
390 | """
391 |
392 | # eval mode
393 | gen_net = gen_net.eval()
394 |
395 | eval_iter = num_img // args.eval_batch_size
396 | img_list = list()
397 | for _ in range(eval_iter):
398 | z = torch.cuda.FloatTensor(np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim)))
399 |
400 | # Generate a batch of images
401 | gen_imgs = gen_net(z).mul_(127.5).add_(127.5).clamp_(0.0, 255.0).permute(0, 2, 3, 1).to('cpu',
402 | torch.uint8).numpy()
403 | img_list.extend(list(gen_imgs))
404 |
405 | # get inception score
406 | logger.info('calculate Inception score...')
407 | mean, std = get_inception_score(img_list)
408 |
409 | return mean
410 |
411 |
412 | def validate(args, fixed_z, fid_stat, epoch, gen_net: nn.Module, writer_dict, clean_dir=True):
413 | writer = writer_dict['writer']
414 | global_steps = writer_dict['valid_global_steps']
415 |
416 | # eval mode
417 | gen_net.eval()
418 |
419 | # generate images
420 | # with torch.no_grad():
421 | # sample_imgs = gen_net(fixed_z, epoch)
422 | # img_grid = make_grid(sample_imgs, nrow=5, normalize=True, scale_each=True)
423 |
424 | # get fid and inception score
425 | # if args.gpu == 0:
426 | # fid_buffer_dir = os.path.join(args.path_helper['sample_path'], 'fid_buffer')
427 | # os.makedirs(fid_buffer_dir, exist_ok=True) if args.gpu == 0 else 0
428 |
429 | # eval_iter = args.num_eval_imgs // args.eval_batch_size
430 | # img_list = list()
431 | # for iter_idx in tqdm(range(eval_iter), desc='sample images'):
432 | # z = torch.cuda.FloatTensor(np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim)))
433 |
434 | # # Generate a batch of images
435 | # gen_imgs = gen_net(z, epoch).mul_(127.5).add_(127.5).clamp_(0.0, 255.0).permute(0, 2, 3, 1).to('cpu',
436 | # torch.uint8).numpy()
437 | # for img_idx, img in enumerate(gen_imgs):
438 | # file_name = os.path.join(fid_buffer_dir, f'iter{iter_idx}_b{img_idx}.png')
439 | # imsave(file_name, img)
440 | # img_list.extend(list(gen_imgs))
441 |
442 | # get inception score
443 | logger.info('=> calculate inception score') if args.rank == 0 else 0
444 | if args.rank == 0:
445 | # mean, std = get_inception_score(img_list)
446 | mean, std = 0, 0
447 | else:
448 | mean, std = 0, 0
449 | print(f"Inception score: {mean}") if args.rank == 0 else 0
450 | # mean, std = 0, 0
451 | # get fid score
452 | print('=> calculate fid score') if args.rank == 0 else 0
453 | if args.rank == 0:
454 | fid_score = get_fid(args, fid_stat, epoch, gen_net, args.num_eval_imgs, args.gen_batch_size, args.eval_batch_size, writer_dict=writer_dict, cls_idx=None)
455 | else:
456 | fid_score = 10000
457 | # fid_score = 10000
458 | print(f"FID score: {fid_score}") if args.rank == 0 else 0
459 |
460 | # if args.gpu == 0:
461 | # if clean_dir:
462 | # os.system('rm -r {}'.format(fid_buffer_dir))
463 | # else:
464 | # logger.info(f'=> sampled images are saved to {fid_buffer_dir}')
465 |
466 | # writer.add_image('sampled_images', img_grid, global_steps)
467 | if args.rank == 0:
468 | writer.add_scalar('Inception_score/mean', mean, global_steps)
469 | writer.add_scalar('Inception_score/std', std, global_steps)
470 | writer.add_scalar('FID_score', fid_score, global_steps)
471 |
472 | writer_dict['valid_global_steps'] = global_steps + 1
473 |
474 | return mean, fid_score
475 |
476 |
477 | def save_samples(args, fixed_z, fid_stat, epoch, gen_net: nn.Module, writer_dict, clean_dir=True):
478 |
479 | # eval mode
480 | gen_net.eval()
481 | with torch.no_grad():
482 | # generate images
483 | batch_size = fixed_z.size(0)
484 | sample_imgs = []
485 | for i in range(fixed_z.size(0)):
486 | sample_img = gen_net(fixed_z[i:(i+1)], epoch)
487 | sample_imgs.append(sample_img)
488 | sample_imgs = torch.cat(sample_imgs, dim=0)
489 | os.makedirs(f"./samples/{args.exp_name}", exist_ok=True)
490 | save_image(sample_imgs, f'./samples/{args.exp_name}/sampled_images_{epoch}.png', nrow=10, normalize=True, scale_each=True)
491 | return 0
492 |
493 |
494 | def get_topk_arch_hidden(args, controller, gen_net, prev_archs, prev_hiddens):
495 | """
496 | ~
497 | :param args:
498 | :param controller:
499 | :param gen_net:
500 | :param prev_archs: previous architecture
501 | :param prev_hiddens: previous hidden vector
502 | :return: a list of topk archs and hiddens.
503 | """
504 | logger.info(f'=> get top{args.topk} archs out of {args.num_candidate} candidate archs...')
505 | assert args.num_candidate >= args.topk
506 | controller.eval()
507 | cur_stage = controller.cur_stage
508 | archs, _, _, hiddens = controller.sample(args.num_candidate, with_hidden=True, prev_archs=prev_archs,
509 | prev_hiddens=prev_hiddens)
510 | hxs, cxs = hiddens
511 | arch_idx_perf_table = {}
512 | for arch_idx in range(len(archs)):
513 | logger.info(f'arch: {archs[arch_idx]}')
514 | gen_net.set_arch(archs[arch_idx], cur_stage)
515 | is_score = get_is(args, gen_net, args.rl_num_eval_img)
516 | logger.info(f'get Inception score of {is_score}')
517 | arch_idx_perf_table[arch_idx] = is_score
518 | topk_arch_idx_perf = sorted(arch_idx_perf_table.items(), key=operator.itemgetter(1))[::-1][:args.topk]
519 | topk_archs = []
520 | topk_hxs = []
521 | topk_cxs = []
522 | logger.info(f'top{args.topk} archs:')
523 | for arch_idx_perf in topk_arch_idx_perf:
524 | logger.info(arch_idx_perf)
525 | arch_idx = arch_idx_perf[0]
526 | topk_archs.append(archs[arch_idx])
527 | topk_hxs.append(hxs[arch_idx].detach().requires_grad_(False))
528 | topk_cxs.append(cxs[arch_idx].detach().requires_grad_(False))
529 |
530 | return topk_archs, (topk_hxs, topk_cxs)
531 |
532 |
533 | class LinearLrDecay(object):
534 | def __init__(self, optimizer, start_lr, end_lr, decay_start_step, decay_end_step):
535 |
536 | assert start_lr > end_lr
537 | self.optimizer = optimizer
538 | self.delta = (start_lr - end_lr) / (decay_end_step - decay_start_step)
539 | self.decay_start_step = decay_start_step
540 | self.decay_end_step = decay_end_step
541 | self.start_lr = start_lr
542 | self.end_lr = end_lr
543 |
544 | def step(self, current_step):
545 | if current_step <= self.decay_start_step:
546 | lr = self.start_lr
547 | elif current_step >= self.decay_end_step:
548 | lr = self.end_lr
549 | else:
550 | lr = self.start_lr - self.delta * (current_step - self.decay_start_step)
551 | for param_group in self.optimizer.param_groups:
552 | param_group['lr'] = lr
553 | return lr
554 |
555 | def load_params(model, new_param, args, mode="gpu"):
556 | if mode == "cpu":
557 | for p, new_p in zip(model.parameters(), new_param):
558 | cpu_p = deepcopy(new_p)
559 | # p.data.copy_(cpu_p.cuda().to(f"cuda:{args.gpu}"))
560 | p.data.copy_(cpu_p.cuda().to("cpu"))
561 | del cpu_p
562 |
563 | else:
564 | for p, new_p in zip(model.parameters(), new_param):
565 | p.data.copy_(new_p)
566 |
567 |
568 | def copy_params(model, mode='cpu'):
569 | if mode == 'gpu':
570 | flatten = []
571 | for p in model.parameters():
572 | cpu_p = deepcopy(p).cpu()
573 | flatten.append(cpu_p.data)
574 | else:
575 | flatten = deepcopy(list(p.data for p in model.parameters()))
576 | return flatten
--------------------------------------------------------------------------------
/images/PositionalEncoding.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imics-lab/tts-gan/3f8b36ab84d1c00d48021d6e7c5dbd461686844e/images/PositionalEncoding.pdf
--------------------------------------------------------------------------------
/images/PositionalEncoding.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imics-lab/tts-gan/3f8b36ab84d1c00d48021d6e7c5dbd461686844e/images/PositionalEncoding.png
--------------------------------------------------------------------------------
/images/TTS-GAN.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imics-lab/tts-gan/3f8b36ab84d1c00d48021d6e7c5dbd461686844e/images/TTS-GAN.pdf
--------------------------------------------------------------------------------
/images/TTS-GAN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imics-lab/tts-gan/3f8b36ab84d1c00d48021d6e7c5dbd461686844e/images/TTS-GAN.png
--------------------------------------------------------------------------------
/pre-trained-models/JumpingGAN_checkpoint:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imics-lab/tts-gan/3f8b36ab84d1c00d48021d6e7c5dbd461686844e/pre-trained-models/JumpingGAN_checkpoint
--------------------------------------------------------------------------------
/pre-trained-models/RunningGAN_checkpoint:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imics-lab/tts-gan/3f8b36ab84d1c00d48021d6e7c5dbd461686844e/pre-trained-models/RunningGAN_checkpoint
--------------------------------------------------------------------------------
/train_GAN.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import cfg
6 | # import models_search
7 | # import datasets
8 | from dataLoader import *
9 | from GANModels import *
10 | from functions import train, train_d, validate, save_samples, LinearLrDecay, load_params, copy_params, cur_stages
11 | from utils.utils import set_log_dir, save_checkpoint, create_logger
12 | # from utils.inception_score import _init_inception
13 | # from utils.fid_score import create_inception_graph, check_or_download_inception
14 |
15 | import torch
16 | import torch.multiprocessing as mp
17 | import torch.distributed as dist
18 | import torch.utils.data.distributed
19 | from torch.utils import data
20 | import os
21 | import numpy as np
22 | import torch.nn as nn
23 | # from tensorboardX import SummaryWriter
24 | from torch.utils.tensorboard import SummaryWriter
25 | from tqdm import tqdm
26 | from copy import deepcopy
27 | from adamw import AdamW
28 | import random
29 | import matplotlib.pyplot as plt
30 | import io
31 | import PIL.Image
32 | from torchvision.transforms import ToTensor
33 |
34 | # torch.backends.cudnn.enabled = True
35 | # torch.backends.cudnn.benchmark = True
36 |
37 |
38 | def main():
39 | args = cfg.parse_args()
40 |
41 | # _init_inception()
42 | # inception_path = check_or_download_inception(None)
43 | # create_inception_graph(inception_path)
44 |
45 | if args.seed is not None:
46 | torch.manual_seed(args.random_seed)
47 | torch.cuda.manual_seed(args.random_seed)
48 | torch.cuda.manual_seed_all(args.random_seed)
49 | np.random.seed(args.random_seed)
50 | random.seed(args.random_seed)
51 | torch.backends.cudnn.benchmark = False
52 | torch.backends.cudnn.deterministic = True
53 |
54 | if args.gpu is not None:
55 | warnings.warn('You have chosen a specific GPU. This will completely '
56 | 'disable data parallelism.')
57 |
58 | if args.dist_url == "env://" and args.world_size == -1:
59 | args.world_size = int(os.environ["WORLD_SIZE"])
60 |
61 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
62 |
63 | ngpus_per_node = torch.cuda.device_count()
64 | if args.multiprocessing_distributed:
65 | # Since we have ngpus_per_node processes per node, the total world_size
66 | # needs to be adjusted accordingly
67 | args.world_size = ngpus_per_node * args.world_size
68 | # Use torch.multiprocessing.spawn to launch distributed processes: the
69 | # main_worker process function
70 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
71 | else:
72 | # Simply call main_worker function
73 | main_worker(args.gpu, ngpus_per_node, args)
74 |
75 | def main_worker(gpu, ngpus_per_node, args):
76 | args.gpu = gpu
77 |
78 | if args.gpu is not None:
79 | print("Use GPU: {} for training".format(args.gpu))
80 |
81 | if args.distributed:
82 | if args.dist_url == "env://" and args.rank == -1:
83 | args.rank = int(os.environ["RANK"])
84 | if args.multiprocessing_distributed:
85 | # For multiprocessing distributed training, rank needs to be the
86 | # global rank among all the processes
87 | args.rank = args.rank * ngpus_per_node + gpu
88 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
89 | world_size=args.world_size, rank=args.rank)
90 | # weight init
91 | def weights_init(m):
92 | classname = m.__class__.__name__
93 | if classname.find('Conv2d') != -1:
94 | if args.init_type == 'normal':
95 | nn.init.normal_(m.weight.data, 0.0, 0.02)
96 | elif args.init_type == 'orth':
97 | nn.init.orthogonal_(m.weight.data)
98 | elif args.init_type == 'xavier_uniform':
99 | nn.init.xavier_uniform(m.weight.data, 1.)
100 | else:
101 | raise NotImplementedError('{} unknown inital type'.format(args.init_type))
102 | # elif classname.find('Linear') != -1:
103 | # if args.init_type == 'normal':
104 | # nn.init.normal_(m.weight.data, 0.0, 0.02)
105 | # elif args.init_type == 'orth':
106 | # nn.init.orthogonal_(m.weight.data)
107 | # elif args.init_type == 'xavier_uniform':
108 | # nn.init.xavier_uniform(m.weight.data, 1.)
109 | # else:
110 | # raise NotImplementedError('{} unknown inital type'.format(args.init_type))
111 | elif classname.find('BatchNorm2d') != -1:
112 | nn.init.normal_(m.weight.data, 1.0, 0.02)
113 | nn.init.constant_(m.bias.data, 0.0)
114 |
115 | # import network
116 |
117 | gen_net = Generator()
118 | print(gen_net)
119 | dis_net = Discriminator()
120 | print(dis_net)
121 | if not torch.cuda.is_available():
122 | print('using CPU, this will be slow')
123 | elif args.distributed:
124 | # For multiprocessing distributed, DistributedDataParallel constructor
125 | # should always set the single device scope, otherwise,
126 | # DistributedDataParallel will use all available devices.
127 | if args.gpu is not None:
128 | torch.cuda.set_device(args.gpu)
129 | # gen_net = eval('models_search.'+args.gen_model+'.Generator')(args=args)
130 | # dis_net = eval('models_search.'+args.dis_model+'.Discriminator')(args=args)
131 |
132 | gen_net.apply(weights_init)
133 | dis_net.apply(weights_init)
134 | gen_net.cuda(args.gpu)
135 | dis_net.cuda(args.gpu)
136 | # When using a single GPU per process and per
137 | # DistributedDataParallel, we need to divide the batch size
138 | # ourselves based on the total number of GPUs we have
139 | args.dis_batch_size = int(args.dis_batch_size / ngpus_per_node)
140 | args.gen_batch_size = int(args.gen_batch_size / ngpus_per_node)
141 | args.batch_size = args.dis_batch_size
142 |
143 | args.num_workers = int((args.num_workers + ngpus_per_node - 1) / ngpus_per_node)
144 | gen_net = torch.nn.parallel.DistributedDataParallel(gen_net, device_ids=[args.gpu], find_unused_parameters=True)
145 | dis_net = torch.nn.parallel.DistributedDataParallel(dis_net, device_ids=[args.gpu], find_unused_parameters=True)
146 | else:
147 | gen_net.cuda()
148 | dis_net.cuda()
149 | # DistributedDataParallel will divide and allocate batch_size to all
150 | # available GPUs if device_ids are not set
151 | gen_net = torch.nn.parallel.DistributedDataParallel(gen_net)
152 | dis_net = torch.nn.parallel.DistributedDataParallel(dis_net)
153 | elif args.gpu is not None:
154 | torch.cuda.set_device(args.gpu)
155 | gen_net.cuda(args.gpu)
156 | dis_net.cuda(args.gpu)
157 | else:
158 | gen_net = torch.nn.DataParallel(gen_net).cuda()
159 | dis_net = torch.nn.DataParallel(dis_net).cuda()
160 | print(dis_net) if args.rank == 0 else 0
161 |
162 |
163 | # set optimizer
164 | if args.optimizer == "adam":
165 | gen_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, gen_net.parameters()),
166 | args.g_lr, (args.beta1, args.beta2))
167 | dis_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, dis_net.parameters()),
168 | args.d_lr, (args.beta1, args.beta2))
169 | elif args.optimizer == "adamw":
170 | gen_optimizer = AdamW(filter(lambda p: p.requires_grad, gen_net.parameters()),
171 | args.g_lr, weight_decay=args.wd)
172 | dis_optimizer = AdamW(filter(lambda p: p.requires_grad, dis_net.parameters()),
173 | args.g_lr, weight_decay=args.wd)
174 |
175 | gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0, args.max_iter * args.n_critic)
176 | dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0, args.max_iter * args.n_critic)
177 |
178 | # fid stat
179 | # if args.dataset.lower() == 'cifar10':
180 | # fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
181 | # elif args.dataset.lower() == 'stl10':
182 | # fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz'
183 | # elif args.fid_stat is not None:
184 | # fid_stat = args.fid_stat
185 | # else:
186 | # raise NotImplementedError(f'no fid stat for {args.dataset.lower()}')
187 | # assert os.path.exists(fid_stat)
188 |
189 |
190 | # epoch number for dis_net
191 | args.max_epoch = args.max_epoch * args.n_critic
192 | # dataset = datasets.ImageDataset(args, cur_img_size=8)
193 | # train_loader = dataset.train
194 | # train_sampler = dataset.train_sampler
195 |
196 | # train_set = unimib_load_dataset(incl_xyz_accel = True, incl_rms_accel = False, incl_val_group = False, one_hot_encode = False, data_mode = 'Train')
197 | # test_set = unimib_load_dataset(incl_xyz_accel = True, incl_rms_accel = False, incl_val_group = False, one_hot_encode = False, data_mode = 'Test')
198 | # train_loader = data.DataLoader(train_set, batch_size=args.dis_batch_size, num_workers=args.num_workers, shuffle=True)
199 | # test_loader = data.DataLoader(test_set, batch_size=args.dis_batch_size, num_workers=args.num_workers, shuffle=True)
200 |
201 | train_set = unimib_load_dataset(incl_xyz_accel = True, incl_rms_accel = False, incl_val_group = False, is_normalize = True, one_hot_encode = False, data_mode = 'Train', single_class = True, class_name = args.class_name, augment_times=args.augment_times)
202 | train_loader = data.DataLoader(train_set, batch_size=args.batch_size, num_workers=args.num_workers, shuffle = True)
203 | test_set = unimib_load_dataset(incl_xyz_accel = True, incl_rms_accel = False, incl_val_group = False, is_normalize = True, one_hot_encode = False, data_mode = 'Test', single_class = True, class_name = args.class_name)
204 | test_loader = data.DataLoader(test_set, batch_size=args.batch_size, num_workers=args.num_workers, shuffle = True)
205 |
206 | print(len(train_loader))
207 |
208 | if args.max_iter:
209 | args.max_epoch = np.ceil(args.max_iter * args.n_critic / len(train_loader))
210 |
211 | # initial
212 | fixed_z = torch.cuda.FloatTensor(np.random.normal(0, 1, (100, args.latent_dim)))
213 | avg_gen_net = deepcopy(gen_net).cpu()
214 | gen_avg_param = copy_params(avg_gen_net)
215 | del avg_gen_net
216 | start_epoch = 0
217 | best_fid = 1e4
218 |
219 | # set writer
220 | writer = None
221 | if args.load_path:
222 | print(f'=> resuming from {args.load_path}')
223 | assert os.path.exists(args.load_path)
224 | checkpoint_file = os.path.join(args.load_path)
225 | assert os.path.exists(checkpoint_file)
226 | loc = 'cuda:{}'.format(args.gpu)
227 | checkpoint = torch.load(checkpoint_file, map_location=loc)
228 | start_epoch = checkpoint['epoch']
229 | best_fid = checkpoint['best_fid']
230 |
231 |
232 | dis_net.load_state_dict(checkpoint['dis_state_dict'])
233 | gen_optimizer.load_state_dict(checkpoint['gen_optimizer'])
234 | dis_optimizer.load_state_dict(checkpoint['dis_optimizer'])
235 |
236 | # avg_gen_net = deepcopy(gen_net)
237 | gen_net.load_state_dict(checkpoint['avg_gen_state_dict'])
238 | gen_avg_param = copy_params(gen_net, mode='gpu')
239 | gen_net.load_state_dict(checkpoint['gen_state_dict'])
240 | fixed_z = checkpoint['fixed_z']
241 | # del avg_gen_net
242 | # gen_avg_param = list(p.cuda().to(f"cuda:{args.gpu}") for p in gen_avg_param)
243 |
244 |
245 |
246 | args.path_helper = checkpoint['path_helper']
247 | logger = create_logger(args.path_helper['log_path']) if args.rank == 0 else None
248 | print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})')
249 | writer = SummaryWriter(args.path_helper['log_path']) if args.rank == 0 else None
250 | del checkpoint
251 | else:
252 | # create new log dir
253 | assert args.exp_name
254 | if args.rank == 0:
255 | args.path_helper = set_log_dir('logs', args.exp_name)
256 | logger = create_logger(args.path_helper['log_path'])
257 | writer = SummaryWriter(args.path_helper['log_path'])
258 |
259 | if args.rank == 0:
260 | logger.info(args)
261 | writer_dict = {
262 | 'writer': writer,
263 | 'train_global_steps': start_epoch * len(train_loader),
264 | 'valid_global_steps': start_epoch // args.val_freq,
265 | }
266 |
267 | # train loop
268 | for epoch in range(int(start_epoch), int(args.max_epoch)):
269 | # train_sampler.set_epoch(epoch)
270 | lr_schedulers = (gen_scheduler, dis_scheduler) if args.lr_decay else None
271 | cur_stage = cur_stages(epoch, args)
272 | print("cur_stage " + str(cur_stage)) if args.rank==0 else 0
273 | print(f"path: {args.path_helper['prefix']}") if args.rank==0 else 0
274 |
275 | # if (epoch+1) % 3 == 0:
276 | # # train discriminator and generator both
277 | # train(args, gen_net, dis_net, gen_optimizer, dis_optimizer, gen_avg_param, train_loader, epoch, writer_dict,fixed_z, lr_schedulers)
278 | # else:
279 | # #only train discriminator
280 | # train_d(args, gen_net, dis_net, dis_optimizer, train_loader, epoch, writer_dict,fixed_z, lr_schedulers)
281 | train(args, gen_net, dis_net, gen_optimizer, dis_optimizer, gen_avg_param, train_loader, epoch, writer_dict,fixed_z, lr_schedulers)
282 |
283 | if args.rank == 0 and args.show:
284 | backup_param = copy_params(gen_net)
285 | load_params(gen_net, gen_avg_param, args, mode="cpu")
286 | save_samples(args, fixed_z, fid_stat, epoch, gen_net, writer_dict)
287 | load_params(gen_net, backup_param, args)
288 |
289 | #fid_stat is not defined It doesn't make sense to use image evaluate matrics
290 | # if epoch and epoch % args.val_freq == 0 or epoch == int(args.max_epoch)-1:
291 | # backup_param = copy_params(gen_net)
292 | # load_params(gen_net, gen_avg_param, args, mode="cpu")
293 | # inception_score, fid_score = validate(args, fixed_z, fid_stat, epoch, gen_net, writer_dict)
294 | # if args.rank==0:
295 | # logger.info(f'Inception score: {inception_score}, FID score: {fid_score} || @ epoch {epoch}.')
296 | # load_params(gen_net, backup_param, args)
297 | # if fid_score < best_fid:
298 | # best_fid = fid_score
299 | # is_best = True
300 | # else:
301 | # is_best = False
302 | # else:
303 | # is_best = False
304 |
305 | #TO DO: Validate add synthetic data plot in tensorboard
306 | #Plot synthetic data every 5 epochs
307 | # if epoch and epoch % 1 == 0:
308 | gen_net.eval()
309 | plot_buf = gen_plot(gen_net, epoch, args.class_name)
310 | image = PIL.Image.open(plot_buf)
311 | image = ToTensor()(image).unsqueeze(0)
312 | #writer = SummaryWriter(comment='synthetic signals')
313 | writer.add_image('Image', image[0], epoch)
314 |
315 | is_best = False
316 | avg_gen_net = deepcopy(gen_net)
317 | load_params(avg_gen_net, gen_avg_param, args)
318 | # if not args.multiprocessing_distributed or (args.multiprocessing_distributed
319 | # and args.rank == 0):
320 | # Add module in model saving code exp'gen_net.module.state_dict()' to solve the model loading unpaired name problem
321 | save_checkpoint({
322 | 'epoch': epoch + 1,
323 | 'gen_model': args.gen_model,
324 | 'dis_model': args.dis_model,
325 | 'gen_state_dict': gen_net.module.state_dict(),
326 | 'dis_state_dict': dis_net.module.state_dict(),
327 | 'avg_gen_state_dict': avg_gen_net.module.state_dict(),
328 | 'gen_optimizer': gen_optimizer.state_dict(),
329 | 'dis_optimizer': dis_optimizer.state_dict(),
330 | 'best_fid': best_fid,
331 | 'path_helper': args.path_helper,
332 | 'fixed_z': fixed_z
333 | }, is_best, args.path_helper['ckpt_path'], filename="checkpoint")
334 | del avg_gen_net
335 |
336 | def gen_plot(gen_net, epoch, class_name):
337 | """Create a pyplot plot and save to buffer."""
338 | synthetic_data = []
339 |
340 | for i in range(10):
341 | fake_noise = torch.FloatTensor(np.random.normal(0, 1, (1, 100)))
342 | fake_sigs = gen_net(fake_noise).to('cpu').detach().numpy()
343 | synthetic_data.append(fake_sigs)
344 |
345 | fig, axs = plt.subplots(2, 5, figsize=(20,5))
346 | fig.suptitle(f'Synthetic {class_name} at epoch {epoch}', fontsize=30)
347 | for i in range(2):
348 | for j in range(5):
349 | axs[i, j].plot(synthetic_data[i*5+j][0][0][0][:])
350 | axs[i, j].plot(synthetic_data[i*5+j][0][1][0][:])
351 | axs[i, j].plot(synthetic_data[i*5+j][0][2][0][:])
352 | buf = io.BytesIO()
353 | plt.savefig(buf, format='jpeg')
354 | buf.seek(0)
355 | return buf
356 |
357 | if __name__ == '__main__':
358 | main()
359 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2019-07-25
3 | # @Author : Xinyu Gong (xy_gong@tamu.edu)
4 | # @Link : None
5 | # @Version : 0.0
6 |
7 | from __future__ import absolute_import
8 | from __future__ import division
9 | from __future__ import print_function
10 |
11 | from utils import utils
12 |
--------------------------------------------------------------------------------
/utils/cal_fid_stat.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2019-07-26
3 | # @Author : Xinyu Gong (xy_gong@tamu.edu)
4 | # @Link : None
5 | # @Version : 0.0
6 |
7 |
8 | import os
9 | import glob
10 | import argparse
11 | import numpy as np
12 | from imageio import imread
13 | import tensorflow as tf
14 |
15 | import utils.fid_score as fid
16 |
17 |
18 | def parse_args():
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument(
21 | '--data_path',
22 | type=str,
23 | required=True,
24 | help='set path to training set jpg images dir')
25 | parser.add_argument(
26 | '--output_file',
27 | type=str,
28 | default='fid_stat/fid_stats_cifar10_train.npz',
29 | help='path for where to store the statistics')
30 |
31 | opt = parser.parse_args()
32 | print(opt)
33 | return opt
34 |
35 |
36 | def main():
37 | args = parse_args()
38 |
39 | ########
40 | # PATHS
41 | ########
42 | data_path = args.data_path
43 | output_path = args.output_file
44 | # if you have downloaded and extracted
45 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
46 | # set this path to the directory where the extracted files are, otherwise
47 | # just set it to None and the script will later download the files for you
48 | inception_path = None
49 | print("check for inception model..", end=" ", flush=True)
50 | inception_path = fid.check_or_download_inception(inception_path) # download inception if necessary
51 | print("ok")
52 |
53 | # loads all images into memory (this might require a lot of RAM!)
54 | print("load images..", end=" ", flush=True)
55 | image_list = glob.glob(os.path.join(data_path, '*.jpg'))
56 | images = np.array([imread(str(fn)).astype(np.float32) for fn in image_list])
57 | print("%d images found and loaded" % len(images))
58 |
59 | print("create inception graph..", end=" ", flush=True)
60 | fid.create_inception_graph(inception_path) # load the graph into the current TF graph
61 | print("ok")
62 |
63 | print("calculte FID stats..", end=" ", flush=True)
64 | config = tf.ConfigProto()
65 | config.gpu_options.allow_growth = True
66 | with tf.Session(config=config) as sess:
67 | sess.run(tf.global_variables_initializer())
68 | mu, sigma = fid.calculate_activation_statistics(images, sess, batch_size=100)
69 | np.savez_compressed(output_path, mu=mu, sigma=sigma)
70 | print("finished")
71 |
72 |
73 | if __name__ == '__main__':
74 | main()
75 |
--------------------------------------------------------------------------------
/utils/fid_score.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """ Calculates the Frechet Inception Distance (FID) to evaluate GANs.
3 |
4 | The FID metric calculates the distance between two distributions of images.
5 | Typically, we have summary statistics (mean & covariance matrix) of one
6 | of these distributions, while the 2nd distribution is given by a GAN.
7 |
8 | When run as a stand-alone program, it compares the distribution of
9 | images that are stored as PNG/JPEG at a specified location with a
10 | distribution given by summary statistics (in pickle format).
11 |
12 | The FID is calculated by assuming that X_1 and X_2 are the activations of
13 | the pool_3 layer of the inception net for generated samples and real world
14 | samples respectively.
15 |
16 | See --help to see further details.
17 | """
18 |
19 | from __future__ import absolute_import, division, print_function
20 |
21 | import os
22 | import pathlib
23 | import warnings
24 |
25 | import numpy as np
26 | import tensorflow.compat.v1 as tf
27 | tf.disable_v2_behavior()
28 |
29 | from scipy import linalg
30 | from imageio import imread
31 | from tqdm import tqdm
32 |
33 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
34 |
35 |
36 | class InvalidFIDException(Exception):
37 | pass
38 |
39 |
40 | def create_inception_graph(pth):
41 | """Creates a graph from saved GraphDef file."""
42 | # Creates graph from saved graph_def.pb.
43 | with tf.gfile.FastGFile(pth, 'rb') as f:
44 | graph_def = tf.GraphDef()
45 | graph_def.ParseFromString(f.read())
46 | _ = tf.import_graph_def(graph_def, name='FID_Inception_Net')
47 |
48 |
49 | # -------------------------------------------------------------------------------
50 |
51 |
52 | # code for handling inception net derived from
53 | # https://github.com/openai/improved-gan/blob/master/inception_score/model.py
54 | def _get_inception_layer(sess):
55 | """Prepares inception net for batched usage and returns pool_3 layer. """
56 | layername = 'FID_Inception_Net/pool_3:0'
57 | pool3 = sess.graph.get_tensor_by_name(layername)
58 | ops = pool3.graph.get_operations()
59 | for op_idx, op in enumerate(ops):
60 | for o in op.outputs:
61 | shape = o.get_shape()
62 | if shape._dims != []:
63 | shape = [s.value for s in shape]
64 | new_shape = []
65 | for j, s in enumerate(shape):
66 | if s == 1 and j == 0:
67 | new_shape.append(None)
68 | else:
69 | new_shape.append(s)
70 | o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
71 | return pool3
72 |
73 |
74 | # -------------------------------------------------------------------------------
75 |
76 |
77 | def get_activations(images, sess, batch_size=16, verbose=False):
78 | """Calculates the activations of the pool_3 layer for all images.
79 |
80 | Params:
81 | -- images : Numpy array of dimension (n_images, hi, wi, 3). The values
82 | must lie between 0 and 256.
83 | -- sess : current session
84 | -- batch_size : the images numpy array is split into batches with batch size
85 | batch_size. A reasonable batch size depends on the disposable hardware.
86 | -- verbose : If set to True and parameter out_step is given, the number of calculated
87 | batches is reported.
88 | Returns:
89 | -- A numpy array of dimension (num images, 2048) that contains the
90 | activations of the given tensor when feeding inception with the query tensor.
91 | """
92 | inception_layer = _get_inception_layer(sess)
93 | d0 = len(images)
94 | if batch_size > d0:
95 | print("warning: batch size is bigger than the data size. setting batch size to data size")
96 | batch_size = d0
97 | n_batches = d0 // batch_size
98 | n_used_imgs = n_batches * batch_size
99 | pred_arr = np.empty((n_used_imgs, 2048))
100 | for i in tqdm(range(n_batches)):
101 | if verbose:
102 | print("\rPropagating batch %d/%d" % (i + 1, n_batches), end="", flush=True)
103 | start = i * batch_size
104 | end = start + batch_size
105 | batch = images[start:end]
106 | pred = sess.run(inception_layer, {'FID_Inception_Net/ExpandDims:0': batch})
107 | pred_arr[start:end] = pred.reshape(batch_size, -1)
108 | if verbose:
109 | print(" done")
110 | return pred_arr
111 |
112 |
113 | # -------------------------------------------------------------------------------
114 |
115 |
116 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
117 | """Numpy implementation of the Frechet Distance.
118 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
119 | and X_2 ~ N(mu_2, C_2) is
120 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
121 |
122 | Stable version by Dougal J. Sutherland.
123 |
124 | Params:
125 | -- mu1 : Numpy array containing the activations of the pool_3 layer of the
126 | inception net ( like returned by the function 'get_predictions')
127 | for generated samples.
128 | -- mu2 : The sample mean over activations of the pool_3 layer, precalcualted
129 | on an representive data set.
130 | -- sigma1: The covariance matrix over activations of the pool_3 layer for
131 | generated samples.
132 | -- sigma2: The covariance matrix over activations of the pool_3 layer,
133 | precalcualted on an representive data set.
134 |
135 | Returns:
136 | -- : The Frechet Distance.
137 | """
138 |
139 | mu1 = np.atleast_1d(mu1)
140 | mu2 = np.atleast_1d(mu2)
141 |
142 | sigma1 = np.atleast_2d(sigma1)
143 | sigma2 = np.atleast_2d(sigma2)
144 |
145 | assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
146 | assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"
147 |
148 | diff = mu1 - mu2
149 |
150 | # product might be almost singular
151 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
152 | if not np.isfinite(covmean).all():
153 | msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps
154 | warnings.warn(msg)
155 | offset = np.eye(sigma1.shape[0]) * eps
156 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
157 |
158 | # numerical error might give slight imaginary component
159 | if np.iscomplexobj(covmean):
160 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
161 | m = np.max(np.abs(covmean.imag))
162 | raise ValueError("Imaginary component {}".format(m))
163 | covmean = covmean.real
164 |
165 | tr_covmean = np.trace(covmean)
166 |
167 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
168 |
169 |
170 | # -------------------------------------------------------------------------------
171 |
172 |
173 | def calculate_activation_statistics(images, sess, batch_size=16, verbose=False):
174 | """Calculation of the statistics used by the FID.
175 | Params:
176 | -- images : Numpy array of dimension (n_images, hi, wi, 3). The values
177 | must lie between 0 and 255.
178 | -- sess : current session
179 | -- batch_size : the images numpy array is split into batches with batch size
180 | batch_size. A reasonable batch size depends on the available hardware.
181 | -- verbose : If set to True and parameter out_step is given, the number of calculated
182 | batches is reported.
183 | Returns:
184 | -- mu : The mean over samples of the activations of the pool_3 layer of
185 | the incption model.
186 | -- sigma : The covariance matrix of the activations of the pool_3 layer of
187 | the incption model.
188 | """
189 | act = get_activations(images, sess, batch_size, verbose)
190 | mu = np.mean(act, axis=0)
191 | sigma = np.cov(act, rowvar=False)
192 | return mu, sigma
193 |
194 |
195 | # ------------------
196 | # The following methods are implemented to obtain a batched version of the activations.
197 | # This has the advantage to reduce memory requirements, at the cost of slightly reduced efficiency.
198 | # - Pyrestone
199 | # ------------------
200 |
201 |
202 | def load_image_batch(files):
203 | """Convenience method for batch-loading images
204 | Params:
205 | -- files : list of paths to image files. Images need to have same dimensions for all files.
206 | Returns:
207 | -- A numpy array of dimensions (num_images,hi, wi, 3) representing the image pixel values.
208 | """
209 | return np.array([imread(str(fn)).astype(np.float32) for fn in files])
210 |
211 |
212 | def get_activations_from_files(files, sess, batch_size=16, verbose=False):
213 | """Calculates the activations of the pool_3 layer for all images.
214 |
215 | Params:
216 | -- files : list of paths to image files. Images need to have same dimensions for all files.
217 | -- sess : current session
218 | -- batch_size : the images numpy array is split into batches with batch size
219 | batch_size. A reasonable batch size depends on the disposable hardware.
220 | -- verbose : If set to True and parameter out_step is given, the number of calculated
221 | batches is reported.
222 | Returns:
223 | -- A numpy array of dimension (num images, 2048) that contains the
224 | activations of the given tensor when feeding inception with the query tensor.
225 | """
226 | inception_layer = _get_inception_layer(sess)
227 | d0 = len(files)
228 | if batch_size > d0:
229 | print("warning: batch size is bigger than the data size. setting batch size to data size")
230 | batch_size = d0
231 | n_batches = d0 // batch_size
232 | n_used_imgs = n_batches * batch_size
233 | pred_arr = np.empty((n_used_imgs, 2048))
234 | for i in range(n_batches):
235 | if verbose:
236 | print("\rPropagating batch %d/%d" % (i + 1, n_batches), end="", flush=True)
237 | start = i * batch_size
238 | end = start + batch_size
239 | batch = load_image_batch(files[start:end])
240 | pred = sess.run(inception_layer, {'FID_Inception_Net/ExpandDims:0': batch})
241 | pred_arr[start:end] = pred.reshape(batch_size, -1)
242 | del batch # clean up memory
243 | if verbose:
244 | print(" done")
245 | return pred_arr
246 |
247 |
248 | def calculate_activation_statistics_from_files(files, sess, batch_size=1, verbose=False):
249 | """Calculation of the statistics used by the FID.
250 | Params:
251 | -- files : list of paths to image files. Images need to have same dimensions for all files.
252 | -- sess : current session
253 | -- batch_size : the images numpy array is split into batches with batch size
254 | batch_size. A reasonable batch size depends on the available hardware.
255 | -- verbose : If set to True and parameter out_step is given, the number of calculated
256 | batches is reported.
257 | Returns:
258 | -- mu : The mean over samples of the activations of the pool_3 layer of
259 | the incption model.
260 | -- sigma : The covariance matrix of the activations of the pool_3 layer of
261 | the incption model.
262 | """
263 | act = get_activations_from_files(files, sess, batch_size, verbose)
264 | mu = np.mean(act, axis=0)
265 | sigma = np.cov(act, rowvar=False)
266 | return mu, sigma
267 |
268 |
269 | # -------------------------------------------------------------------------------
270 |
271 |
272 | # -------------------------------------------------------------------------------
273 | # The following functions aren't needed for calculating the FID
274 | # they're just here to make this module work as a stand-alone script
275 | # for calculating FID scores
276 | # -------------------------------------------------------------------------------
277 | def check_or_download_inception(inception_path):
278 | """ Checks if the path to the inception file is valid, or downloads
279 | the file if it is not present. """
280 | INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
281 | if inception_path is None:
282 | inception_path = '/tmp'
283 | inception_path = pathlib.Path(inception_path)
284 | model_file = inception_path / 'classify_image_graph_def.pb'
285 | if not model_file.exists():
286 | print("Downloading Inception model")
287 | from urllib import request
288 | import tarfile
289 | fn, _ = request.urlretrieve(INCEPTION_URL)
290 | with tarfile.open(fn, mode='r') as f:
291 | f.extract('classify_image_graph_def.pb', str(model_file.parent))
292 | return str(model_file)
293 |
294 |
295 | def _handle_path(path, sess, low_profile=False):
296 | if isinstance(path, str):
297 | f = np.load(path)
298 | m, s = f['mu'][:], f['sigma'][:]
299 | f.close()
300 | else:
301 | # path = pathlib.Path(path)
302 | files = path
303 | if low_profile:
304 | m, s = calculate_activation_statistics_from_files(files, sess)
305 | else:
306 | # x = np.array([imread(str(fn)).astype(np.float32) for fn in files])
307 | x = path
308 | m, s = calculate_activation_statistics(x, sess)
309 | del x # clean up memory
310 | return m, s
311 |
312 |
313 | def calculate_fid_given_paths(paths, inception_path, low_profile=False):
314 | """ Calculates the FID of two paths. """
315 | # inception_path = check_or_download_inception(inception_path)
316 |
317 | # for p in paths:
318 | # if not os.path.exists(p):
319 | # raise RuntimeError("Invalid path: %s" % p)
320 |
321 | config = tf.ConfigProto()
322 | config.gpu_options.allow_growth = True
323 | with tf.Session(config=config) as sess:
324 | sess.run(tf.global_variables_initializer())
325 | m1, s1 = _handle_path(paths[0], sess, low_profile=low_profile)
326 | m2, s2 = _handle_path(paths[1], sess, low_profile=low_profile)
327 | fid_value = calculate_frechet_distance(m1, s1, m2, s2)
328 | sess.close()
329 |
330 | return fid_value
331 |
--------------------------------------------------------------------------------
/utils/inception.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision import models
5 |
6 | try:
7 | from torchvision.models.utils import load_state_dict_from_url
8 | except ImportError:
9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
10 |
11 | # Inception weights ported to Pytorch from
12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
14 |
15 |
16 | class InceptionV3(nn.Module):
17 | """Pretrained InceptionV3 network returning feature maps"""
18 |
19 | # Index of default block of inception to return,
20 | # corresponds to output of final average pooling
21 | DEFAULT_BLOCK_INDEX = 3
22 |
23 | # Maps feature dimensionality to their output blocks indices
24 | BLOCK_INDEX_BY_DIM = {
25 | 64: 0, # First max pooling features
26 | 192: 1, # Second max pooling featurs
27 | 768: 2, # Pre-aux classifier features
28 | 2048: 3 # Final average pooling features
29 | }
30 |
31 | def __init__(self,
32 | output_blocks=[DEFAULT_BLOCK_INDEX],
33 | resize_input=True,
34 | normalize_input=True,
35 | requires_grad=False,
36 | use_fid_inception=True):
37 | """Build pretrained InceptionV3
38 | Parameters
39 | ----------
40 | output_blocks : list of int
41 | Indices of blocks to return features of. Possible values are:
42 | - 0: corresponds to output of first max pooling
43 | - 1: corresponds to output of second max pooling
44 | - 2: corresponds to output which is fed to aux classifier
45 | - 3: corresponds to output of final average pooling
46 | resize_input : bool
47 | If true, bilinearly resizes input to width and height 299 before
48 | feeding input to model. As the network without fully connected
49 | layers is fully convolutional, it should be able to handle inputs
50 | of arbitrary size, so resizing might not be strictly needed
51 | normalize_input : bool
52 | If true, scales the input from range (0, 1) to the range the
53 | pretrained Inception network expects, namely (-1, 1)
54 | requires_grad : bool
55 | If true, parameters of the model require gradients. Possibly useful
56 | for finetuning the network
57 | use_fid_inception : bool
58 | If true, uses the pretrained Inception model used in Tensorflow's
59 | FID implementation. If false, uses the pretrained Inception model
60 | available in torchvision. The FID Inception model has different
61 | weights and a slightly different structure from torchvision's
62 | Inception model. If you want to compute FID scores, you are
63 | strongly advised to set this parameter to true to get comparable
64 | results.
65 | """
66 | super(InceptionV3, self).__init__()
67 |
68 | self.resize_input = resize_input
69 | self.normalize_input = normalize_input
70 | self.output_blocks = sorted(output_blocks)
71 | self.last_needed_block = max(output_blocks)
72 |
73 | assert self.last_needed_block <= 3, \
74 | 'Last possible output block index is 3'
75 |
76 | self.blocks = nn.ModuleList()
77 |
78 | if use_fid_inception:
79 | inception = fid_inception_v3()
80 | else:
81 | inception = models.inception_v3(pretrained=True)
82 |
83 | # Block 0: input to maxpool1
84 | block0 = [
85 | inception.Conv2d_1a_3x3,
86 | inception.Conv2d_2a_3x3,
87 | inception.Conv2d_2b_3x3,
88 | nn.MaxPool2d(kernel_size=3, stride=2)
89 | ]
90 | self.blocks.append(nn.Sequential(*block0))
91 |
92 | # Block 1: maxpool1 to maxpool2
93 | if self.last_needed_block >= 1:
94 | block1 = [
95 | inception.Conv2d_3b_1x1,
96 | inception.Conv2d_4a_3x3,
97 | nn.MaxPool2d(kernel_size=3, stride=2)
98 | ]
99 | self.blocks.append(nn.Sequential(*block1))
100 |
101 | # Block 2: maxpool2 to aux classifier
102 | if self.last_needed_block >= 2:
103 | block2 = [
104 | inception.Mixed_5b,
105 | inception.Mixed_5c,
106 | inception.Mixed_5d,
107 | inception.Mixed_6a,
108 | inception.Mixed_6b,
109 | inception.Mixed_6c,
110 | inception.Mixed_6d,
111 | inception.Mixed_6e,
112 | ]
113 | self.blocks.append(nn.Sequential(*block2))
114 |
115 | # Block 3: aux classifier to final avgpool
116 | if self.last_needed_block >= 3:
117 | block3 = [
118 | inception.Mixed_7a,
119 | inception.Mixed_7b,
120 | inception.Mixed_7c,
121 | nn.AdaptiveAvgPool2d(output_size=(1, 1))
122 | ]
123 | self.blocks.append(nn.Sequential(*block3))
124 |
125 | for param in self.parameters():
126 | param.requires_grad = requires_grad
127 |
128 | def forward(self, inp):
129 | """Get Inception feature maps
130 | Parameters
131 | ----------
132 | inp : torch.autograd.Variable
133 | Input tensor of shape Bx3xHxW. Values are expected to be in
134 | range (0, 1)
135 | Returns
136 | -------
137 | List of torch.autograd.Variable, corresponding to the selected output
138 | block, sorted ascending by index
139 | """
140 | outp = []
141 | x = inp
142 |
143 | if self.resize_input:
144 | x = F.interpolate(x,
145 | size=(299, 299),
146 | mode='bilinear',
147 | align_corners=False)
148 |
149 | if self.normalize_input:
150 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
151 |
152 | for idx, block in enumerate(self.blocks):
153 | x = block(x)
154 | if idx in self.output_blocks:
155 | outp.append(x)
156 |
157 | if idx == self.last_needed_block:
158 | break
159 |
160 | return outp
161 |
162 |
163 | def fid_inception_v3():
164 | """Build pretrained Inception model for FID computation
165 | The Inception model for FID computation uses a different set of weights
166 | and has a slightly different structure than torchvision's Inception.
167 | This method first constructs torchvision's Inception and then patches the
168 | necessary parts that are different in the FID Inception model.
169 | """
170 | inception = models.inception_v3(num_classes=1008,
171 | aux_logits=False,
172 | pretrained=False,
173 | init_weights=False)
174 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
175 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
176 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
177 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
178 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
179 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
180 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
181 | inception.Mixed_7b = FIDInceptionE_1(1280)
182 | inception.Mixed_7c = FIDInceptionE_2(2048)
183 |
184 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
185 | inception.load_state_dict(state_dict)
186 | return inception
187 |
188 |
189 | class FIDInceptionA(models.inception.InceptionA):
190 | """InceptionA block patched for FID computation"""
191 |
192 | def __init__(self, in_channels, pool_features):
193 | super(FIDInceptionA, self).__init__(in_channels, pool_features)
194 |
195 | def forward(self, x):
196 | branch1x1 = self.branch1x1(x)
197 |
198 | branch5x5 = self.branch5x5_1(x)
199 | branch5x5 = self.branch5x5_2(branch5x5)
200 |
201 | branch3x3dbl = self.branch3x3dbl_1(x)
202 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
203 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
204 |
205 | # Patch: Tensorflow's average pool does not use the padded zero's in
206 | # its average calculation
207 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
208 | count_include_pad=False)
209 | branch_pool = self.branch_pool(branch_pool)
210 |
211 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
212 | return torch.cat(outputs, 1)
213 |
214 |
215 | class FIDInceptionC(models.inception.InceptionC):
216 | """InceptionC block patched for FID computation"""
217 |
218 | def __init__(self, in_channels, channels_7x7):
219 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
220 |
221 | def forward(self, x):
222 | branch1x1 = self.branch1x1(x)
223 |
224 | branch7x7 = self.branch7x7_1(x)
225 | branch7x7 = self.branch7x7_2(branch7x7)
226 | branch7x7 = self.branch7x7_3(branch7x7)
227 |
228 | branch7x7dbl = self.branch7x7dbl_1(x)
229 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
230 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
231 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
232 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
233 |
234 | # Patch: Tensorflow's average pool does not use the padded zero's in
235 | # its average calculation
236 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
237 | count_include_pad=False)
238 | branch_pool = self.branch_pool(branch_pool)
239 |
240 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
241 | return torch.cat(outputs, 1)
242 |
243 |
244 | class FIDInceptionE_1(models.inception.InceptionE):
245 | """First InceptionE block patched for FID computation"""
246 |
247 | def __init__(self, in_channels):
248 | super(FIDInceptionE_1, self).__init__(in_channels)
249 |
250 | def forward(self, x):
251 | branch1x1 = self.branch1x1(x)
252 |
253 | branch3x3 = self.branch3x3_1(x)
254 | branch3x3 = [
255 | self.branch3x3_2a(branch3x3),
256 | self.branch3x3_2b(branch3x3),
257 | ]
258 | branch3x3 = torch.cat(branch3x3, 1)
259 |
260 | branch3x3dbl = self.branch3x3dbl_1(x)
261 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
262 | branch3x3dbl = [
263 | self.branch3x3dbl_3a(branch3x3dbl),
264 | self.branch3x3dbl_3b(branch3x3dbl),
265 | ]
266 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
267 |
268 | # Patch: Tensorflow's average pool does not use the padded zero's in
269 | # its average calculation
270 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
271 | count_include_pad=False)
272 | branch_pool = self.branch_pool(branch_pool)
273 |
274 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
275 | return torch.cat(outputs, 1)
276 |
277 |
278 | class FIDInceptionE_2(models.inception.InceptionE):
279 | """Second InceptionE block patched for FID computation"""
280 |
281 | def __init__(self, in_channels):
282 | super(FIDInceptionE_2, self).__init__(in_channels)
283 |
284 | def forward(self, x):
285 | branch1x1 = self.branch1x1(x)
286 |
287 | branch3x3 = self.branch3x3_1(x)
288 | branch3x3 = [
289 | self.branch3x3_2a(branch3x3),
290 | self.branch3x3_2b(branch3x3),
291 | ]
292 | branch3x3 = torch.cat(branch3x3, 1)
293 |
294 | branch3x3dbl = self.branch3x3dbl_1(x)
295 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
296 | branch3x3dbl = [
297 | self.branch3x3dbl_3a(branch3x3dbl),
298 | self.branch3x3dbl_3b(branch3x3dbl),
299 | ]
300 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
301 |
302 | # Patch: The FID Inception model uses max pooling instead of average
303 | # pooling. This is likely an error in this specific Inception
304 | # implementation, as other Inception models use average pooling here
305 | # (which matches the description in the paper).
306 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
307 | branch_pool = self.branch_pool(branch_pool)
308 |
309 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
310 | return torch.cat(outputs, 1)
311 |
--------------------------------------------------------------------------------
/utils/inception_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision import models
5 |
6 | try:
7 | from torchvision.models.utils import load_state_dict_from_url
8 | except ImportError:
9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
10 |
11 | # Inception weights ported to Pytorch from
12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
14 |
15 |
16 | class InceptionV3(nn.Module):
17 | """Pretrained InceptionV3 network returning feature maps"""
18 |
19 | # Index of default block of inception to return,
20 | # corresponds to output of final average pooling
21 | DEFAULT_BLOCK_INDEX = 3
22 |
23 | # Maps feature dimensionality to their output blocks indices
24 | BLOCK_INDEX_BY_DIM = {
25 | 64: 0, # First max pooling features
26 | 192: 1, # Second max pooling featurs
27 | 768: 2, # Pre-aux classifier features
28 | 2048: 3 # Final average pooling features
29 | }
30 |
31 | def __init__(self,
32 | output_blocks=[DEFAULT_BLOCK_INDEX],
33 | resize_input=True,
34 | normalize_input=True,
35 | requires_grad=False,
36 | use_fid_inception=True):
37 | """Build pretrained InceptionV3
38 | Parameters
39 | ----------
40 | output_blocks : list of int
41 | Indices of blocks to return features of. Possible values are:
42 | - 0: corresponds to output of first max pooling
43 | - 1: corresponds to output of second max pooling
44 | - 2: corresponds to output which is fed to aux classifier
45 | - 3: corresponds to output of final average pooling
46 | resize_input : bool
47 | If true, bilinearly resizes input to width and height 299 before
48 | feeding input to model. As the network without fully connected
49 | layers is fully convolutional, it should be able to handle inputs
50 | of arbitrary size, so resizing might not be strictly needed
51 | normalize_input : bool
52 | If true, scales the input from range (0, 1) to the range the
53 | pretrained Inception network expects, namely (-1, 1)
54 | requires_grad : bool
55 | If true, parameters of the model require gradients. Possibly useful
56 | for finetuning the network
57 | use_fid_inception : bool
58 | If true, uses the pretrained Inception model used in Tensorflow's
59 | FID implementation. If false, uses the pretrained Inception model
60 | available in torchvision. The FID Inception model has different
61 | weights and a slightly different structure from torchvision's
62 | Inception model. If you want to compute FID scores, you are
63 | strongly advised to set this parameter to true to get comparable
64 | results.
65 | """
66 | super(InceptionV3, self).__init__()
67 |
68 | self.resize_input = resize_input
69 | self.normalize_input = normalize_input
70 | self.output_blocks = sorted(output_blocks)
71 | self.last_needed_block = max(output_blocks)
72 |
73 | assert self.last_needed_block <= 3, \
74 | 'Last possible output block index is 3'
75 |
76 | self.blocks = nn.ModuleList()
77 |
78 | if use_fid_inception:
79 | inception = fid_inception_v3()
80 | else:
81 | inception = models.inception_v3(pretrained=True)
82 |
83 | # Block 0: input to maxpool1
84 | block0 = [
85 | inception.Conv2d_1a_3x3,
86 | inception.Conv2d_2a_3x3,
87 | inception.Conv2d_2b_3x3,
88 | nn.MaxPool2d(kernel_size=3, stride=2)
89 | ]
90 | self.blocks.append(nn.Sequential(*block0))
91 |
92 | # Block 1: maxpool1 to maxpool2
93 | if self.last_needed_block >= 1:
94 | block1 = [
95 | inception.Conv2d_3b_1x1,
96 | inception.Conv2d_4a_3x3,
97 | nn.MaxPool2d(kernel_size=3, stride=2)
98 | ]
99 | self.blocks.append(nn.Sequential(*block1))
100 |
101 | # Block 2: maxpool2 to aux classifier
102 | if self.last_needed_block >= 2:
103 | block2 = [
104 | inception.Mixed_5b,
105 | inception.Mixed_5c,
106 | inception.Mixed_5d,
107 | inception.Mixed_6a,
108 | inception.Mixed_6b,
109 | inception.Mixed_6c,
110 | inception.Mixed_6d,
111 | inception.Mixed_6e,
112 | ]
113 | self.blocks.append(nn.Sequential(*block2))
114 |
115 | # Block 3: aux classifier to final avgpool
116 | if self.last_needed_block >= 3:
117 | block3 = [
118 | inception.Mixed_7a,
119 | inception.Mixed_7b,
120 | inception.Mixed_7c,
121 | nn.AdaptiveAvgPool2d(output_size=(1, 1))
122 | ]
123 | self.blocks.append(nn.Sequential(*block3))
124 |
125 | for param in self.parameters():
126 | param.requires_grad = requires_grad
127 |
128 | def forward(self, inp):
129 | """Get Inception feature maps
130 | Parameters
131 | ----------
132 | inp : torch.autograd.Variable
133 | Input tensor of shape Bx3xHxW. Values are expected to be in
134 | range (0, 1)
135 | Returns
136 | -------
137 | List of torch.autograd.Variable, corresponding to the selected output
138 | block, sorted ascending by index
139 | """
140 | outp = []
141 | x = inp
142 |
143 | if self.resize_input:
144 | x = F.interpolate(x,
145 | size=(299, 299),
146 | mode='bilinear',
147 | align_corners=False)
148 |
149 | if self.normalize_input:
150 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
151 |
152 | for idx, block in enumerate(self.blocks):
153 | x = block(x)
154 | if idx in self.output_blocks:
155 | outp.append(x)
156 |
157 | if idx == self.last_needed_block:
158 | break
159 |
160 | return outp
161 |
162 |
163 | def fid_inception_v3():
164 | """Build pretrained Inception model for FID computation
165 | The Inception model for FID computation uses a different set of weights
166 | and has a slightly different structure than torchvision's Inception.
167 | This method first constructs torchvision's Inception and then patches the
168 | necessary parts that are different in the FID Inception model.
169 | """
170 | inception = models.inception_v3(num_classes=1008,
171 | aux_logits=False,
172 | pretrained=False)
173 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
174 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
175 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
176 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
177 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
178 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
179 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
180 | inception.Mixed_7b = FIDInceptionE_1(1280)
181 | inception.Mixed_7c = FIDInceptionE_2(2048)
182 |
183 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
184 | inception.load_state_dict(state_dict)
185 | return inception
186 |
187 |
188 | class FIDInceptionA(models.inception.InceptionA):
189 | """InceptionA block patched for FID computation"""
190 |
191 | def __init__(self, in_channels, pool_features):
192 | super(FIDInceptionA, self).__init__(in_channels, pool_features)
193 |
194 | def forward(self, x):
195 | branch1x1 = self.branch1x1(x)
196 |
197 | branch5x5 = self.branch5x5_1(x)
198 | branch5x5 = self.branch5x5_2(branch5x5)
199 |
200 | branch3x3dbl = self.branch3x3dbl_1(x)
201 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
202 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
203 |
204 | # Patch: Tensorflow's average pool does not use the padded zero's in
205 | # its average calculation
206 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
207 | count_include_pad=False)
208 | branch_pool = self.branch_pool(branch_pool)
209 |
210 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
211 | return torch.cat(outputs, 1)
212 |
213 |
214 | class FIDInceptionC(models.inception.InceptionC):
215 | """InceptionC block patched for FID computation"""
216 |
217 | def __init__(self, in_channels, channels_7x7):
218 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
219 |
220 | def forward(self, x):
221 | branch1x1 = self.branch1x1(x)
222 |
223 | branch7x7 = self.branch7x7_1(x)
224 | branch7x7 = self.branch7x7_2(branch7x7)
225 | branch7x7 = self.branch7x7_3(branch7x7)
226 |
227 | branch7x7dbl = self.branch7x7dbl_1(x)
228 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
229 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
230 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
231 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
232 |
233 | # Patch: Tensorflow's average pool does not use the padded zero's in
234 | # its average calculation
235 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
236 | count_include_pad=False)
237 | branch_pool = self.branch_pool(branch_pool)
238 |
239 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
240 | return torch.cat(outputs, 1)
241 |
242 |
243 | class FIDInceptionE_1(models.inception.InceptionE):
244 | """First InceptionE block patched for FID computation"""
245 |
246 | def __init__(self, in_channels):
247 | super(FIDInceptionE_1, self).__init__(in_channels)
248 |
249 | def forward(self, x):
250 | branch1x1 = self.branch1x1(x)
251 |
252 | branch3x3 = self.branch3x3_1(x)
253 | branch3x3 = [
254 | self.branch3x3_2a(branch3x3),
255 | self.branch3x3_2b(branch3x3),
256 | ]
257 | branch3x3 = torch.cat(branch3x3, 1)
258 |
259 | branch3x3dbl = self.branch3x3dbl_1(x)
260 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
261 | branch3x3dbl = [
262 | self.branch3x3dbl_3a(branch3x3dbl),
263 | self.branch3x3dbl_3b(branch3x3dbl),
264 | ]
265 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
266 |
267 | # Patch: Tensorflow's average pool does not use the padded zero's in
268 | # its average calculation
269 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
270 | count_include_pad=False)
271 | branch_pool = self.branch_pool(branch_pool)
272 |
273 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
274 | return torch.cat(outputs, 1)
275 |
276 |
277 | class FIDInceptionE_2(models.inception.InceptionE):
278 | """Second InceptionE block patched for FID computation"""
279 |
280 | def __init__(self, in_channels):
281 | super(FIDInceptionE_2, self).__init__(in_channels)
282 |
283 | def forward(self, x):
284 | branch1x1 = self.branch1x1(x)
285 |
286 | branch3x3 = self.branch3x3_1(x)
287 | branch3x3 = [
288 | self.branch3x3_2a(branch3x3),
289 | self.branch3x3_2b(branch3x3),
290 | ]
291 | branch3x3 = torch.cat(branch3x3, 1)
292 |
293 | branch3x3dbl = self.branch3x3dbl_1(x)
294 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
295 | branch3x3dbl = [
296 | self.branch3x3dbl_3a(branch3x3dbl),
297 | self.branch3x3dbl_3b(branch3x3dbl),
298 | ]
299 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
300 |
301 | # Patch: The FID Inception model uses max pooling instead of average
302 | # pooling. This is likely an error in this specific Inception
303 | # implementation, as other Inception models use average pooling here
304 | # (which matches the description in the paper).
305 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
306 | branch_pool = self.branch_pool(branch_pool)
307 |
308 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
309 | return torch.cat(outputs, 1)
--------------------------------------------------------------------------------
/utils/inception_score.py:
--------------------------------------------------------------------------------
1 | # Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import math
7 | import os
8 | import os.path
9 | import sys
10 | import tarfile
11 |
12 | import numpy as np
13 | import tensorflow.compat.v1 as tf
14 | tf.disable_v2_behavior()
15 | from six.moves import urllib
16 | from tqdm import tqdm
17 |
18 |
19 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
20 | MODEL_DIR = '/tmp/imagenet'
21 | DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
22 | softmax = None
23 | config = tf.ConfigProto()
24 | # config = tf.ConfigProto(device_count = {'GPU': 0})
25 | config.gpu_options.visible_device_list= '0'
26 | config.gpu_options.allow_growth = True
27 |
28 |
29 | # Call this function with list of images. Each of elements should be a
30 | # numpy array with values ranging from 0 to 255.
31 | def get_inception_score(images, splits=10):
32 | assert (type(images) == list)
33 | assert (type(images[0]) == np.ndarray)
34 | assert (len(images[0].shape) == 3)
35 | assert (np.max(images[0]) > 10)
36 | assert (np.min(images[0]) >= 0.0)
37 | inps = []
38 | for img in images:
39 | img = img.astype(np.float32)
40 | inps.append(np.expand_dims(img, 0))
41 | bs = 128
42 | with tf.Session(config=config) as sess:
43 | preds = []
44 | n_batches = int(math.ceil(float(len(inps)) / float(bs)))
45 | for i in tqdm(range(n_batches), desc="Calculate inception score"):
46 | sys.stdout.flush()
47 | inp = inps[(i * bs):min((i + 1) * bs, len(inps))]
48 | inp = np.concatenate(inp, 0)
49 | pred = sess.run(softmax, {'ExpandDims:0': inp})
50 | preds.append(pred)
51 | preds = np.concatenate(preds, 0)
52 | scores = []
53 | for i in range(splits):
54 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :]
55 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
56 | kl = np.mean(np.sum(kl, 1))
57 | scores.append(np.exp(kl))
58 |
59 | sess.close()
60 | return np.mean(scores), np.std(scores)
61 |
62 |
63 | # This function is called automatically.
64 | def _init_inception():
65 | global softmax
66 | if not os.path.exists(MODEL_DIR):
67 | os.makedirs(MODEL_DIR)
68 | filename = DATA_URL.split('/')[-1]
69 | filepath = os.path.join(MODEL_DIR, filename)
70 | if not os.path.exists(filepath):
71 | def _progress(count, block_size, total_size):
72 | sys.stdout.write('\r>> Downloading %s %.1f%%' % (
73 | filename, float(count * block_size) / float(total_size) * 100.0))
74 | sys.stdout.flush()
75 |
76 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
77 | print()
78 | statinfo = os.stat(filepath)
79 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
80 | tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR)
81 | with tf.gfile.FastGFile(os.path.join(
82 | MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f:
83 | graph_def = tf.GraphDef()
84 | graph_def.ParseFromString(f.read())
85 | _ = tf.import_graph_def(graph_def, name='')
86 | # Works with an arbitrary minibatch size.
87 | with tf.Session(config=config) as sess:
88 | pool3 = sess.graph.get_tensor_by_name('pool_3:0')
89 | ops = pool3.graph.get_operations()
90 | for op_idx, op in enumerate(ops):
91 | for o in op.outputs:
92 | shape = o.get_shape()
93 | if shape._dims != []:
94 | shape = [s.value for s in shape]
95 | new_shape = []
96 | for j, s in enumerate(shape):
97 | if s == 1 and j == 0:
98 | new_shape.append(None)
99 | else:
100 | new_shape.append(s)
101 | o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
102 | w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1]
103 | logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w)
104 | softmax = tf.nn.softmax(logits)
105 | sess.close()
106 |
--------------------------------------------------------------------------------
/utils/torch_fid_score.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs
3 | The FID metric calculates the distance between two distributions of images.
4 | Typically, we have summary statistics (mean & covariance matrix) of one
5 | of these distributions, while the 2nd distribution is given by a GAN.
6 | When run as a stand-alone program, it compares the distribution of
7 | images that are stored as PNG/JPEG at a specified location with a
8 | distribution given by summary statistics (in pickle format).
9 | The FID is calculated by assuming that X_1 and X_2 are the activations of
10 | the pool_3 layer of the inception net for generated samples and real world
11 | samples respectively.
12 | See --help to see further details.
13 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
14 | of Tensorflow
15 | Copyright 2018 Institute of Bioinformatics, JKU Linz
16 | Licensed under the Apache License, Version 2.0 (the "License");
17 | you may not use this file except in compliance with the License.
18 | You may obtain a copy of the License at
19 | http://www.apache.org/licenses/LICENSE-2.0
20 | Unless required by applicable law or agreed to in writing, software
21 | distributed under the License is distributed on an "AS IS" BASIS,
22 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23 | See the License for the specific language governing permissions and
24 | limitations under the License.
25 | """
26 | import os
27 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
28 |
29 | import numpy as np
30 | import torch
31 | from utils.inception import InceptionV3
32 | from torch.nn.functional import adaptive_avg_pool2d
33 |
34 | try:
35 | from tqdm import tqdm
36 | except ImportError:
37 | # If not tqdm is not available, provide a mock version of it
38 | def tqdm(x):
39 | return x
40 |
41 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
42 | parser.add_argument('path', type=str, nargs=2,
43 | help=('Path to the generated images or '
44 | 'to .npz statistic files'))
45 | parser.add_argument('--batch-size', type=int, default=50,
46 | help='Batch size to use')
47 | parser.add_argument('--dims', type=int, default=2048,
48 | choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
49 | help=('Dimensionality of Inception features to use. '
50 | 'By default, uses pool3 features'))
51 | parser.add_argument('-c', '--gpu', default='1', type=str,
52 | help='GPU to use (leave blank for CPU only)')
53 |
54 | def _get_no_grad_ctx_mgr(require_grad):
55 | """Returns a the `torch.no_grad` context manager for PyTorch version >=
56 | 0.4, or a no-op context manager otherwise.
57 | """
58 | if not require_grad and float(torch.__version__[0:3]) >= 0.4:
59 | return torch.no_grad()
60 |
61 | return contextlib.suppress()
62 |
63 | # Pytorch implementation of matrix sqrt, from Tsung-Yu Lin, and Subhransu Maji
64 | # https://github.com/msubhransu/matrix-sqrt
65 | def sqrt_newton_schulz(A, numIters, dtype=None):
66 | if dtype is None:
67 | dtype = A.type()
68 | batchSize = A.shape[0]
69 | dim = A.shape[1]
70 | normA = A.mul(A).sum(dim=1).sum(dim=1).sqrt()
71 | Y = A.div(normA.view(batchSize, 1, 1).expand_as(A)).to("cuda:0")
72 | I = torch.eye(dim, dim).view(1, dim, dim).repeat(batchSize, 1, 1).type(dtype).to("cuda:0")
73 | Z = torch.eye(dim, dim).view(1, dim, dim).repeat(batchSize, 1, 1).type(dtype).to("cuda:0")
74 | for i in range(numIters):
75 | T = 0.5 * (3.0 * I - Z.bmm(Y))
76 | Y = Y.bmm(T)
77 | Z = T.bmm(Z)
78 | sA = Y * torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A)
79 | return sA
80 |
81 |
82 | # A pytorch implementation of cov, from Modar M. Alfadly
83 | # https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2
84 | def torch_cov(m, rowvar=False):
85 | '''Estimate a covariance matrix given data.
86 | Covariance indicates the level to which two variables vary together.
87 | If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`,
88 | then the covariance matrix element `C_{ij}` is the covariance of
89 | `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`.
90 | Args:
91 | m: A 1-D or 2-D array containing multiple variables and observations.
92 | Each row of `m` represents a variable, and each column a single
93 | observation of all those variables.
94 | rowvar: If `rowvar` is True, then each row represents a
95 | variable, with observations in the columns. Otherwise, the
96 | relationship is transposed: each column represents a variable,
97 | while the rows contain observations.
98 | Returns:
99 | The covariance matrix of the variables.
100 | '''
101 | if m.dim() > 2:
102 | raise ValueError('m has more than 2 dimensions')
103 | if m.dim() < 2:
104 | m = m.view(1, -1)
105 | if not rowvar and m.size(0) != 1:
106 | m = m.t()
107 | # m = m.type(torch.double) # uncomment this line if desired
108 | fact = 1.0 / (m.size(1) - 1)
109 | m -= torch.mean(m, dim=1, keepdim=True)
110 | mt = m.t() # if complex: mt = m.t().conj()
111 | return fact * m.matmul(mt).squeeze()
112 |
113 |
114 | def get_activations(args, gen_net, model, batch_size=50, dims=2048,
115 | cuda=False, verbose=False):
116 | """Calculates the activations of the pool_3 layer for all images.
117 | Params:
118 | -- files : List of image files paths
119 | -- model : Instance of inception model
120 | -- batch_size : Batch size of images for the model to process at once.
121 | Make sure that the number of samples is a multiple of
122 | the batch size, otherwise some samples are ignored. This
123 | behavior is retained to match the original FID score
124 | implementation.
125 | -- dims : Dimensionality of features returned by Inception
126 | -- cuda : If set to True, use GPU
127 | -- verbose : If set to True and parameter out_step is given, the number
128 | of calculated batches is reported.
129 | Returns:
130 | -- A numpy array of dimension (num images, dims) that contains the
131 | activations of the given tensor when feeding inception with the
132 | query tensor.
133 | """
134 | with torch.no_grad():
135 | gen_net.eval()
136 | model.eval()
137 |
138 | # if gen_imgs.shape[0] % batch_size != 0:
139 | # print(('Warning: number of images is not a multiple of the '
140 | # 'batch size. Some samples are going to be ignored.'))
141 | # if batch_size > gen_imgs.shape[0]:
142 | # print(('Warning: batch size is bigger than the data size. '
143 | # 'Setting batch size to data size'))
144 | # batch_size = gen_imgs.shape[0]
145 |
146 | n_batches = args.num_eval_imgs // batch_size
147 |
148 | # normalize
149 |
150 | pred_arr = []
151 | for i in tqdm(range(n_batches)):
152 | z = torch.cuda.FloatTensor(np.random.normal(0, 1, (batch_size, args.latent_dim)))
153 | gen_imgs = gen_net(z, 200)
154 |
155 | if verbose:
156 | print('\rPropagating batch %d/%d' % (i + 1, n_batches),
157 | end='', flush=True)
158 | start = i * batch_size
159 | end = start + batch_size
160 |
161 | images = (gen_imgs + 1.0) / 2.0
162 | model.to("cuda:0")
163 | pred = model(images.to("cuda:0"))[0]
164 |
165 | # If model output is not scalar, apply global spatial average pooling.
166 | # This happens if you choose a dimensionality not equal 2048.
167 | if pred.shape[2] != 1 or pred.shape[3] != 1:
168 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
169 |
170 | pred_arr += [pred.view(batch_size, -1)]
171 |
172 | if verbose:
173 | print('done')
174 | del images
175 |
176 | return torch.cat(pred_arr, dim=0)
177 |
178 |
179 | def torch_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
180 | """Pytorch implementation of the Frechet Distance.
181 | Taken from https://github.com/bioinf-jku/TTUR
182 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
183 | and X_2 ~ N(mu_2, C_2) is
184 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
185 | Stable version by Dougal J. Sutherland.
186 | Params:
187 | -- mu1 : Numpy array containing the activations of a layer of the
188 | inception net (like returned by the function 'get_predictions')
189 | for generated samples.
190 | -- mu2 : The sample mean over activations, precalculated on an
191 | representive data set.
192 | -- sigma1: The covariance matrix over activations for generated samples.
193 | -- sigma2: The covariance matrix over activations, precalculated on an
194 | representive data set.
195 | Returns:
196 | -- : The Frechet Distance.
197 | """
198 |
199 | assert mu1.shape == mu2.shape, \
200 | 'Training and test mean vectors have different lengths'
201 | assert sigma1.shape == sigma2.shape, \
202 | 'Training and test covariances have different dimensions'
203 |
204 | diff = mu1 - mu2
205 | # Run 50 itrs of newton-schulz to get the matrix sqrt of sigma1 dot sigma2
206 | covmean = sqrt_newton_schulz(sigma1.mm(sigma2).unsqueeze(0), 50).squeeze()
207 | out = (diff.dot(diff) + torch.trace(sigma1) + torch.trace(sigma2)
208 | - 2 * torch.trace(covmean))
209 | return out
210 |
211 |
212 | def calculate_activation_statistics(gen_net, model, batch_size=50,
213 | dims=2048, cuda=False, verbose=False):
214 | """Calculation of the statistics used by the FID.
215 | Params:
216 | -- gen_imgs : gen_imgs, tensor
217 | -- model : Instance of inception model
218 | -- batch_size : The images numpy array is split into batches with
219 | batch size batch_size. A reasonable batch size
220 | depends on the hardware.
221 | -- dims : Dimensionality of features returned by Inception
222 | -- cuda : If set to True, use GPU
223 | -- verbose : If set to True and parameter out_step is given, the
224 | number of calculated batches is reported.
225 | Returns:
226 | -- mu : The mean over samples of the activations of the pool_3 layer of
227 | the inception model.
228 | -- sigma : The covariance matrix of the activations of the pool_3 layer of
229 | the inception model.
230 | """
231 | act = get_activations(gen_net, model, batch_size, dims, cuda, verbose)
232 | mu = torch.mean(act, dim=0)
233 | sigma = torch_cov(act, rowvar=False)
234 | return mu, sigma
235 |
236 |
237 | def _compute_statistics_of_path(args, path, model, batch_size, dims, cuda):
238 | if isinstance(path, str):
239 | assert path.endswith('.npz')
240 | f = np.load(path)
241 | if 'mean' in f:
242 | m, s = f['mean'][:], f['cov'][:]
243 | else:
244 | m, s = f['mu'][:], f['sigma'][:]
245 | f.close()
246 | else:
247 | # a tensor
248 | gen_net = path
249 | m, s = calculate_activation_statistics(args, gen_net, model, batch_size,
250 | dims, cuda)
251 |
252 | return m, s
253 |
254 |
255 | def calculate_fid_given_paths_torch(args, gen_net, path, require_grad=False, gen_batch_size=1, batch_size=1, cuda=True, dims=2048):
256 | """
257 | Calculates the FID of two paths
258 | :param gen_imgs: The value range of gen_imgs should be (-1, 1). Just the output of tanh.
259 | :param path: fid file path. *.npz.
260 | :param batch_size:
261 | :param cuda:
262 | :param dims:
263 | :return:
264 | """
265 | if not os.path.exists(path):
266 | raise RuntimeError('Invalid path: %s' % path)
267 |
268 | assert args.num_eval_imgs >= dims, f'gen_imgs size: {args.num_eval_imgs}' # or will lead to nan
269 |
270 | with _get_no_grad_ctx_mgr(require_grad=require_grad):
271 |
272 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
273 |
274 | model = InceptionV3([block_idx])
275 | if cuda:
276 | model.cuda()
277 |
278 | m1, s1 = _compute_statistics_of_path(args, gen_net, model, batch_size,
279 | dims, cuda)
280 | # print(f'generated stat: {m1}, {s1}')
281 | m2, s2 = _compute_statistics_of_path(args, path, model, batch_size,
282 | dims, cuda)
283 | # print(f'GT stat: {m2}, {s2}')
284 | fid_value = torch_calculate_frechet_distance(m1.to("cuda:0"), s1.to("cuda:0"), torch.tensor(m2).float().cuda().to("cuda:0"),
285 | torch.tensor(s2).float().cuda().to("cuda:0"))
286 | del model
287 |
288 | return fid_value
289 |
290 |
291 | def get_fid(args, fid_stat, epoch, gen_net, num_img, gen_batch_size, val_batch_size, writer_dict=None, cls_idx=None):
292 | gen_net.eval()
293 | with torch.no_grad():
294 | # eval mode
295 | gen_net.eval()
296 |
297 | # eval_iter = num_img // gen_batch_size
298 | # img_list = []
299 | # for _ in tqdm(range(eval_iter), desc='sample images'):
300 | # z = torch.cuda.FloatTensor(np.random.normal(0, 1, (gen_batch_size, args.latent_dim)))
301 |
302 | # # Generate a batch of images
303 | # if args.n_classes > 0:
304 | # if cls_idx is not None:
305 | # label = torch.ones(z.shape[0]) * cls_idx
306 | # label = label.type(torch.cuda.LongTensor)
307 | # else:
308 | # label = torch.randint(low=0, high=args.n_classes, size=(z.shape[0],), device='cuda')
309 | # gen_imgs = gen_net(z, epoch)
310 | # else:
311 | # gen_imgs = gen_net(z, epoch)
312 | # if isinstance(gen_imgs, tuple):
313 | # gen_imgs = gen_imgs[0]
314 | # img_list += [gen_imgs]
315 |
316 | # img_list = torch.cat(img_list, 0)
317 | fid_score = calculate_fid_given_paths_torch(args, gen_net, fid_stat, gen_batch_size=gen_batch_size, batch_size=val_batch_size)
318 |
319 | if writer_dict:
320 | writer = writer_dict['writer']
321 | global_steps = writer_dict['valid_global_steps']
322 | writer.add_scalar('FID_score', fid_score, global_steps)
323 | writer_dict['valid_global_steps'] = global_steps + 1
324 |
325 | return fid_score
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2019-07-25
3 | # @Author : Xinyu Gong (xy_gong@tamu.edu)
4 | # @Link : None
5 | # @Version : 0.0
6 |
7 | import collections
8 | import logging
9 | import math
10 | import os
11 | import time
12 | from datetime import datetime
13 |
14 | import dateutil.tz
15 | import torch
16 |
17 | from typing import Union, Optional, List, Tuple, Text, BinaryIO
18 | import pathlib
19 | import torch
20 | import math
21 | import warnings
22 | import numpy as np
23 | from PIL import Image, ImageDraw, ImageFont, ImageColor
24 |
25 | @torch.no_grad()
26 | def make_grid(
27 | tensor: Union[torch.Tensor, List[torch.Tensor]],
28 | nrow: int = 8,
29 | padding: int = 2,
30 | normalize: bool = False,
31 | value_range: Optional[Tuple[int, int]] = None,
32 | scale_each: bool = False,
33 | pad_value: int = 0,
34 | **kwargs
35 | ) -> torch.Tensor:
36 | """
37 | Make a grid of images.
38 | Args:
39 | tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
40 | or a list of images all of the same size.
41 | nrow (int, optional): Number of images displayed in each row of the grid.
42 | The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
43 | padding (int, optional): amount of padding. Default: ``2``.
44 | normalize (bool, optional): If True, shift the image to the range (0, 1),
45 | by the min and max values specified by :attr:`range`. Default: ``False``.
46 | value_range (tuple, optional): tuple (min, max) where min and max are numbers,
47 | then these numbers are used to normalize the image. By default, min and max
48 | are computed from the tensor.
49 | scale_each (bool, optional): If ``True``, scale each image in the batch of
50 | images separately rather than the (min, max) over all images. Default: ``False``.
51 | pad_value (float, optional): Value for the padded pixels. Default: ``0``.
52 | Returns:
53 | grid (Tensor): the tensor containing grid of images.
54 | Example:
55 | See this notebook
56 | `here `_
57 | """
58 | if not (torch.is_tensor(tensor) or
59 | (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
60 | raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
61 |
62 | if "range" in kwargs.keys():
63 | warning = "range will be deprecated, please use value_range instead."
64 | warnings.warn(warning)
65 | value_range = kwargs["range"]
66 |
67 | # if list of tensors, convert to a 4D mini-batch Tensor
68 | if isinstance(tensor, list):
69 | tensor = torch.stack(tensor, dim=0)
70 |
71 | if tensor.dim() == 2: # single image H x W
72 | tensor = tensor.unsqueeze(0)
73 | if tensor.dim() == 3: # single image
74 | if tensor.size(0) == 1: # if single-channel, convert to 3-channel
75 | tensor = torch.cat((tensor, tensor, tensor), 0)
76 | tensor = tensor.unsqueeze(0)
77 |
78 | if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
79 | tensor = torch.cat((tensor, tensor, tensor), 1)
80 |
81 | if normalize is True:
82 | tensor = tensor.clone() # avoid modifying tensor in-place
83 | if value_range is not None:
84 | assert isinstance(value_range, tuple), \
85 | "value_range has to be a tuple (min, max) if specified. min and max are numbers"
86 |
87 | def norm_ip(img, low, high):
88 | img.clamp(min=low, max=high)
89 | img.sub_(low).div_(max(high - low, 1e-5))
90 |
91 | def norm_range(t, value_range):
92 | if value_range is not None:
93 | norm_ip(t, value_range[0], value_range[1])
94 | else:
95 | norm_ip(t, float(t.min()), float(t.max()))
96 |
97 | if scale_each is True:
98 | for t in tensor: # loop over mini-batch dimension
99 | norm_range(t, value_range)
100 | else:
101 | norm_range(tensor, value_range)
102 |
103 | if tensor.size(0) == 1:
104 | return tensor.squeeze(0)
105 |
106 | # make the mini-batch of images into a grid
107 | nmaps = tensor.size(0)
108 | xmaps = min(nrow, nmaps)
109 | ymaps = int(math.ceil(float(nmaps) / xmaps))
110 | height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
111 | num_channels = tensor.size(1)
112 | grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
113 | k = 0
114 | for y in range(ymaps):
115 | for x in range(xmaps):
116 | if k >= nmaps:
117 | break
118 | # Tensor.copy_() is a valid method but seems to be missing from the stubs
119 | # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_
120 | grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined]
121 | 2, x * width + padding, width - padding
122 | ).copy_(tensor[k])
123 | k = k + 1
124 | return grid
125 |
126 |
127 | @torch.no_grad()
128 | def save_image(
129 | tensor: Union[torch.Tensor, List[torch.Tensor]],
130 | fp: Union[Text, pathlib.Path, BinaryIO],
131 | format: Optional[str] = None,
132 | **kwargs
133 | ) -> None:
134 | """
135 | Save a given Tensor into an image file.
136 | Args:
137 | tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
138 | saves the tensor as a grid of images by calling ``make_grid``.
139 | fp (string or file object): A filename or a file object
140 | format(Optional): If omitted, the format to use is determined from the filename extension.
141 | If a file object was used instead of a filename, this parameter should always be used.
142 | **kwargs: Other arguments are documented in ``make_grid``.
143 | """
144 |
145 | grid = make_grid(tensor, **kwargs)
146 | # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
147 | ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
148 | im = Image.fromarray(ndarr)
149 | im.save(fp, format=format)
150 |
151 |
152 | def create_logger(log_dir, phase='train'):
153 | time_str = time.strftime('%Y-%m-%d-%H-%M')
154 | log_file = '{}_{}.log'.format(time_str, phase)
155 | final_log_file = os.path.join(log_dir, log_file)
156 | head = '%(asctime)-15s %(message)s'
157 | logging.basicConfig(filename=str(final_log_file),
158 | format=head)
159 | logger = logging.getLogger()
160 | logger.setLevel(logging.INFO)
161 | console = logging.StreamHandler()
162 | logging.getLogger('').addHandler(console)
163 |
164 | return logger
165 |
166 |
167 | def set_log_dir(root_dir, exp_name):
168 | path_dict = {}
169 | os.makedirs(root_dir, exist_ok=True)
170 |
171 | # set log path
172 | exp_path = os.path.join(root_dir, exp_name)
173 | now = datetime.now(dateutil.tz.tzlocal())
174 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
175 | prefix = exp_path + '_' + timestamp
176 | os.makedirs(prefix)
177 | path_dict['prefix'] = prefix
178 |
179 | # set checkpoint path
180 | ckpt_path = os.path.join(prefix, 'Model')
181 | os.makedirs(ckpt_path)
182 | path_dict['ckpt_path'] = ckpt_path
183 |
184 | log_path = os.path.join(prefix, 'Log')
185 | os.makedirs(log_path)
186 | path_dict['log_path'] = log_path
187 |
188 | # set sample image path for fid calculation
189 | sample_path = os.path.join(prefix, 'Samples')
190 | os.makedirs(sample_path)
191 | path_dict['sample_path'] = sample_path
192 |
193 | return path_dict
194 |
195 |
196 | def save_checkpoint(states, is_best, output_dir,
197 | filename='checkpoint.pth'):
198 | torch.save(states, os.path.join(output_dir, filename))
199 | if is_best:
200 | torch.save(states, os.path.join(output_dir, 'checkpoint_best.pth'))
201 |
202 |
203 | class RunningStats:
204 | def __init__(self, WIN_SIZE):
205 | self.mean = 0
206 | self.run_var = 0
207 | self.WIN_SIZE = WIN_SIZE
208 |
209 | self.window = collections.deque(maxlen=WIN_SIZE)
210 |
211 | def clear(self):
212 | self.window.clear()
213 | self.mean = 0
214 | self.run_var = 0
215 |
216 | def is_full(self):
217 | return len(self.window) == self.WIN_SIZE
218 |
219 | def push(self, x):
220 |
221 | if len(self.window) == self.WIN_SIZE:
222 | # Adjusting variance
223 | x_removed = self.window.popleft()
224 | self.window.append(x)
225 | old_m = self.mean
226 | self.mean += (x - x_removed) / self.WIN_SIZE
227 | self.run_var += (x + x_removed - old_m - self.mean) * (x - x_removed)
228 | else:
229 | # Calculating first variance
230 | self.window.append(x)
231 | delta = x - self.mean
232 | self.mean += delta / len(self.window)
233 | self.run_var += delta * (x - self.mean)
234 |
235 | def get_mean(self):
236 | return self.mean if len(self.window) else 0.0
237 |
238 | def get_var(self):
239 | return self.run_var / len(self.window) if len(self.window) > 1 else 0.0
240 |
241 | def get_std(self):
242 | return math.sqrt(self.get_var())
243 |
244 | def get_all(self):
245 | return list(self.window)
246 |
247 | def __str__(self):
248 | return "Current window values: {}".format(list(self.window))
249 |
--------------------------------------------------------------------------------
/visualizationMetrics.py:
--------------------------------------------------------------------------------
1 | """Time-series Generative Adversarial Networks (TimeGAN) Codebase.
2 | Reference: Jinsung Yoon, Daniel Jarrett, Mihaela van der Schaar,
3 | "Time-series Generative Adversarial Networks,"
4 | Neural Information Processing Systems (NeurIPS), 2019.
5 | Paper link: https://papers.nips.cc/paper/8789-time-series-generative-adversarial-networks
6 | Last updated Date: April 24th 2020
7 | Code author: Jinsung Yoon (jsyoon0823@gmail.com)
8 | -----------------------------
9 | visualization_metrics.py
10 | Note: Use PCA or tSNE for generated and original data visualization
11 | """
12 |
13 | # Necessary packages
14 | from sklearn.manifold import TSNE
15 | from sklearn.decomposition import PCA
16 | import matplotlib.pyplot as plt
17 | import numpy as np
18 |
19 |
20 | def visualization (ori_data, generated_data, analysis, save_name):
21 | """Using PCA or tSNE for generated and original data visualization.
22 |
23 | Args:
24 | - ori_data: original data
25 | - generated_data: generated synthetic data
26 | - analysis: tsne or pca
27 | """
28 | # Analysis sample size (for faster computation)
29 | anal_sample_no = min([1000, len(ori_data)])
30 | idx = np.random.permutation(len(ori_data))[:anal_sample_no]
31 |
32 | # Data preprocessing
33 | ori_data = np.asarray(ori_data)
34 | generated_data = np.asarray(generated_data)
35 |
36 | ori_data = ori_data[idx]
37 | generated_data = generated_data[idx]
38 |
39 | no, seq_len, dim = ori_data.shape
40 |
41 | for i in range(anal_sample_no):
42 | if (i == 0):
43 | prep_data = np.reshape(np.mean(ori_data[0,:,:], 1), [1,seq_len])
44 | prep_data_hat = np.reshape(np.mean(generated_data[0,:,:],1), [1,seq_len])
45 | else:
46 | prep_data = np.concatenate((prep_data,
47 | np.reshape(np.mean(ori_data[i,:,:],1), [1,seq_len])))
48 | prep_data_hat = np.concatenate((prep_data_hat,
49 | np.reshape(np.mean(generated_data[i,:,:],1), [1,seq_len])))
50 |
51 | # Visualization parameter
52 | colors = ["red" for i in range(anal_sample_no)] + ["blue" for i in range(anal_sample_no)]
53 |
54 | if analysis == 'pca':
55 | # PCA Analysis
56 | pca = PCA(n_components = 2)
57 | pca.fit(prep_data)
58 | pca_results = pca.transform(prep_data)
59 | pca_hat_results = pca.transform(prep_data_hat)
60 |
61 | # Plotting
62 | f, ax = plt.subplots(1)
63 | plt.scatter(pca_results[:,0], pca_results[:,1],
64 | c = colors[:anal_sample_no], alpha = 0.2, label = "Original")
65 | plt.scatter(pca_hat_results[:,0], pca_hat_results[:,1],
66 | c = colors[anal_sample_no:], alpha = 0.2, label = "Synthetic")
67 |
68 | ax.legend()
69 | plt.title('PCA plot')
70 | plt.xlabel('x-pca')
71 | plt.ylabel('y_pca')
72 | # plt.show()
73 |
74 | elif analysis == 'tsne':
75 |
76 | # Do t-SNE Analysis together
77 | prep_data_final = np.concatenate((prep_data, prep_data_hat), axis = 0)
78 |
79 | # TSNE anlaysis
80 | tsne = TSNE(n_components = 2, verbose = 1, perplexity = 40, n_iter = 300)
81 | tsne_results = tsne.fit_transform(prep_data_final)
82 |
83 | # Plotting
84 | f, ax = plt.subplots(1)
85 |
86 | plt.scatter(tsne_results[:anal_sample_no,0], tsne_results[:anal_sample_no,1],
87 | c = colors[:anal_sample_no], alpha = 0.2, label = "Original")
88 | plt.scatter(tsne_results[anal_sample_no:,0], tsne_results[anal_sample_no:,1],
89 | c = colors[anal_sample_no:], alpha = 0.2, label = "Synthetic")
90 |
91 | ax.legend()
92 |
93 | plt.title('t-SNE plot')
94 | plt.xlabel('x-tsne')
95 | plt.ylabel('y_tsne')
96 | # plt.show()
97 |
98 | plt.savefig(f'./images/{save_name}.pdf', format="pdf")
99 | plt.show()
--------------------------------------------------------------------------------