├── models ├── __init__.py ├── discriminators │ ├── __init__.py │ ├── resblocks.py │ ├── snresnet64.py │ └── snresnet.py ├── generators │ ├── __init__.py │ ├── resnet64.py │ ├── resnet.py │ └── resblocks.py └── inception.py ├── metrics ├── __init__.py └── fid.py ├── links ├── __init__.py └── conditional_batchnorm.py ├── Dockerfile.tensorboard ├── docker-compose.yml ├── LICENSE ├── Dockerfile ├── .gitignore ├── evaluation.py ├── losses.py ├── README.md ├── utils.py └── train_64.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/discriminators/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/generators/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /links/__init__.py: -------------------------------------------------------------------------------- 1 | from links.conditional_batchnorm import CategoricalConditionalBatchNorm2d # NOQA 2 | from links.conditional_batchnorm import ConditionalBatchNorm2d # NOQA 3 | -------------------------------------------------------------------------------- /Dockerfile.tensorboard: -------------------------------------------------------------------------------- 1 | FROM python:3.5 2 | 3 | RUN pip install --no-cache-dir tensorflow 4 | 5 | WORKDIR /logs 6 | 7 | ENTRYPOINT ["tensorboard", "--logdir", "/logs"] 8 | CMD [] 9 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "2.3" 2 | 3 | services: 4 | base: &default 5 | image: pytorch.sngan_projection 6 | build: 7 | context: . 8 | dockerfile: Dockerfile 9 | args: 10 | - user_id=1000 11 | - group_id=1000 12 | - user_name=crcrpar 13 | - group_name=sngan_projection 14 | - PYTHON_VERSION=3.6 15 | volumes: 16 | - "${PWD}:/src" 17 | - "${DATA}:/data" 18 | - "${RESULTS}:/results" 19 | ports: 20 | - ${PORT}:8888 21 | ipc: host 22 | tensorboard: 23 | build: 24 | context: . 25 | dockerfile: Dockerfile.tensorboard 26 | volumes: 27 | - "${RESULTS}:/logs" 28 | working_dir: '/logs' 29 | ports: 30 | - "6006:6006" 31 | environment: 32 | reload_interval: 2 33 | log_dir: /logs 34 | 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Masaki Kozuki 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:9.2-cudnn7-devel-ubuntu16.04 2 | 3 | # create user 4 | ARG user_id 5 | ARG user_name 6 | ARG group_id 7 | ARG group_name 8 | RUN groupadd -g ${group_id} ${group_name} && \ 9 | useradd -u ${user_id} -g ${group_name} -s /bin/bash -m ${user_name} && \ 10 | echo "${user_name} ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers && \ 11 | chown -R ${user_name}:${group_name} /home/${user_name} 12 | 13 | # Default 14 | RUN apt-get update && apt-get install -y --no-install-recommends \ 15 | build-essential \ 16 | cmake \ 17 | git \ 18 | curl \ 19 | wget \ 20 | vim \ 21 | ca-certificates \ 22 | libjpeg-dev \ 23 | libgl1-mesa-dev \ 24 | libpng-dev \ 25 | build-essential \ 26 | zip \ 27 | unzip \ 28 | libpng-dev &&\ 29 | rm -rf /var/lib/apt/lists/* 30 | 31 | # Python Anaconda default. 32 | RUN wget -q https://repo.anaconda.com/archive/Anaconda3-5.3.1-Linux-x86_64.sh -O ~/anaconda.sh && \ 33 | /bin/bash ~/anaconda.sh -b -p /opt/conda && \ 34 | rm ~/anaconda.sh 35 | # Install PyTorch V1. 36 | ENV PATH /opt/conda/bin:$PATH 37 | ARG PYTHON_VERSION 38 | RUN conda install -y python=$PYTHON_VERSION && \ 39 | conda install -y -c conda-forge feather-format && \ 40 | conda install -y -c conda-forge jupyterlab && \ 41 | conda install -y -c conda-forge jupyter_contrib_nbextensions && \ 42 | jupyter contrib nbextension install --system 43 | RUN conda install -y pytorch torchvision -c pytorch 44 | RUN conda install -y -c conda-forge tensorflow && \ 45 | conda clean -y --all && \ 46 | pip install --no-cache-dir tensorboardX 47 | 48 | ENV PATH /opt/conda/bin:$PATH 49 | WORKDIR /src 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Repository specific 2 | tiny-imagenet-200 3 | results 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Environments 89 | .env 90 | .venv 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | -------------------------------------------------------------------------------- /models/generators/resnet64.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | 6 | from models.generators.resblocks import Block 7 | 8 | 9 | class ResNetGenerator(nn.Module): 10 | """Generator generates 64x64.""" 11 | 12 | def __init__(self, num_features=64, dim_z=128, bottom_width=4, 13 | activation=F.relu, num_classes=0, distribution='normal'): 14 | super(ResNetGenerator, self).__init__() 15 | self.num_features = num_features 16 | self.dim_z = dim_z 17 | self.bottom_width = bottom_width 18 | self.activation = activation 19 | self.num_classes = num_classes 20 | self.distribution = distribution 21 | 22 | self.l1 = nn.Linear(dim_z, 16 * num_features * bottom_width ** 2) 23 | 24 | self.block2 = Block(num_features * 16, num_features * 8, 25 | activation=activation, upsample=True, 26 | num_classes=num_classes) 27 | self.block3 = Block(num_features * 8, num_features * 4, 28 | activation=activation, upsample=True, 29 | num_classes=num_classes) 30 | self.block4 = Block(num_features * 4, num_features * 2, 31 | activation=activation, upsample=True, 32 | num_classes=num_classes) 33 | self.block5 = Block(num_features * 2, num_features, 34 | activation=activation, upsample=True, 35 | num_classes=num_classes) 36 | self.b6 = nn.BatchNorm2d(num_features) 37 | self.conv6 = nn.Conv2d(num_features, 3, 1, 1) 38 | 39 | def _initialize(self): 40 | init.xavier_uniform_(self.l1.weight.tensor) 41 | init.xavier_uniform_(self.conv7.weight.tensor) 42 | 43 | def forward(self, z, y=None, **kwargs): 44 | h = self.l1(z).view(z.size(0), -1, self.bottom_width, self.bottom_width) 45 | for i in range(2, 6): 46 | h = getattr(self, 'block{}'.format(i))(h, y, **kwargs) 47 | h = self.activation(self.b6(h)) 48 | return torch.tanh(self.conv6(h)) 49 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torchvision 5 | 6 | import metrics.fid 7 | import utils 8 | 9 | 10 | def evaluate(args, current_iter, gen, device, 11 | inception_model=None, eval_iter=None): 12 | """Evaluate model using 100 mini-batches.""" 13 | calc_fid = (inception_model is not None) and (eval_iter is not None) 14 | num_batches = args.n_eval_batches 15 | gen.eval() 16 | fake_list, real_list = [], [] 17 | conditional = args.cGAN 18 | for i in range(1, num_batches + 1): 19 | if conditional: 20 | class_id = i % args.num_classes 21 | else: 22 | class_id = None 23 | fake = utils.generate_images( 24 | gen, device, args.batch_size, args.gen_dim_z, 25 | args.gen_distribution, class_id=class_id 26 | ) 27 | if calc_fid and i <= args.n_fid_batches: 28 | fake_list.append((fake.cpu().numpy() + 1.0) / 2.0) 29 | real_list.append((next(eval_iter)[0].cpu().numpy() + 1.0) / 2.0) 30 | # Save generated images. 31 | root = args.eval_image_root 32 | if conditional: 33 | root = os.path.join(root, "class_id_{:04d}".format(i)) 34 | if not os.path.isdir(root): 35 | os.makedirs(root) 36 | fn = "image_iter_{:07d}_batch_{:04d}.png".format(current_iter, i) 37 | torchvision.utils.save_image( 38 | fake, os.path.join(root, fn), nrow=4, normalize=True, scale_each=True 39 | ) 40 | # Calculate FID scores 41 | if calc_fid: 42 | fake_images = np.concatenate(fake_list) 43 | real_images = np.concatenate(real_list) 44 | mu_fake, sigma_fake = metrics.fid.calculate_activation_statistics( 45 | fake_images, inception_model, args.batch_size, device=device 46 | ) 47 | mu_real, sigma_real = metrics.fid.calculate_activation_statistics( 48 | real_images, inception_model, args.batch_size, device=device 49 | ) 50 | fid_score = metrics.fid.calculate_frechet_distance( 51 | mu_fake, sigma_fake, mu_real, sigma_real 52 | ) 53 | else: 54 | fid_score = -1000 55 | gen.train() 56 | return fid_score 57 | -------------------------------------------------------------------------------- /models/generators/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | 6 | from models.generators.resblocks import Block 7 | 8 | 9 | class ResNetGenerator(nn.Module): 10 | """Generator generates 128x128.""" 11 | 12 | def __init__(self, num_features=64, dim_z=128, bottom_width=4, 13 | activation=F.relu, num_classes=0, distribution='normal'): 14 | super(ResNetGenerator, self).__init__() 15 | self.num_features = num_features 16 | self.dim_z = dim_z 17 | self.bottom_width = bottom_width 18 | self.activation = activation 19 | self.num_classes = num_classes 20 | self.distribution = distribution 21 | 22 | self.l1 = nn.Linear(dim_z, 16 * num_features * bottom_width ** 2) 23 | 24 | self.block2 = Block(num_features * 16, num_features * 16, 25 | activation=activation, upsample=True, 26 | num_classes=num_classes) 27 | self.block3 = Block(num_features * 16, num_features * 8, 28 | activation=activation, upsample=True, 29 | num_classes=num_classes) 30 | self.block4 = Block(num_features * 8, num_features * 4, 31 | activation=activation, upsample=True, 32 | num_classes=num_classes) 33 | self.block5 = Block(num_features * 4, num_features * 2, 34 | activation=activation, upsample=True, 35 | num_classes=num_classes) 36 | self.block6 = Block(num_features * 2, num_features, 37 | activation=activation, upsample=True, 38 | num_classes=num_classes) 39 | self.b7 = nn.BatchNorm2d(num_features) 40 | self.conv7 = nn.Conv2d(num_features, 3, 1, 1) 41 | 42 | def _initialize(self): 43 | init.xavier_uniform_(self.l1.weight.tensor) 44 | init.xavier_uniform_(self.conv7.weight.tensor) 45 | 46 | def forward(self, z, y=None, **kwargs): 47 | h = self.l1(z).view(z.size(0), -1, self.bottom_width, self.bottom_width) 48 | for i in [2, 3, 4, 5, 6]: 49 | h = getattr(self, 'block{}'.format(i))(h, y, **kwargs) 50 | h = self.activation(self.b7(h)) 51 | return torch.tanh(self.conv7(h)) 52 | -------------------------------------------------------------------------------- /models/generators/resblocks.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn import init 6 | 7 | from links import CategoricalConditionalBatchNorm2d 8 | 9 | 10 | def _upsample(x): 11 | h, w = x.size()[2:] 12 | return F.interpolate(x, size=(h * 2, w * 2), mode='bilinear') 13 | 14 | 15 | class Block(nn.Module): 16 | 17 | def __init__(self, in_ch, out_ch, h_ch=None, ksize=3, pad=1, 18 | activation=F.relu, upsample=False, num_classes=0): 19 | super(Block, self).__init__() 20 | 21 | self.activation = activation 22 | self.upsample = upsample 23 | self.learnable_sc = in_ch != out_ch or upsample 24 | if h_ch is None: 25 | h_ch = out_ch 26 | self.num_classes = num_classes 27 | 28 | # Register layrs 29 | self.c1 = nn.Conv2d(in_ch, h_ch, ksize, 1, pad) 30 | self.c2 = nn.Conv2d(h_ch, out_ch, ksize, 1, pad) 31 | if self.num_classes > 0: 32 | self.b1 = CategoricalConditionalBatchNorm2d( 33 | num_classes, in_ch) 34 | self.b2 = CategoricalConditionalBatchNorm2d( 35 | num_classes, h_ch) 36 | else: 37 | self.b1 = nn.BatchNorm2d(in_ch) 38 | self.b2 = nn.BatchNorm2d(h_ch) 39 | if self.learnable_sc: 40 | self.c_sc = nn.Conv2d(in_ch, out_ch, 1) 41 | 42 | def _initialize(self): 43 | init.xavier_uniform_(self.c1.weight.tensor, gain=math.sqrt(2)) 44 | init.xavier_uniform_(self.c2.weight.tensor, gain=math.sqrt(2)) 45 | if self.learnable_sc: 46 | init.xavier_uniform_(self.c_sc.weight.tensor, gain=1) 47 | 48 | def forward(self, x, y=None, z=None, **kwargs): 49 | return self.shortcut(x) + self.residual(x, y, z) 50 | 51 | def shortcut(self, x, **kwargs): 52 | if self.learnable_sc: 53 | if self.upsample: 54 | h = _upsample(x) 55 | h = self.c_sc(h) 56 | return h 57 | else: 58 | return x 59 | 60 | def residual(self, x, y=None, z=None, **kwargs): 61 | if y is not None: 62 | h = self.b1(x, y, **kwargs) 63 | else: 64 | h = self.b1(x) 65 | h = self.activation(h) 66 | if self.upsample: 67 | h = _upsample(h) 68 | h = self.c1(h) 69 | if y is not None: 70 | h = self.b2(h, y, **kwargs) 71 | else: 72 | h = self.b2(h) 73 | return self.c2(self.activation(h)) 74 | -------------------------------------------------------------------------------- /models/discriminators/resblocks.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import init 7 | from torch.nn import utils 8 | 9 | 10 | class Block(nn.Module): 11 | 12 | def __init__(self, in_ch, out_ch, h_ch=None, ksize=3, pad=1, 13 | activation=F.relu, downsample=False): 14 | super(Block, self).__init__() 15 | 16 | self.activation = activation 17 | self.downsample = downsample 18 | 19 | self.learnable_sc = (in_ch != out_ch) or downsample 20 | if h_ch is None: 21 | h_ch = in_ch 22 | else: 23 | h_ch = out_ch 24 | 25 | self.c1 = utils.spectral_norm(nn.Conv2d(in_ch, h_ch, ksize, 1, pad)) 26 | self.c2 = utils.spectral_norm(nn.Conv2d(h_ch, out_ch, ksize, 1, pad)) 27 | if self.learnable_sc: 28 | self.c_sc = utils.spectral_norm(nn.Conv2d(in_ch, out_ch, 1, 1, 0)) 29 | 30 | self._initialize() 31 | 32 | def _initialize(self): 33 | init.xavier_uniform_(self.c1.weight.data, math.sqrt(2)) 34 | init.xavier_uniform_(self.c2.weight.data, math.sqrt(2)) 35 | if self.learnable_sc: 36 | init.xavier_uniform_(self.c_sc.weight.data) 37 | 38 | def forward(self, x): 39 | return self.shortcut(x) + self.residual(x) 40 | 41 | def shortcut(self, x): 42 | if self.learnable_sc: 43 | x = self.c_sc(x) 44 | if self.downsample: 45 | return F.avg_pool2d(x, 2) 46 | return x 47 | 48 | def residual(self, x): 49 | h = self.c1(self.activation(x)) 50 | h = self.c2(self.activation(h)) 51 | if self.downsample: 52 | h = F.avg_pool2d(h, 2) 53 | return h 54 | 55 | 56 | class OptimizedBlock(nn.Module): 57 | 58 | def __init__(self, in_ch, out_ch, ksize=3, pad=1, activation=F.relu): 59 | super(OptimizedBlock, self).__init__() 60 | self.activation = activation 61 | 62 | self.c1 = utils.spectral_norm(nn.Conv2d(in_ch, out_ch, ksize, 1, pad)) 63 | self.c2 = utils.spectral_norm(nn.Conv2d(out_ch, out_ch, ksize, 1, pad)) 64 | self.c_sc = utils.spectral_norm(nn.Conv2d(in_ch, out_ch, 1, 1, 0)) 65 | 66 | self._initialize() 67 | 68 | def _initialize(self): 69 | init.xavier_uniform_(self.c1.weight.data, math.sqrt(2)) 70 | init.xavier_uniform_(self.c2.weight.data, math.sqrt(2)) 71 | init.xavier_uniform_(self.c_sc.weight.data) 72 | 73 | def forward(self, x): 74 | return self.shortcut(x) + self.residual(x) 75 | 76 | def shortcut(self, x): 77 | return self.c_sc(F.avg_pool2d(x, 2)) 78 | 79 | def residual(self, x): 80 | h = self.activation(self.c1(x)) 81 | return F.avg_pool2d(self.c2(h), 2) 82 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | AVAILABLE_LOSSES = ["hinge", "dcgan"] 6 | 7 | 8 | def dis_hinge(dis_fake, dis_real): 9 | loss = torch.mean(torch.relu(1. - dis_real)) +\ 10 | torch.mean(torch.relu(1. + dis_fake)) 11 | return loss 12 | 13 | 14 | def gen_hinge(dis_fake, dis_real=None): 15 | return -torch.mean(dis_fake) 16 | 17 | 18 | def dis_dcgan(dis_fake, dis_real): 19 | loss = torch.mean(F.softplus(-dis_real)) + torch.mean(F.softplus(dis_fake)) 20 | return loss 21 | 22 | 23 | def gen_dcgan(dis_fake, dis_real=None): 24 | return torch.mean(F.softplus(-dis_fake)) 25 | 26 | 27 | class _Loss(object): 28 | 29 | """GAN Loss base class. 30 | 31 | Args: 32 | loss_type (str) 33 | is_relativistic (bool) 34 | 35 | """ 36 | 37 | def __init__(self, loss_type, is_relativistic=False): 38 | assert loss_type in AVAILABLE_LOSSES, "Invalid loss. Choose from {}".format(AVAILABLE_LOSSES) 39 | self.loss_type = loss_type 40 | self.is_relativistic = is_relativistic 41 | 42 | def _preprocess(self, dis_fake, dis_real): 43 | C_xf_tilde = torch.mean(dis_fake, dim=0, keepdim=True).expand_as(dis_fake) 44 | C_xr_tilde = torch.mean(dis_real, dim=0, keepdim=True).expand_as(dis_real) 45 | return dis_fake - C_xr_tilde, dis_real - C_xf_tilde 46 | 47 | 48 | class DisLoss(_Loss): 49 | 50 | """Discriminator Loss.""" 51 | 52 | def __call__(self, dis_fake, dis_real, **kwargs): 53 | if not self.is_relativistic: 54 | if self.loss_type == "hinge": 55 | return dis_hinge(dis_fake, dis_real) 56 | elif self.loss_type == "dcgan": 57 | return dis_dcgan(dis_fake, dis_real) 58 | else: 59 | d_xf, d_xr = self._preprocess(dis_fake, dis_real) 60 | if self.loss_type == "hinge": 61 | return dis_hinge(d_xf, d_xr) 62 | elif self.loss_type == "dcgan": 63 | D_xf = torch.sigmoid(d_xf) 64 | D_xr = torch.sigmoid(d_xr) 65 | return -torch.log(D_xr) - torch.log(1.0 - D_xf) 66 | else: 67 | raise NotImplementedError 68 | 69 | 70 | class GenLoss(_Loss): 71 | 72 | """Generator Loss.""" 73 | 74 | def __call__(self, dis_fake, dis_real=None, **kwargs): 75 | if not self.is_relativistic: 76 | if self.loss_type == "hinge": 77 | return gen_hinge(dis_fake, dis_real) 78 | elif self.loss_type == "dcgan": 79 | return gen_dcgan(dis_fake, dis_real) 80 | else: 81 | assert dis_real is not None, "Relativistic Generator loss requires `dis_real`." 82 | d_xf, d_xr = self._preprocess(dis_fake, dis_real) 83 | if self.loss_type == "hinge": 84 | return dis_hinge(d_xr, d_xf) 85 | elif self.loss_type == "dcgan": 86 | D_xf = torch.sigmoid(d_xf) 87 | D_xr = torch.sigmoid(d_xr) 88 | return -torch.log(D_xf) - torch.log(1.0 - D_xr) 89 | else: 90 | raise NotImplementedError 91 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **UPDATE 2019/07/14: C-GAN has a bug in network definition!** 2 | Thank you @UdonDa for reporting and pointing out ([issue#28](https://github.com/crcrpar/pytorch.sngan_projection/issues/28)) 3 | 4 | **UPDATE 2019/05/24: Current implementation of FID score is incorrect!** 5 | Thank you @youngjung for the report & fix suggestions ([issue#25](https://github.com/crcrpar/pytorch.sngan_projection/issues/25))! 6 | 7 | --- 8 | 9 | The original is available at https://github.com/pfnet-research/sngan_projection. 10 | 11 | # SNGAN and cGANs with projection discriminator 12 | _**This is unofficial PyTorch implementation of sngan_projection.**_ 13 | _**This does not reproduce the experiments and results reported in the paper due to the lack of GPUs.**_ 14 | **This repository does some experiments on images of size 64x64.** 15 | 16 | Some results are on issues with _results_ label. 17 | 18 | ## SNGAN 19 | > Spectral Normalization for Generative Adversarial Networks 20 | > Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida 21 | > OpenReview: https://openreview.net/forum?id=B1QRgziT- 22 | > arXiv: https://arxiv.org/abs/1802.05957 23 | 24 | ## cGANs with projection discriminator 25 | > cGANs with Projection Discriminator 26 | > Takeru Miyato, Masanori Koyama 27 | > OpenReview: https://openreview.net/forum?id=ByS1VpgRZ 28 | > arXiv: https://arxiv.org/abs/1802.05637 29 | 30 | ## Requirements 31 | - Python 3.6.4 32 | - PyTorch 0.4.1 33 | - torchvision 0.2.1 34 | - NumPy: Used in FID score calculation and data loader 35 | - SciPy: Used in FID score calculation 36 | - tensorflow (optional) 37 | - tensorboardX (optional) 38 | - tqdm: Progressbar and Log 39 | 40 | If you want to use **tensorboard** for beautiful training update visualization, please install tensorflow and tensorboardX. 41 | When using only tensorboard, tensorflow cpu is enough. 42 | 43 | ### Docker environment 44 | Dockerfiles for pytorch 1.0 environment and tensorboard are added. PyTorch 1.0 Dockerfile requires an nvidia driver that supports CUDA 9.2. 45 | Also, this dockerized environment needs some environment variables: 46 | - `DATA`: Path to dataset 47 | - `RESULTS`: Path to save results 48 | - `PORT`: Port number for jupyter notebook. 49 | 50 | ## Dataset 51 | - tiny ImageNet[^1]. 52 | 53 | > Tiny Imagenet has 200 classes. Each class has 500 training images, 50 validation images, and 50 test images. 54 | 55 | [^1]: https://tiny-imagenet.herokuapp.com/ 56 | 57 | ## Training configuration 58 | Default parameters are the same as the original Chainer implementation. 59 | 60 | - to train cGAN with projection discriminator: run `train_64.py` with `--cGAN` option. 61 | - to train cGAN with concat discriminator: run `train_64.py` with both `--cGAN` and `--dis_arch_concat`. 62 | - to run without `tensorboard`, please add `--no_tensorboard`. 63 | - to calculate FID, add `--calc_FID` (not tested) 64 | - to use make discriminator relativistic, add `--relativistic_loss` or `-relloss` (not tested) 65 | 66 | To see all the available arguments, run `python train_64.py --help`. 67 | 68 | ## TODO 69 | - [ ] implement super-resolution (cGAN) 70 | 71 | # Acknowledgement 72 | 1. https://github.com/pfnet-research/sngan_projection 73 | 2. https://github.com/mseitzer/pytorch-fid: FID score 74 | 3. https://github.com/naoto0804/pytorch-AdaIN: Infinite Sampler of DataLoader 75 | -------------------------------------------------------------------------------- /models/discriminators/snresnet64.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | from torch.nn import utils 6 | 7 | from models.discriminators.resblocks import Block 8 | from models.discriminators.resblocks import OptimizedBlock 9 | 10 | 11 | class SNResNetProjectionDiscriminator(nn.Module): 12 | 13 | def __init__(self, num_features=64, num_classes=0, activation=F.relu): 14 | super(SNResNetProjectionDiscriminator, self).__init__() 15 | self.num_features = num_features 16 | self.num_classes = num_classes 17 | self.activation = activation 18 | 19 | self.block1 = OptimizedBlock(3, num_features) 20 | self.block2 = Block(num_features, num_features * 2, 21 | activation=activation, downsample=True) 22 | self.block3 = Block(num_features * 2, num_features * 4, 23 | activation=activation, downsample=True) 24 | self.block4 = Block(num_features * 4, num_features * 8, 25 | activation=activation, downsample=True) 26 | self.block5 = Block(num_features * 8, num_features * 16, 27 | activation=activation, downsample=True) 28 | self.l6 = utils.spectral_norm(nn.Linear(num_features * 16, 1)) 29 | if num_classes > 0: 30 | self.l_y = utils.spectral_norm( 31 | nn.Embedding(num_classes, num_features * 16)) 32 | 33 | self._initialize() 34 | 35 | def _initialize(self): 36 | init.xavier_uniform_(self.l6.weight.data) 37 | optional_l_y = getattr(self, 'l_y', None) 38 | if optional_l_y is not None: 39 | init.xavier_uniform_(optional_l_y.weight.data) 40 | 41 | def forward(self, x, y=None): 42 | h = x 43 | h = self.block1(h) 44 | h = self.block2(h) 45 | h = self.block3(h) 46 | h = self.block4(h) 47 | h = self.block5(h) 48 | h = self.activation(h) 49 | # Global pooling 50 | h = torch.sum(h, dim=(2, 3)) 51 | output = self.l6(h) 52 | if y is not None: 53 | output += torch.sum(self.l_y(y) * h, dim=1, keepdim=True) 54 | return output 55 | 56 | 57 | class SNResNetConcatDiscriminator(nn.Module): 58 | 59 | def __init__(self, num_features, num_classes, activation=F.relu, 60 | dim_emb=128): 61 | super(SNResNetConcatDiscriminator, self).__init__() 62 | self.num_features = num_features 63 | self.num_classes = num_classes 64 | self.dim_emb = dim_emb 65 | self.activation = activation 66 | 67 | self.block1 = OptimizedBlock(3, num_features) 68 | self.block2 = Block(num_features, num_features * 2, 69 | activation=activation, downsample=True) 70 | self.block3 = Block(num_features * 2, num_features * 4, 71 | activation=activation, downsample=True) 72 | if num_classes > 0: 73 | self.l_y = utils.spectral_norm(nn.Embedding(num_classes, dim_emb)) 74 | self.block4 = Block(num_features * 4 + dim_emb, num_features * 8, 75 | activation=activation, downsample=True) 76 | self.block5 = Block(num_features * 8, num_features * 16, 77 | activation=activation, downsample=True) 78 | self.l6 = utils.spectral_norm(nn.Linear(num_features * 16, 1)) 79 | 80 | self._initialize() 81 | 82 | def _initialize(self): 83 | init.xavier_uniform_(self.l6.weight.data) 84 | if hasattr(self, 'l_y'): 85 | init.xavier_uniform_(self.l_y.weight.data) 86 | 87 | def forward(self, x, y=None): 88 | h = x 89 | h = self.block1(h) 90 | h = self.block2(h) 91 | h = self.block3(h) 92 | if y is not None: 93 | emb = self.l_y(y).unsqueeze(-1).unsqueeze(-1) 94 | emb = emb.expand(emb.size(0), emb.size(1), h.size(2), h.size(3)) 95 | h = torch.cat((h, emb), dim=1) 96 | h = self.block4(h) 97 | h = self.block5(h) 98 | h = torch.sum(self.activation(h), dim=(2, 3)) 99 | return self.l6(h) 100 | -------------------------------------------------------------------------------- /models/discriminators/snresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | from torch.nn import utils 6 | 7 | from models.discriminators.resblocks import Block 8 | from models.discriminators.resblocks import OptimizedBlock 9 | 10 | 11 | class SNResNetProjectionDiscriminator(nn.Module): 12 | 13 | def __init__(self, num_features, num_classes=0, activation=F.relu): 14 | super(SNResNetProjectionDiscriminator, self).__init__() 15 | self.num_features = num_features 16 | self.num_classes = num_classes 17 | self.activation = activation 18 | 19 | self.block1 = OptimizedBlock(3, num_features) 20 | self.block2 = Block(num_features, num_features * 2, 21 | activation=activation, downsample=True) 22 | self.block3 = Block(num_features * 2, num_features * 4, 23 | activation=activation, downsample=True) 24 | self.block4 = Block(num_features * 4, num_features * 8, 25 | activation=activation, downsample=True) 26 | self.block5 = Block(num_features * 8, num_features * 16, 27 | activation=activation, downsample=True) 28 | self.block6 = Block(num_features * 16, num_features * 16, 29 | activation=activation, downsample=True) 30 | self.l7 = utils.spectral_norm(nn.Linear(num_features * 16, 1)) 31 | if num_classes > 0: 32 | self.l_y = utils.spectral_norm( 33 | nn.Embedding(num_classes, num_features * 16)) 34 | 35 | self._initialize() 36 | 37 | def _initialize(self): 38 | init.xavier_uniform_(self.l7.weight.data) 39 | optional_l_y = getattr(self, 'l_y', None) 40 | if optional_l_y is not None: 41 | init.xavier_uniform_(optional_l_y.weight.data) 42 | 43 | def forward(self, x, y=None): 44 | h = x 45 | for i in range(1, 7): 46 | h = getattr(self, 'block{}'.format(i))(h) 47 | h = self.activation(h) 48 | # Global pooling 49 | h = torch.sum(h, dim=(2, 3)) 50 | output = self.l7(h) 51 | if y is not None: 52 | output += torch.sum(self.l_y(y) * h, dim=1, keepdim=True) 53 | return output 54 | 55 | 56 | class SNResNetConcatDiscriminator(nn.Module): 57 | 58 | def __init__(self, num_features, num_classes, activation=F.relu, 59 | dim_emb=128): 60 | super(SNResNetConcatDiscriminator, self).__init__() 61 | self.num_features = num_features 62 | self.num_classes = num_classes 63 | self.dim_emb = dim_emb 64 | self.activation = activation 65 | 66 | self.block1 = OptimizedBlock(3, num_features) 67 | self.block2 = Block(num_features, num_features * 2, 68 | activation=activation, downsample=True) 69 | self.block3 = Block(num_features * 2, num_features * 4, 70 | activation=activation, downsample=True) 71 | if num_classes > 0: 72 | self.l_y = utils.spectral_norm(nn.Embedding(num_classes, dim_emb)) 73 | self.block4 = Block(num_features * 4 + dim_emb, num_features * 8, 74 | activation=activation, downsample=True) 75 | self.block5 = Block(num_features * 8, num_features * 16, 76 | activation=activation, downsample=True) 77 | self.block6 = Block(num_features * 16, num_features * 16, 78 | activation=activation, downsample=False) 79 | self.l7 = utils.spectral_norm(nn.Linear(num_features * 16, 1)) 80 | 81 | self._initialize() 82 | 83 | def _initialize(self): 84 | init.xavier_uniform_(self.l7.weight.data) 85 | if hasattr(self, 'l_y'): 86 | init.xavier_uniform_(self.l_y.weight.data) 87 | 88 | def forward(self, x, y=None): 89 | h = x 90 | for i in range(1, 4): 91 | h = getattr(self, 'block{}'.format(i))(h) 92 | if y is not None: 93 | emb = self.l_y(y).unsqueeze(-1).unsqueeze(-1) 94 | emb = emb.expand(emb.size(0), emb.size(1), h.size(2), h.size(3)) 95 | h = torch.cat((h, emb), dim=1) 96 | for i in range(4, 7): 97 | h = getattr(self, 'block{}'.format(i))(h) 98 | h = torch.sum(self.activation(h), dim=(2, 3)) 99 | return self.l7(h) 100 | -------------------------------------------------------------------------------- /links/conditional_batchnorm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torch.nn import init 4 | 5 | 6 | class ConditionalBatchNorm2d(nn.BatchNorm2d): 7 | 8 | """Conditional Batch Normalization""" 9 | 10 | def __init__(self, num_features, eps=1e-05, momentum=0.1, 11 | affine=False, track_running_stats=True): 12 | super(ConditionalBatchNorm2d, self).__init__( 13 | num_features, eps, momentum, affine, track_running_stats 14 | ) 15 | 16 | def forward(self, input, weight, bias, **kwargs): 17 | self._check_input_dim(input) 18 | 19 | exponential_average_factor = 0.0 20 | 21 | if self.training and self.track_running_stats: 22 | self.num_batches_tracked += 1 23 | if self.momentum is None: # use cumulative moving average 24 | exponential_average_factor = 1.0 / self.num_batches_tracked.item() 25 | else: # use exponential moving average 26 | exponential_average_factor = self.momentum 27 | 28 | output = F.batch_norm(input, self.running_mean, self.running_var, 29 | self.weight, self.bias, 30 | self.training or not self.track_running_stats, 31 | exponential_average_factor, self.eps) 32 | if weight.dim() == 1: 33 | weight = weight.unsqueeze(0) 34 | if bias.dim() == 1: 35 | bias = bias.unsqueeze(0) 36 | size = output.size() 37 | weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size) 38 | bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size) 39 | return weight * output + bias 40 | 41 | 42 | class CategoricalConditionalBatchNorm2d(ConditionalBatchNorm2d): 43 | 44 | def __init__(self, num_classes, num_features, eps=1e-5, momentum=0.1, 45 | affine=False, track_running_stats=True): 46 | super(CategoricalConditionalBatchNorm2d, self).__init__( 47 | num_features, eps, momentum, affine, track_running_stats 48 | ) 49 | self.weights = nn.Embedding(num_classes, num_features) 50 | self.biases = nn.Embedding(num_classes, num_features) 51 | 52 | self._initialize() 53 | 54 | def _initialize(self): 55 | init.ones_(self.weights.weight.data) 56 | init.zeros_(self.biases.weight.data) 57 | 58 | def forward(self, input, c, **kwargs): 59 | weight = self.weights(c) 60 | bias = self.biases(c) 61 | 62 | return super(CategoricalConditionalBatchNorm2d, self).forward(input, weight, bias) 63 | 64 | 65 | if __name__ == '__main__': 66 | """Forward computation check.""" 67 | import torch 68 | size = (3, 3, 12, 12) 69 | batch_size, num_features = size[:2] 70 | print('# Affirm embedding output') 71 | naive_bn = nn.BatchNorm2d(3) 72 | idx_input = torch.tensor([1, 2, 0], dtype=torch.long) 73 | embedding = nn.Embedding(3, 3) 74 | weights = embedding(idx_input) 75 | print('# weights size', weights.size()) 76 | empty = torch.tensor((), dtype=torch.float) 77 | running_mean = empty.new_zeros((3,)) 78 | running_var = empty.new_ones((3,)) 79 | 80 | naive_bn_W = naive_bn.weight 81 | # print('# weights from embedding | type {}\n'.format(type(weights)), weights) 82 | # print('# naive_bn_W | type {}\n'.format(type(naive_bn_W)), naive_bn_W) 83 | input = torch.rand(*size, dtype=torch.float32) 84 | print('input size', input.size()) 85 | print('input ndim ', input.dim()) 86 | 87 | _ = naive_bn(input) 88 | 89 | print('# batch_norm with given weights') 90 | 91 | try: 92 | with torch.no_grad(): 93 | output = F.batch_norm(input, running_mean, running_var, 94 | weights, naive_bn.bias, False, 0.0, 1e-05) 95 | except Exception as e: 96 | print("\tFailed to use given weights") 97 | print('# Error msg:', e) 98 | print() 99 | else: 100 | print("Succeeded to use given weights") 101 | 102 | print('\n# Batch norm before use given weights') 103 | with torch.no_grad(): 104 | tmp_out = F.batch_norm(input, running_mean, running_var, 105 | naive_bn_W, naive_bn.bias, False, .0, 1e-05) 106 | weights_cast = weights.unsqueeze(-1).unsqueeze(-1) 107 | weights_cast = weights_cast.expand(tmp_out.size()) 108 | try: 109 | out = weights_cast * tmp_out 110 | except Exception: 111 | print("Failed") 112 | else: 113 | print("Succeeded!") 114 | print('\t {}'.format(out.size())) 115 | print(type(tuple(out.size()))) 116 | 117 | print('--- condBN and catCondBN ---') 118 | 119 | catCondBN = CategoricalConditionalBatchNorm2d(3, 3) 120 | output = catCondBN(input, idx_input) 121 | 122 | assert tuple(output.size()) == size 123 | 124 | condBN = ConditionalBatchNorm2d(3) 125 | 126 | idx = torch.tensor([1], dtype=torch.long) 127 | out = catCondBN(input, idx) 128 | 129 | print('cat cond BN weights\n', catCondBN.weights.weight.data) 130 | print('cat cond BN biases\n', catCondBN.biases.weight.data) 131 | -------------------------------------------------------------------------------- /models/inception.py: -------------------------------------------------------------------------------- 1 | """https://raw.githubusercontent.com/mseitzer/pytorch-fid""" 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | 6 | 7 | class InceptionV3(nn.Module): 8 | """Pretrained InceptionV3 network returning feature maps""" 9 | 10 | # Index of default block of inception to return, 11 | # corresponds to output of final average pooling 12 | DEFAULT_BLOCK_INDEX = 3 13 | 14 | # Maps feature dimensionality to their output blocks indices 15 | BLOCK_INDEX_BY_DIM = { 16 | 64: 0, # First max pooling features 17 | 192: 1, # Second max pooling featurs 18 | 768: 2, # Pre-aux classifier features 19 | 2048: 3 # Final average pooling features 20 | } 21 | 22 | def __init__(self, 23 | output_blocks=[DEFAULT_BLOCK_INDEX], 24 | resize_input=True, 25 | normalize_input=True, 26 | requires_grad=False): 27 | """Build pretrained InceptionV3 28 | 29 | Parameters 30 | ---------- 31 | output_blocks : list of int 32 | Indices of blocks to return features of. Possible values are: 33 | - 0: corresponds to output of first max pooling 34 | - 1: corresponds to output of second max pooling 35 | - 2: corresponds to output which is fed to aux classifier 36 | - 3: corresponds to output of final average pooling 37 | resize_input : bool 38 | If true, bilinearly resizes input to width and height 299 before 39 | feeding input to model. As the network without fully connected 40 | layers is fully convolutional, it should be able to handle inputs 41 | of arbitrary size, so resizing might not be strictly needed 42 | normalize_input : bool 43 | If true, normalizes the input to the statistics the pretrained 44 | Inception network expects 45 | requires_grad : bool 46 | If true, parameters of the model require gradient. Possibly useful 47 | for finetuning the network 48 | """ 49 | super(InceptionV3, self).__init__() 50 | 51 | self.resize_input = resize_input 52 | self.normalize_input = normalize_input 53 | self.output_blocks = sorted(output_blocks) 54 | self.last_needed_block = max(output_blocks) 55 | 56 | assert self.last_needed_block <= 3, \ 57 | 'Last possible output block index is 3' 58 | 59 | self.blocks = nn.ModuleList() 60 | 61 | inception = models.inception_v3(pretrained=True) 62 | 63 | # Block 0: input to maxpool1 64 | block0 = [ 65 | inception.Conv2d_1a_3x3, 66 | inception.Conv2d_2a_3x3, 67 | inception.Conv2d_2b_3x3, 68 | nn.MaxPool2d(kernel_size=3, stride=2) 69 | ] 70 | self.blocks.append(nn.Sequential(*block0)) 71 | 72 | # Block 1: maxpool1 to maxpool2 73 | if self.last_needed_block >= 1: 74 | block1 = [ 75 | inception.Conv2d_3b_1x1, 76 | inception.Conv2d_4a_3x3, 77 | nn.MaxPool2d(kernel_size=3, stride=2) 78 | ] 79 | self.blocks.append(nn.Sequential(*block1)) 80 | 81 | # Block 2: maxpool2 to aux classifier 82 | if self.last_needed_block >= 2: 83 | block2 = [ 84 | inception.Mixed_5b, 85 | inception.Mixed_5c, 86 | inception.Mixed_5d, 87 | inception.Mixed_6a, 88 | inception.Mixed_6b, 89 | inception.Mixed_6c, 90 | inception.Mixed_6d, 91 | inception.Mixed_6e, 92 | ] 93 | self.blocks.append(nn.Sequential(*block2)) 94 | 95 | # Block 3: aux classifier to final avgpool 96 | if self.last_needed_block >= 3: 97 | block3 = [ 98 | inception.Mixed_7a, 99 | inception.Mixed_7b, 100 | inception.Mixed_7c, 101 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 102 | ] 103 | self.blocks.append(nn.Sequential(*block3)) 104 | 105 | for param in self.parameters(): 106 | param.requires_grad = requires_grad 107 | 108 | def forward(self, inp): 109 | """Get Inception feature maps 110 | 111 | Parameters 112 | ---------- 113 | inp : torch.autograd.Variable 114 | Input tensor of shape Bx3xHxW. Values are expected to be in 115 | range (0, 1) 116 | 117 | Returns 118 | ------- 119 | List of torch.autograd.Variable, corresponding to the selected output 120 | block, sorted ascending by index 121 | """ 122 | outp = [] 123 | x = inp 124 | 125 | if self.resize_input: 126 | x = F.interpolate(x, size=(299, 299), mode='bilinear') 127 | 128 | if self.normalize_input: 129 | x = x.clone() 130 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 131 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 132 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 133 | 134 | for idx, block in enumerate(self.blocks): 135 | x = block(x) 136 | if idx in self.output_blocks: 137 | outp.append(x) 138 | 139 | if idx == self.last_needed_block: 140 | break 141 | 142 | return outp 143 | -------------------------------------------------------------------------------- /metrics/fid.py: -------------------------------------------------------------------------------- 1 | """Derived from https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py""" # NOQA 2 | import numpy as np 3 | from scipy import linalg 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | def get_activations(images, model, batch_size=64, dims=2048, device=None): 9 | """Calculates the activations of the pool_3 layer for all images. 10 | 11 | Params: 12 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values 13 | must lie between 0 and 1. 14 | -- model : Instance of inception model 15 | -- batch_size : the images numpy array is split into batches with 16 | batch size batch_size. A reasonable batch size depends 17 | on the hardware. 18 | -- dims : Dimensionality of features returned by Inception 19 | -- device : torch.Device 20 | 21 | Returns: 22 | -- A numpy array of dimension (num images, dims) that contains the 23 | activations of the given tensor when feeding inception with the 24 | query tensor. 25 | """ 26 | model.eval() 27 | 28 | d0 = images.shape[0] 29 | if batch_size > d0: 30 | print(('Warning: batch size is bigger than the data size. ' 31 | 'Setting batch size to data size')) 32 | batch_size = d0 33 | 34 | n_batches = d0 // batch_size 35 | n_used_imgs = n_batches * batch_size 36 | 37 | pred_arr = np.empty((n_used_imgs, dims)) 38 | for i in range(n_batches): 39 | start = i * batch_size 40 | end = start + batch_size 41 | 42 | batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor) 43 | if device is not None: 44 | batch = batch.to(device) 45 | 46 | with torch.no_grad(): 47 | pred = model(batch)[0] 48 | 49 | # If model output is not scalar, apply global spatial average pooling. 50 | # This happens if you choose a dimensionality not equal 2048. 51 | if pred.shape[2] != 1 or pred.shape[3] != 1: 52 | pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1)) 53 | 54 | pred_arr[start:end] = pred.cpu().numpy().reshape(batch_size, -1) 55 | 56 | return pred_arr 57 | 58 | 59 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 60 | """Numpy implementation of the Frechet Distance. 61 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 62 | and X_2 ~ N(mu_2, C_2) is 63 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 64 | Stable version by Dougal J. Sutherland. 65 | Params: 66 | -- mu1 : Numpy array containing the activations of a layer of the 67 | inception net (like returned by the function 'get_predictions') 68 | for generated samples. 69 | -- mu2 : The sample mean over activations, precalculated on an 70 | representive data set. 71 | -- sigma1: The covariance matrix over activations for generated samples. 72 | -- sigma2: The covariance matrix over activations, precalculated on an 73 | representive data set. 74 | Returns: 75 | -- : The Frechet Distance. 76 | """ 77 | 78 | mu1 = np.atleast_1d(mu1) 79 | mu2 = np.atleast_1d(mu2) 80 | 81 | sigma1 = np.atleast_2d(sigma1) 82 | sigma2 = np.atleast_2d(sigma2) 83 | 84 | assert mu1.shape == mu2.shape, \ 85 | 'Training and test mean vectors have different lengths' 86 | assert sigma1.shape == sigma2.shape, \ 87 | 'Training and test covariances have different dimensions' 88 | 89 | diff = mu1 - mu2 90 | 91 | # Product might be almost singular 92 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 93 | if not np.isfinite(covmean).all(): 94 | msg = ('fid calculation produces singular product; ' 95 | 'adding %s to diagonal of cov estimates') % eps 96 | print(msg) 97 | offset = np.eye(sigma1.shape[0]) * eps 98 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 99 | 100 | # Numerical error might give slight imaginary component 101 | if np.iscomplexobj(covmean): 102 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 103 | m = np.max(np.abs(covmean.imag)) 104 | raise ValueError('Imaginary component {}'.format(m)) 105 | covmean = covmean.real 106 | 107 | tr_covmean = np.trace(covmean) 108 | 109 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 110 | 111 | 112 | def calculate_activation_statistics(images, model, batch_size=64, dims=2048, device=None): 113 | """Calculation of the statistics used by the FID. 114 | Params: 115 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values 116 | must lie between 0 and 1. 117 | -- model : Instance of inception model 118 | -- batch_size : The images numpy array is split into batches with 119 | batch size batch_size. A reasonable batch size 120 | depends on the hardware. 121 | -- dims : Dimensionality of features returned by Inception 122 | -- device : If set to True, use GPU 123 | -- verbose : If set to True and parameter out_step is given, the 124 | number of calculated batches is reported. 125 | Returns: 126 | -- mu : The mean over samples of the activations of the pool_3 layer of 127 | the inception model. 128 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 129 | the inception model. 130 | """ 131 | act = get_activations(images, model, batch_size, dims, device) 132 | mu = np.mean(act, axis=0) 133 | sigma = np.cov(act, rowvar=False) 134 | return mu, sigma 135 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import shutil 4 | 5 | import torch 6 | import torchvision 7 | import numpy 8 | 9 | 10 | class Dict2Args(object): 11 | 12 | """Dict-argparse object converter.""" 13 | 14 | def __init__(self, dict_args): 15 | for key, value in dict_args.items(): 16 | setattr(self, key, value) 17 | 18 | 19 | def generate_images(gen, device, batch_size=64, dim_z=128, distribution=None, 20 | num_classes=None, class_id=None): 21 | """Generate images. 22 | 23 | Priority: num_classes > class_id. 24 | 25 | Args: 26 | gen (nn.Module): generator. 27 | device (torch.device) 28 | batch_size (int) 29 | dim_z (int) 30 | distribution (str) 31 | num_classes (int, optional) 32 | class_id (int, optional) 33 | 34 | Returns: 35 | torch.tensor 36 | 37 | """ 38 | 39 | z = sample_z(batch_size, dim_z, device, distribution) 40 | if num_classes is None and class_id is None: 41 | y = None 42 | elif num_classes is not None: 43 | y = sample_pseudo_labels(num_classes, batch_size, device) 44 | elif class_id is not None: 45 | y = torch.tensor([class_id] * batch_size, dtype=torch.long).to(device) 46 | else: 47 | y = None 48 | with torch.no_grad(): 49 | fake = gen(z, y) 50 | 51 | return fake 52 | 53 | 54 | def sample_z(batch_size, dim_z, device, distribution=None): 55 | """Sample random noises. 56 | 57 | Args: 58 | batch_size (int) 59 | dim_z (int) 60 | device (torch.device) 61 | distribution (str, optional): default is normal 62 | 63 | Returns: 64 | torch.FloatTensor or torch.cuda.FloatTensor 65 | 66 | """ 67 | 68 | if distribution is None: 69 | distribution = 'normal' 70 | if distribution == 'normal': 71 | return torch.empty(batch_size, dim_z, dtype=torch.float32, device=device).normal_() 72 | else: 73 | return torch.empty(batch_size, dim_z, dtype=torch.float32, device=device).uniform_() 74 | 75 | 76 | def sample_pseudo_labels(num_classes, batch_size, device): 77 | """Sample pseudo-labels. 78 | 79 | Args: 80 | num_classes (int): number of classes in the dataset. 81 | batch_size (int): size of mini-batch. 82 | device (torch.Device): For compatibility. 83 | 84 | Returns: 85 | ~torch.LongTensor or torch.cuda.LongTensor. 86 | 87 | """ 88 | 89 | pseudo_labels = torch.from_numpy( 90 | numpy.random.randint(low=0, high=num_classes, size=(batch_size)) 91 | ) 92 | pseudo_labels = pseudo_labels.type(torch.long).to(device) 93 | return pseudo_labels 94 | 95 | 96 | def save_images(n_iter, count, root, train_image_root, fake, real): 97 | """Save images (torch.tensor). 98 | 99 | Args: 100 | root (str) 101 | train_image_root (root) 102 | fake (torch.tensor) 103 | real (torch.tensor) 104 | 105 | """ 106 | 107 | fake_path = os.path.join( 108 | train_image_root, 109 | 'fake_{}_iter_{:07d}.png'.format(count, n_iter) 110 | ) 111 | real_path = os.path.join( 112 | train_image_root, 113 | 'real_{}_iter_{:07d}.png'.format(count, n_iter) 114 | ) 115 | torchvision.utils.save_image( 116 | fake, fake_path, nrow=4, normalize=True, scale_each=True 117 | ) 118 | shutil.copy(fake_path, os.path.join(root, 'fake_latest.png')) 119 | torchvision.utils.save_image( 120 | real, real_path, nrow=4, normalize=True, scale_each=True 121 | ) 122 | shutil.copy(real_path, os.path.join(root, 'real_latest.png')) 123 | 124 | 125 | def save_checkpoints(args, n_iter, count, gen, opt_gen, dis, opt_dis): 126 | """Save checkpoints. 127 | 128 | Args: 129 | args (argparse object) 130 | n_iter (int) 131 | gen (nn.Module) 132 | opt_gen (torch.optim) 133 | dis (nn.Module) 134 | opt_dis (torch.optim) 135 | 136 | """ 137 | 138 | count = n_iter // args.checkpoint_interval 139 | gen_dst = os.path.join( 140 | args.results_root, 141 | 'gen_{}_iter_{:07d}.pth.tar'.format(count, n_iter) 142 | ) 143 | torch.save({ 144 | 'model': gen.state_dict(), 'opt': opt_gen.state_dict(), 145 | }, gen_dst) 146 | shutil.copy(gen_dst, os.path.join(args.results_root, 'gen_latest.pth.tar')) 147 | dis_dst = os.path.join( 148 | args.results_root, 149 | 'dis_{}_iter_{:07d}.pth.tar'.format(count, n_iter) 150 | ) 151 | torch.save({ 152 | 'model': dis.state_dict(), 'opt': opt_dis.state_dict(), 153 | }, dis_dst) 154 | shutil.copy(dis_dst, os.path.join(args.results_root, 'dis_latest.pth.tar')) 155 | 156 | 157 | def resume_from_args(args_path, gen_ckpt_path, dis_ckpt_path): 158 | """Load generator & discriminator with their optimizers from args.json. 159 | 160 | Args: 161 | args_path (str): Path to args.json 162 | gen_ckpt_path (str): Path to generator checkpoint or relative path 163 | from args['results_root'] 164 | dis_ckpt_path (str): Path to discriminator checkpoint or relative path 165 | from args['results_root'] 166 | 167 | Returns: 168 | gen, opt_dis 169 | dis, opt_dis 170 | 171 | """ 172 | 173 | from models.generators import resnet64 174 | from models.discriminators import snresnet64 175 | 176 | with open(args_path) as f: 177 | args = json.load(f) 178 | conditional = args['cGAN'] 179 | num_classes = args['num_classes'] if conditional else 0 180 | # Initialize generator 181 | gen = resnet64.ResNetGenerator( 182 | args['gen_num_features'], args['gen_dim_z'], args['gen_bottom_width'], 183 | num_classes=num_classes, distribution=args['gen_distribution'] 184 | ) 185 | opt_gen = torch.optim.Adam( 186 | gen.parameters(), args['lr'], (args['beta1'], args['beta2']) 187 | ) 188 | # Initialize discriminator 189 | if args['dis_arch_concat']: 190 | dis = snresnet64.SNResNetConcatDiscriminator( 191 | args['dis_num_features'], num_classes, dim_emb=args['dis_emb'] 192 | ) 193 | else: 194 | dis = snresnet64.SNResNetProjectionDiscriminator( 195 | args['dis_num_features'], num_classes 196 | ) 197 | opt_dis = torch.optim.Adam( 198 | dis.parameters(), args['lr'], (args['beta1'], args['beta2']) 199 | ) 200 | if not os.path.exists(gen_ckpt_path): 201 | gen_ckpt_path = os.path.join(args['results_root'], gen_ckpt_path) 202 | gen, opt_gen = load_model_optim(gen_ckpt_path, gen, opt_gen) 203 | if not os.path.exists(dis_ckpt_path): 204 | dis_ckpt_path = os.path.join(args['results_root'], dis_ckpt_path) 205 | dis, opt_dis = load_model_optim(dis_ckpt_path, dis, opt_dis) 206 | return Dict2Args(args), gen, opt_gen, dis, opt_dis 207 | 208 | 209 | def load_model_optim(checkpoint_path, model=None, optim=None): 210 | """Load trained weight. 211 | 212 | Args: 213 | checkpoint_path (str) 214 | model (nn.Module) 215 | optim (torch.optim) 216 | 217 | Returns: 218 | model 219 | optim 220 | 221 | """ 222 | 223 | ckpt = torch.load(checkpoint_path) 224 | if model is not None: 225 | model.load_state_dict(ckpt['model']) 226 | if optim is not None: 227 | optim.load_state_dict(ckpt['opt']) 228 | return model, optim 229 | 230 | 231 | def load_model(checkpoint_path, model): 232 | """Load trained weight. 233 | 234 | Args: 235 | checkpoint_path (str) 236 | model (nn.Module) 237 | 238 | Returns: 239 | model 240 | 241 | """ 242 | 243 | return load_model_optim(checkpoint_path, model, None)[0] 244 | 245 | 246 | def load_optim(checkpoint_path, optim): 247 | """Load optimizer from checkpoint. 248 | 249 | Args: 250 | checkpoint_path (str) 251 | optim (torch.optim) 252 | 253 | Returns: 254 | optim 255 | 256 | """ 257 | 258 | return load_model_optim(checkpoint_path, None, optim)[1] 259 | -------------------------------------------------------------------------------- /train_64.py: -------------------------------------------------------------------------------- 1 | # Training script for tiny-imagenet. 2 | # Again, this script has a lot of bugs everywhere. 3 | import argparse 4 | import datetime 5 | import json 6 | import os 7 | import shutil 8 | 9 | import numpy as np 10 | import torch 11 | import torch.utils.data as data 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | import torchvision 15 | import torchvision.datasets as datasets 16 | import torchvision.transforms as transforms 17 | import tqdm 18 | 19 | import evaluation 20 | import losses as L 21 | from models.discriminators.snresnet64 import SNResNetConcatDiscriminator 22 | from models.discriminators.snresnet64 import SNResNetProjectionDiscriminator 23 | from models.generators.resnet64 import ResNetGenerator 24 | from models import inception 25 | import utils 26 | 27 | 28 | # Copied from https://github.com/naoto0804/pytorch-AdaIN/blob/master/sampler.py#L5-L15 29 | def InfiniteSampler(n): 30 | # i = 0 31 | i = n - 1 32 | order = np.random.permutation(n) 33 | while True: 34 | yield order[i] 35 | i += 1 36 | if i >= n: 37 | np.random.seed() 38 | order = np.random.permutation(n) 39 | i = 0 40 | 41 | 42 | # Copied from https://github.com/naoto0804/pytorch-AdaIN/blob/master/sampler.py#L18-L26 43 | class InfiniteSamplerWrapper(data.sampler.Sampler): 44 | def __init__(self, data_source): 45 | self.num_samples = len(data_source) 46 | 47 | def __iter__(self): 48 | return iter(InfiniteSampler(self.num_samples)) 49 | 50 | def __len__(self): 51 | return 2 ** 31 52 | 53 | 54 | def prepare_results_dir(args): 55 | """Makedir, init tensorboard if required, save args.""" 56 | if args.test: 57 | import tempfile 58 | args.results_root = tempfile.mkdtemp() 59 | args.max_iteration = 10 60 | args.log_interval = 2 61 | args.eval_interval = 10 62 | args.n_fid_images = 100 63 | args.n_eval_batches = 10 64 | root = os.path.join(args.results_root, "cGAN" if args.cGAN else "SNGAN", 65 | datetime.datetime.now().strftime('%y%m%d_%H%M')) 66 | os.makedirs(root, exist_ok=True) 67 | if not args.no_tensorboard: 68 | from tensorboardX import SummaryWriter 69 | writer = SummaryWriter(root) 70 | else: 71 | writer = None 72 | 73 | train_image_root = os.path.join(root, "preview", "train") 74 | eval_image_root = os.path.join(root, "preview", "eval") 75 | os.makedirs(train_image_root, exist_ok=True) 76 | os.makedirs(eval_image_root, exist_ok=True) 77 | 78 | args.results_root = root 79 | args.train_image_root = train_image_root 80 | args.eval_image_root = eval_image_root 81 | 82 | if args.cGAN: 83 | if args.num_classes > args.n_eval_batches: 84 | args.n_eval_batches = args.num_classes 85 | if args.eval_batch_size is None: 86 | args.eval_batch_size = args.batch_size // 4 87 | 88 | if args.calc_FID: 89 | args.n_fid_batches = args.n_fid_images // args.batch_size 90 | 91 | with open(os.path.join(root, 'args.json'), 'w') as f: 92 | json.dump(args.__dict__, f, indent=2) 93 | print(json.dumps(args.__dict__, indent=2)) 94 | return args, writer 95 | 96 | 97 | def decay_lr(opt, max_iter, start_iter, initial_lr): 98 | """Decay learning rate linearly till 0.""" 99 | coeff = -initial_lr / (max_iter - start_iter) 100 | for pg in opt.param_groups: 101 | pg['lr'] += coeff 102 | 103 | 104 | def get_args(): 105 | parser = argparse.ArgumentParser() 106 | # Dataset configuration 107 | parser.add_argument('--cGAN', default=False, action='store_true', 108 | help='to train cGAN, set this ``True``. default: False') 109 | parser.add_argument('--data_root', type=str, default='tiny-imagenet-200', 110 | help='path to dataset root directory. default: tiny-imagenet-200') 111 | parser.add_argument('--batch_size', '-B', type=int, default=64, 112 | help='mini-batch size of training data. default: 64') 113 | parser.add_argument('--eval_batch_size', '-eB', default=None, 114 | help='mini-batch size of evaluation data. default: None') 115 | parser.add_argument('--num_workers', type=int, default=8, 116 | help='Number of workers for training data loader. default: 8') 117 | # Generator configuration 118 | parser.add_argument('--gen_num_features', '-gnf', type=int, default=64, 119 | help='Number of features of generator (a.k.a. nplanes or ngf). default: 64') 120 | parser.add_argument('--gen_dim_z', '-gdz', type=int, default=128, 121 | help='Dimension of generator input noise. default: 128') 122 | parser.add_argument('--gen_bottom_width', '-gbw', type=int, default=4, 123 | help='Initial size of hidden variable of generator. default: 4') 124 | parser.add_argument('--gen_distribution', '-gd', type=str, default='normal', 125 | help='Input noise distribution: normal (default) or uniform.') 126 | # Discriminator (Critic) configuration 127 | parser.add_argument('--dis_arch_concat', '-concat', default=False, action='store_true', 128 | help='If use concat discriminator, set this true. default: False') 129 | parser.add_argument('--dis_emb', type=int, default=128, 130 | help='Parameter for concat discriminator. default: 128') 131 | parser.add_argument('--dis_num_features', '-dnf', type=int, default=64, 132 | help='Number of features of discriminator (a.k.a nplanes or ndf). default: 64') 133 | # Optimizer settings 134 | parser.add_argument('--lr', type=float, default=0.0002, 135 | help='Initial learning rate of Adam. default: 0.0002') 136 | parser.add_argument('--beta1', type=float, default=0.0, 137 | help='beta1 (betas[0]) value of Adam. default: 0.0') 138 | parser.add_argument('--beta2', type=float, default=0.9, 139 | help='beta2 (betas[1]) value of Adam. default: 0.9') 140 | parser.add_argument('--lr_decay_start', '-lds', type=int, default=50000, 141 | help='Start point of learning rate decay. default: 50000') 142 | # Training setting 143 | parser.add_argument('--seed', type=int, default=46, 144 | help='Random seed. default: 46 (derived from Nogizaka46)') 145 | parser.add_argument('--max_iteration', '-N', type=int, default=100000, 146 | help='Max iteration number of training. default: 100000') 147 | parser.add_argument('--n_dis', type=int, default=5, 148 | help='Number of discriminator updater per generator updater. default: 5') 149 | parser.add_argument('--num_classes', '-nc', type=int, default=0, 150 | help='Number of classes in training data. No need to set. default: 0') 151 | parser.add_argument('--loss_type', type=str, default='hinge', 152 | help='loss function name. hinge (default) or dcgan.') 153 | parser.add_argument('--relativistic_loss', '-relloss', default=False, action='store_true', 154 | help='Apply relativistic loss or not. default: False') 155 | parser.add_argument('--calc_FID', default=False, action='store_true', 156 | help='If calculate FID score, set this ``True``. default: False') 157 | # Log and Save interval configuration 158 | parser.add_argument('--results_root', type=str, default='results', 159 | help='Path to results directory. default: results') 160 | parser.add_argument('--no_tensorboard', action='store_true', default=False, 161 | help='If you dislike tensorboard, set this ``False``. default: True') 162 | parser.add_argument('--no_image', action='store_true', default=False, 163 | help='If you dislike saving images on tensorboard, set this ``True``. default: False') 164 | parser.add_argument('--checkpoint_interval', '-ci', type=int, default=1000, 165 | help='Interval of saving checkpoints (model and optimizer). default: 1000') 166 | parser.add_argument('--log_interval', '-li', type=int, default=100, 167 | help='Interval of showing losses. default: 100') 168 | parser.add_argument('--eval_interval', '-ei', type=int, default=1000, 169 | help='Interval for evaluation (save images and FID calculation). default: 1000') 170 | parser.add_argument('--n_eval_batches', '-neb', type=int, default=100, 171 | help='Number of mini-batches used in evaluation. default: 100') 172 | parser.add_argument('--n_fid_images', '-nfi', type=int, default=5000, 173 | help='Number of images to calculate FID. default: 5000') 174 | parser.add_argument('--test', default=False, action='store_true', 175 | help='If test this python program, set this ``True``. default: False') 176 | # Resume training 177 | parser.add_argument('--args_path', default=None, help='Checkpoint args json path. default: None') 178 | parser.add_argument('--gen_ckpt_path', '-gcp', default=None, 179 | help='Generator and optimizer checkpoint path. default: None') 180 | parser.add_argument('--dis_ckpt_path', '-dcp', default=None, 181 | help='Discriminator and optimizer checkpoint path. default: None') 182 | args = parser.parse_args() 183 | return args 184 | 185 | 186 | def sample_from_data(args, device, data_loader): 187 | """Sample real images and labels from data_loader. 188 | 189 | Args: 190 | args (argparse object) 191 | device (torch.device) 192 | data_loader (DataLoader) 193 | 194 | Returns: 195 | real, y 196 | 197 | """ 198 | 199 | real, y = next(data_loader) 200 | real, y = real.to(device), y.to(device) 201 | if not args.cGAN: 202 | y = None 203 | return real, y 204 | 205 | 206 | def sample_from_gen(args, device, num_classes, gen): 207 | """Sample fake images and labels from generator. 208 | 209 | Args: 210 | args (argparse object) 211 | device (torch.device) 212 | num_classes (int): for pseudo_y 213 | gen (nn.Module) 214 | 215 | Returns: 216 | fake, pseudo_y, z 217 | 218 | """ 219 | 220 | z = utils.sample_z( 221 | args.batch_size, args.gen_dim_z, device, args.gen_distribution 222 | ) 223 | if args.cGAN: 224 | pseudo_y = utils.sample_pseudo_labels( 225 | num_classes, args.batch_size, device 226 | ) 227 | else: 228 | pseudo_y = None 229 | 230 | fake = gen(z, pseudo_y) 231 | return fake, pseudo_y, z 232 | 233 | 234 | def main(): 235 | args = get_args() 236 | # CUDA setting 237 | if not torch.cuda.is_available(): 238 | raise ValueError("Should buy GPU!") 239 | torch.manual_seed(args.seed) 240 | torch.cuda.manual_seed_all(args.seed) 241 | device = torch.device('cuda') 242 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 243 | torch.backends.cudnn.benchmark = True 244 | 245 | def _rescale(img): 246 | return img * 2.0 - 1.0 247 | 248 | def _noise_adder(img): 249 | return torch.empty_like(img, dtype=img.dtype).uniform_(0.0, 1/128.0) + img 250 | 251 | # dataset 252 | train_dataset = datasets.ImageFolder( 253 | os.path.join(args.data_root, 'train'), 254 | transforms.Compose([ 255 | transforms.ToTensor(), _rescale, _noise_adder, 256 | ]) 257 | ) 258 | train_loader = iter(data.DataLoader( 259 | train_dataset, args.batch_size, 260 | sampler=InfiniteSamplerWrapper(train_dataset), 261 | num_workers=args.num_workers, pin_memory=True) 262 | ) 263 | if args.calc_FID: 264 | eval_dataset = datasets.ImageFolder( 265 | os.path.join(args.data_root, 'val'), 266 | transforms.Compose([ 267 | transforms.ToTensor(), _rescale, 268 | ]) 269 | ) 270 | eval_loader = iter(data.DataLoader( 271 | eval_dataset, args.batch_size, 272 | sampler=InfiniteSamplerWrapper(eval_dataset), 273 | num_workers=args.num_workers, pin_memory=True) 274 | ) 275 | else: 276 | eval_loader = None 277 | num_classes = len(train_dataset.classes) 278 | print(' prepared datasets...') 279 | print(' Number of training images: {}'.format(len(train_dataset))) 280 | # Prepare directories. 281 | args.num_classes = num_classes 282 | args, writer = prepare_results_dir(args) 283 | # initialize models. 284 | _n_cls = num_classes if args.cGAN else 0 285 | gen = ResNetGenerator( 286 | args.gen_num_features, args.gen_dim_z, args.gen_bottom_width, 287 | activation=F.relu, num_classes=_n_cls, distribution=args.gen_distribution 288 | ).to(device) 289 | if args.dis_arch_concat: 290 | dis = SNResNetConcatDiscriminator( 291 | args.dis_num_features, _n_cls, F.relu, args.dis_emb).to(device) 292 | else: 293 | dis = SNResNetProjectionDiscriminator( 294 | args.dis_num_features, _n_cls, F.relu).to(device) 295 | inception_model = inception.InceptionV3().to(device) if args.calc_FID else None 296 | 297 | opt_gen = optim.Adam(gen.parameters(), args.lr, (args.beta1, args.beta2)) 298 | opt_dis = optim.Adam(dis.parameters(), args.lr, (args.beta1, args.beta2)) 299 | 300 | # gen_criterion = getattr(L, 'gen_{}'.format(args.loss_type)) 301 | # dis_criterion = getattr(L, 'dis_{}'.format(args.loss_type)) 302 | gen_criterion = L.GenLoss(args.loss_type, args.relativistic_loss) 303 | dis_criterion = L.DisLoss(args.loss_type, args.relativistic_loss) 304 | print(' Initialized models...\n') 305 | 306 | if args.args_path is not None: 307 | print(' Load weights...\n') 308 | prev_args, gen, opt_gen, dis, opt_dis = utils.resume_from_args( 309 | args.args_path, args.gen_ckpt_path, args.dis_ckpt_path 310 | ) 311 | 312 | # Training loop 313 | for n_iter in tqdm.tqdm(range(1, args.max_iteration + 1)): 314 | 315 | if n_iter >= args.lr_decay_start: 316 | decay_lr(opt_gen, args.max_iteration, args.lr_decay_start, args.lr) 317 | decay_lr(opt_dis, args.max_iteration, args.lr_decay_start, args.lr) 318 | 319 | # ==================== Beginning of 1 iteration. ==================== 320 | _l_g = .0 321 | cumulative_loss_dis = .0 322 | for i in range(args.n_dis): 323 | if i == 0: 324 | fake, pseudo_y, _ = sample_from_gen(args, device, num_classes, gen) 325 | dis_fake = dis(fake, pseudo_y) 326 | if args.relativistic_loss: 327 | real, y = sample_from_data(args, device, train_loader) 328 | dis_real = dis(real, y) 329 | else: 330 | dis_real = None 331 | 332 | loss_gen = gen_criterion(dis_fake, dis_real) 333 | gen.zero_grad() 334 | loss_gen.backward() 335 | opt_gen.step() 336 | _l_g += loss_gen.item() 337 | if n_iter % 10 == 0 and writer is not None: 338 | writer.add_scalar('gen', _l_g, n_iter) 339 | 340 | fake, pseudo_y, _ = sample_from_gen(args, device, num_classes, gen) 341 | real, y = sample_from_data(args, device, train_loader) 342 | 343 | dis_fake, dis_real = dis(fake, pseudo_y), dis(real, y) 344 | loss_dis = dis_criterion(dis_fake, dis_real) 345 | 346 | dis.zero_grad() 347 | loss_dis.backward() 348 | opt_dis.step() 349 | 350 | cumulative_loss_dis += loss_dis.item() 351 | if n_iter % 10 == 0 and i == args.n_dis - 1 and writer is not None: 352 | cumulative_loss_dis /= args.n_dis 353 | writer.add_scalar('dis', cumulative_loss_dis / args.n_dis, n_iter) 354 | # ==================== End of 1 iteration. ==================== 355 | 356 | if n_iter % args.log_interval == 0: 357 | tqdm.tqdm.write( 358 | 'iteration: {:07d}/{:07d}, loss gen: {:05f}, loss dis {:05f}'.format( 359 | n_iter, args.max_iteration, _l_g, cumulative_loss_dis)) 360 | if not args.no_image: 361 | writer.add_image( 362 | 'fake', torchvision.utils.make_grid( 363 | fake, nrow=4, normalize=True, scale_each=True)) 364 | writer.add_image( 365 | 'real', torchvision.utils.make_grid( 366 | real, nrow=4, normalize=True, scale_each=True)) 367 | # Save previews 368 | utils.save_images( 369 | n_iter, n_iter // args.checkpoint_interval, args.results_root, 370 | args.train_image_root, fake, real 371 | ) 372 | if n_iter % args.checkpoint_interval == 0: 373 | # Save checkpoints! 374 | utils.save_checkpoints( 375 | args, n_iter, n_iter // args.checkpoint_interval, 376 | gen, opt_gen, dis, opt_dis 377 | ) 378 | if n_iter % args.eval_interval == 0: 379 | # TODO (crcrpar): implement Ineption score, FID, and Geometry score 380 | # Once these criterion are prepared, val_loader will be used. 381 | fid_score = evaluation.evaluate( 382 | args, n_iter, gen, device, inception_model, eval_loader 383 | ) 384 | tqdm.tqdm.write( 385 | '[Eval] iteration: {:07d}/{:07d}, FID: {:07f}'.format( 386 | n_iter, args.max_iteration, fid_score)) 387 | if writer is not None: 388 | writer.add_scalar("FID", fid_score, n_iter) 389 | # Project embedding weights if exists. 390 | embedding_layer = getattr(dis, 'l_y', None) 391 | if embedding_layer is not None: 392 | writer.add_embedding( 393 | embedding_layer.weight.data, 394 | list(range(args.num_classes)), 395 | global_step=n_iter 396 | ) 397 | if args.test: 398 | shutil.rmtree(args.results_root) 399 | 400 | 401 | if __name__ == '__main__': 402 | main() 403 | --------------------------------------------------------------------------------