├── LICENSE ├── README.md ├── cfg.py ├── datasets.py ├── exps ├── eval.sh └── sngan_cifar10.sh ├── functions.py ├── models ├── __init__.py ├── gen_resblock.py ├── sngan_64.py ├── sngan_cifar10.py └── sngan_stl10.py ├── requirements.txt ├── test.py ├── train.py └── utils ├── __init__.py ├── cal_fid_stat.py ├── fid_score.py ├── inception_score.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Xinyu Gong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SNGAN.pytorch 2 | An unofficial Pytorch implementation of [Spectral Normalization for Generative Adversarial Networks](https://openreview.net/pdf?id=B1QRgziT-). 3 | For official Chainer implementation please refer to [https://github.com/pfnet-research/sngan_projection](https://github.com/pfnet-research/sngan_projection) 4 | 5 | Our implementation achieves Inception score of **8.21** and FID score of **14.21** on unconditional CIFAR-10 image generation task. 6 | In comparison, the original paper claims **8.22** and **21.7** respectively. 7 | 8 | ## Set-up 9 | 10 | ### install libraries: 11 | ```bash 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | ### prepare fid statistic file 16 | ```bash 17 | mkdir fid_stat 18 | ``` 19 | Download the pre-calculated statistics for CIFAR10, 20 | [fid_stats_cifar10_train.npz](http://bioinf.jku.at/research/ttur/ttur_stats/fid_stats_cifar10_train.npz), to `./fid_stat`. 21 | 22 | ### train 23 | ```bash 24 | sh exps/sngan_cifar10.sh 25 | ``` 26 | 27 | ### test 28 | ```bash 29 | mkdir pre_trained 30 | ``` 31 | Download the pre-trained SNGAN model [sngan_cifar10.pth](https://drive.google.com/file/d/1koEJbx9anP2-BEMrqX6jgWXAvEUXG0AU/view?usp=sharing) to `./pre_trained`. 32 | Run the following script: 33 | ```bash 34 | sh exps/eval.sh 35 | ``` 36 | 37 | ## Acknowledgement 38 | 39 | 1. Inception Score code from [OpenAI's Improved GAN](https://github.com/openai/improved-gan/tree/master/inception_score) (official). 40 | 2. FID code and statistics file from [https://github.com/bioinf-jku/TTUR](https://github.com/bioinf-jku/TTUR) (official). 41 | 3. The code of Spectral Norm GAN is inspired by [https://github.com/pfnet-research/sngan_projection](https://github.com/pfnet-research/sngan_projection) (official). 42 | -------------------------------------------------------------------------------- /cfg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2019-07-25 3 | # @Author : Xinyu Gong (xy_gong@tamu.edu) 4 | # @Link : None 5 | # @Version : 0.0 6 | 7 | import argparse 8 | 9 | 10 | def str2bool(v): 11 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 12 | return True 13 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 14 | return False 15 | else: 16 | raise argparse.ArgumentTypeError('Boolean value expected.') 17 | 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | '--max_epoch', 23 | type=int, 24 | default=200, 25 | help='number of epochs of training') 26 | parser.add_argument( 27 | '--max_iter', 28 | type=int, 29 | default=None, 30 | help='set the max iteration number') 31 | parser.add_argument( 32 | '-gen_bs', 33 | '--gen_batch_size', 34 | type=int, 35 | default=64, 36 | help='size of the batches') 37 | parser.add_argument( 38 | '-dis_bs', 39 | '--dis_batch_size', 40 | type=int, 41 | default=64, 42 | help='size of the batches') 43 | parser.add_argument( 44 | '--g_lr', 45 | type=float, 46 | default=0.0002, 47 | help='adam: gen learning rate') 48 | parser.add_argument( 49 | '--d_lr', 50 | type=float, 51 | default=0.0002, 52 | help='adam: disc learning rate') 53 | parser.add_argument( 54 | '--lr_decay', 55 | action='store_true', 56 | help='learning rate decay or not') 57 | parser.add_argument( 58 | '--beta1', 59 | type=float, 60 | default=0.0, 61 | help='adam: decay of first order momentum of gradient') 62 | parser.add_argument( 63 | '--beta2', 64 | type=float, 65 | default=0.9, 66 | help='adam: decay of first order momentum of gradient') 67 | parser.add_argument( 68 | '--num_workers', 69 | type=int, 70 | default=8, 71 | help='number of cpu threads to use during batch generation') 72 | parser.add_argument( 73 | '--latent_dim', 74 | type=int, 75 | default=128, 76 | help='dimensionality of the latent space') 77 | parser.add_argument( 78 | '--img_size', 79 | type=int, 80 | default=32, 81 | help='size of each image dimension') 82 | parser.add_argument( 83 | '--channels', 84 | type=int, 85 | default=3, 86 | help='number of image channels') 87 | parser.add_argument( 88 | '--n_critic', 89 | type=int, 90 | default=1, 91 | help='number of training steps for discriminator per iter') 92 | parser.add_argument( 93 | '--val_freq', 94 | type=int, 95 | default=20, 96 | help='interval between each validation') 97 | parser.add_argument( 98 | '--print_freq', 99 | type=int, 100 | default=50, 101 | help='interval between each verbose') 102 | parser.add_argument( 103 | '--load_path', 104 | type=str, 105 | help='The reload model path') 106 | parser.add_argument( 107 | '--exp_name', 108 | type=str, 109 | help='The name of exp') 110 | parser.add_argument( 111 | '--d_spectral_norm', 112 | type=str2bool, 113 | default=False, 114 | help='add spectral_norm on discriminator?') 115 | parser.add_argument( 116 | '--g_spectral_norm', 117 | type=str2bool, 118 | default=False, 119 | help='add spectral_norm on generator?') 120 | parser.add_argument( 121 | '--dataset', 122 | type=str, 123 | default='cifar10', 124 | help='dataset type') 125 | parser.add_argument( 126 | '--data_path', 127 | type=str, 128 | default='./data', 129 | help='The path of data set') 130 | parser.add_argument('--init_type', type=str, default='normal', 131 | choices=['normal', 'orth', 'xavier_uniform', 'false'], 132 | help='The init type') 133 | parser.add_argument('--gf_dim', type=int, default=64, 134 | help='The base channel num of gen') 135 | parser.add_argument('--df_dim', type=int, default=64, 136 | help='The base channel num of disc') 137 | parser.add_argument( 138 | '--model', 139 | type=str, 140 | default='sngan_cifar10', 141 | help='path of model') 142 | parser.add_argument('--eval_batch_size', type=int, default=100) 143 | parser.add_argument('--num_eval_imgs', type=int, default=50000) 144 | parser.add_argument( 145 | '--bottom_width', 146 | type=int, 147 | default=4, 148 | help="the base resolution of the GAN") 149 | parser.add_argument('--random_seed', type=int, default=12345) 150 | 151 | opt = parser.parse_args() 152 | return opt 153 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.datasets as datasets 3 | import torchvision.transforms as transforms 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class ImageDataset(object): 8 | def __init__(self, args): 9 | if args.dataset.lower() == 'cifar10': 10 | Dt = datasets.CIFAR10 11 | transform = transforms.Compose([ 12 | transforms.Resize(args.img_size), 13 | transforms.ToTensor(), 14 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 15 | ]) 16 | args.n_classes = 10 17 | elif args.dataset.lower() == 'stl10': 18 | Dt = datasets.STL10 19 | transform = transforms.Compose([ 20 | transforms.Resize(args.img_size), 21 | transforms.ToTensor(), 22 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 23 | ]) 24 | else: 25 | raise NotImplementedError('Unknown dataset: {}'.format(args.dataset)) 26 | 27 | if args.dataset.lower() == 'stl10': 28 | self.train = torch.utils.data.DataLoader( 29 | Dt(root=args.data_path, split='train+unlabeled', transform=transform, download=True), 30 | batch_size=args.dis_batch_size, shuffle=True, 31 | num_workers=args.num_workers, pin_memory=True) 32 | 33 | self.valid = torch.utils.data.DataLoader( 34 | Dt(root=args.data_path, split='test', transform=transform), 35 | batch_size=args.dis_batch_size, shuffle=False, 36 | num_workers=args.num_workers, pin_memory=True) 37 | 38 | self.test = self.valid 39 | else: 40 | self.train = torch.utils.data.DataLoader( 41 | Dt(root=args.data_path, train=True, transform=transform, download=True), 42 | batch_size=args.dis_batch_size, shuffle=True, 43 | num_workers=args.num_workers, pin_memory=True) 44 | 45 | self.valid = torch.utils.data.DataLoader( 46 | Dt(root=args.data_path, train=False, transform=transform), 47 | batch_size=args.dis_batch_size, shuffle=False, 48 | num_workers=args.num_workers, pin_memory=True) 49 | 50 | self.test = self.valid 51 | -------------------------------------------------------------------------------- /exps/eval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_VISIBLE_DEVICES=4 4 | python test.py \ 5 | --img_size 32 \ 6 | --model sngan_cifar10 \ 7 | --latent_dim 128 \ 8 | --gf_dim 256 \ 9 | --g_spectral_norm False \ 10 | --load_path pre_trained/sngan_cifar10.pth \ 11 | --exp_name test_sngan_cifar10 -------------------------------------------------------------------------------- /exps/sngan_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0 4 | python train.py \ 5 | -gen_bs 128 \ 6 | -dis_bs 64 \ 7 | --dataset cifar10 \ 8 | --img_size 32 \ 9 | --max_iter 50000 \ 10 | --model sngan_cifar10 \ 11 | --latent_dim 128 \ 12 | --gf_dim 256 \ 13 | --df_dim 128 \ 14 | --g_spectral_norm False \ 15 | --d_spectral_norm True \ 16 | --g_lr 0.0002 \ 17 | --d_lr 0.0002 \ 18 | --beta1 0.0 \ 19 | --beta2 0.9 \ 20 | --init_type xavier_uniform \ 21 | --n_critic 5 \ 22 | --val_freq 20 \ 23 | --exp_name sngan_cifar10 -------------------------------------------------------------------------------- /functions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2019-07-25 3 | # @Author : Xinyu Gong (xy_gong@tamu.edu) 4 | # @Link : None 5 | # @Version : 0.0 6 | 7 | import os 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | from torchvision.utils import make_grid 12 | from imageio import imsave 13 | from tqdm import tqdm 14 | from copy import deepcopy 15 | import logging 16 | 17 | from utils.inception_score import get_inception_score 18 | from utils.fid_score import calculate_fid_given_paths 19 | 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | def train(args, gen_net: nn.Module, dis_net: nn.Module, gen_optimizer, dis_optimizer, gen_avg_param, train_loader, epoch, 25 | writer_dict, schedulers=None): 26 | writer = writer_dict['writer'] 27 | gen_step = 0 28 | 29 | # train mode 30 | gen_net = gen_net.train() 31 | dis_net = dis_net.train() 32 | 33 | for iter_idx, (imgs, _) in enumerate(tqdm(train_loader)): 34 | global_steps = writer_dict['train_global_steps'] 35 | 36 | # Adversarial ground truths 37 | real_imgs = imgs.type(torch.cuda.FloatTensor) 38 | 39 | # Sample noise as generator input 40 | z = torch.cuda.FloatTensor(np.random.normal(0, 1, (imgs.shape[0], args.latent_dim))) 41 | 42 | # --------------------- 43 | # Train Discriminator 44 | # --------------------- 45 | dis_optimizer.zero_grad() 46 | 47 | real_validity = dis_net(real_imgs) 48 | fake_imgs = gen_net(z).detach() 49 | assert fake_imgs.size() == real_imgs.size() 50 | 51 | fake_validity = dis_net(fake_imgs) 52 | 53 | # cal loss 54 | d_loss = torch.mean(nn.ReLU(inplace=True)(1.0 - real_validity)) + \ 55 | torch.mean(nn.ReLU(inplace=True)(1 + fake_validity)) 56 | d_loss.backward() 57 | dis_optimizer.step() 58 | 59 | writer.add_scalar('d_loss', d_loss.item(), global_steps) 60 | 61 | # ----------------- 62 | # Train Generator 63 | # ----------------- 64 | if global_steps % args.n_critic == 0: 65 | gen_optimizer.zero_grad() 66 | 67 | gen_z = torch.cuda.FloatTensor(np.random.normal(0, 1, (args.gen_batch_size, args.latent_dim))) 68 | gen_imgs = gen_net(gen_z) 69 | fake_validity = dis_net(gen_imgs) 70 | 71 | # cal loss 72 | g_loss = -torch.mean(fake_validity) 73 | g_loss.backward() 74 | gen_optimizer.step() 75 | 76 | # adjust learning rate 77 | if schedulers: 78 | gen_scheduler, dis_scheduler = schedulers 79 | g_lr = gen_scheduler.step(global_steps) 80 | d_lr = dis_scheduler.step(global_steps) 81 | writer.add_scalar('LR/g_lr', g_lr, global_steps) 82 | writer.add_scalar('LR/d_lr', d_lr, global_steps) 83 | 84 | # moving average weight 85 | for p, avg_p in zip(gen_net.parameters(), gen_avg_param): 86 | avg_p.mul_(0.999).add_(0.001, p.data) 87 | 88 | writer.add_scalar('g_loss', g_loss.item(), global_steps) 89 | gen_step += 1 90 | 91 | # verbose 92 | if gen_step and iter_idx % args.print_freq == 0: 93 | tqdm.write( 94 | "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % 95 | (epoch, args.max_epoch, iter_idx % len(train_loader), len(train_loader), d_loss.item(), g_loss.item())) 96 | 97 | writer_dict['train_global_steps'] = global_steps + 1 98 | 99 | 100 | def validate(args, fixed_z, fid_stat, gen_net: nn.Module, writer_dict): 101 | writer = writer_dict['writer'] 102 | global_steps = writer_dict['valid_global_steps'] 103 | 104 | # eval mode 105 | gen_net = gen_net.eval() 106 | 107 | # generate images 108 | sample_imgs = gen_net(fixed_z) 109 | img_grid = make_grid(sample_imgs, nrow=5, normalize=True, scale_each=True) 110 | 111 | # get fid and inception score 112 | fid_buffer_dir = os.path.join(args.path_helper['sample_path'], 'fid_buffer') 113 | os.makedirs(fid_buffer_dir) 114 | 115 | eval_iter = args.num_eval_imgs // args.eval_batch_size 116 | img_list = list() 117 | for iter_idx in tqdm(range(eval_iter), desc='sample images'): 118 | z = torch.cuda.FloatTensor(np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim))) 119 | 120 | # Generate a batch of images 121 | gen_imgs = gen_net(z).mul_(127.5).add_(127.5).clamp_(0.0, 255.0).permute(0, 2, 3, 1).to('cpu', torch.uint8).numpy() 122 | for img_idx, img in enumerate(gen_imgs): 123 | file_name = os.path.join(fid_buffer_dir, f'iter{iter_idx}_b{img_idx}.png') 124 | imsave(file_name, img) 125 | img_list.extend(list(gen_imgs)) 126 | 127 | # get inception score 128 | logger.info('=> calculate inception score') 129 | mean, std = get_inception_score(img_list) 130 | 131 | # get fid score 132 | logger.info('=> calculate fid score') 133 | fid_score = calculate_fid_given_paths([fid_buffer_dir, fid_stat], inception_path=None) 134 | 135 | os.system('rm -r {}'.format(fid_buffer_dir)) 136 | 137 | writer.add_image('sampled_images', img_grid, global_steps) 138 | writer.add_scalar('Inception_score/mean', mean, global_steps) 139 | writer.add_scalar('Inception_score/std', std, global_steps) 140 | writer.add_scalar('FID_score', fid_score, global_steps) 141 | 142 | writer_dict['valid_global_steps'] = global_steps + 1 143 | 144 | return mean, fid_score 145 | 146 | 147 | class LinearLrDecay(object): 148 | def __init__(self, optimizer, start_lr, end_lr, decay_start_step, decay_end_step): 149 | 150 | assert start_lr > end_lr 151 | self.optimizer = optimizer 152 | self.delta = (start_lr - end_lr) / (decay_end_step - decay_start_step) 153 | self.decay_start_step = decay_start_step 154 | self.decay_end_step = decay_end_step 155 | self.start_lr = start_lr 156 | self.end_lr = end_lr 157 | 158 | def step(self, current_step): 159 | if current_step <= self.decay_start_step: 160 | lr = self.start_lr 161 | elif current_step >= self.decay_end_step: 162 | lr = self.end_lr 163 | else: 164 | lr = self.start_lr - self.delta * (current_step - self.decay_start_step) 165 | for param_group in self.optimizer.param_groups: 166 | param_group['lr'] = lr 167 | return lr 168 | 169 | 170 | def load_params(model, new_param): 171 | for p, new_p in zip(model.parameters(), new_param): 172 | p.data.copy_(new_p) 173 | 174 | 175 | def copy_params(model): 176 | flatten = deepcopy(list(p.data for p in model.parameters())) 177 | return flatten 178 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2019-07-25 3 | # @Author : Xinyu Gong (xy_gong@tamu.edu) 4 | # @Link : None 5 | # @Version : 0.0 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import models.sngan_cifar10 12 | import models.sngan_stl10 13 | -------------------------------------------------------------------------------- /models/gen_resblock.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 3/26/20 3 | # @Author : Xinyu Gong (xy_gong@tamu.edu) 4 | # @Link : None 5 | # @Version : 0.0 6 | 7 | import torch.nn as nn 8 | 9 | 10 | class GenBlock(nn.Module): 11 | def __init__(self, in_channels, out_channels, hidden_channels=None, ksize=3, pad=1, 12 | activation=nn.ReLU(), upsample=False, n_classes=0): 13 | super(GenBlock, self).__init__() 14 | self.activation = activation 15 | self.upsample = upsample 16 | self.learnable_sc = in_channels != out_channels or upsample 17 | hidden_channels = out_channels if hidden_channels is None else hidden_channels 18 | self.n_classes = n_classes 19 | self.c1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=ksize, padding=pad) 20 | self.c2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=ksize, padding=pad) 21 | 22 | self.b1 = nn.BatchNorm2d(in_channels) 23 | self.b2 = nn.BatchNorm2d(hidden_channels) 24 | if self.learnable_sc: 25 | self.c_sc = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) 26 | 27 | def upsample_conv(self, x, conv): 28 | return conv(nn.UpsamplingNearest2d(scale_factor=2)(x)) 29 | 30 | def residual(self, x): 31 | h = x 32 | h = self.b1(h) 33 | h = self.activation(h) 34 | h = self.upsample_conv(h, self.c1) if self.upsample else self.c1(h) 35 | h = self.b2(h) 36 | h = self.activation(h) 37 | h = self.c2(h) 38 | return h 39 | 40 | def shortcut(self, x): 41 | if self.learnable_sc: 42 | x = self.upsample_conv(x, self.c_sc) if self.upsample else self.c_sc(x) 43 | return x 44 | else: 45 | return x 46 | 47 | def forward(self, x): 48 | return self.residual(x) + self.shortcut(x) -------------------------------------------------------------------------------- /models/sngan_64.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class GenBlock(nn.Module): 5 | def __init__(self, in_channels, out_channels, hidden_channels=None, ksize=3, pad=1, 6 | activation=nn.ReLU(), upsample=False, n_classes=0): 7 | super(GenBlock, self).__init__() 8 | self.activation = activation 9 | self.upsample = upsample 10 | self.learnable_sc = in_channels != out_channels or upsample 11 | hidden_channels = out_channels if hidden_channels is None else hidden_channels 12 | self.n_classes = n_classes 13 | self.c1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=ksize, padding=pad) 14 | self.c2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=ksize, padding=pad) 15 | 16 | self.b1 = nn.BatchNorm2d(in_channels) 17 | self.b2 = nn.BatchNorm2d(hidden_channels) 18 | if self.learnable_sc: 19 | self.c_sc = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) 20 | 21 | def upsample_conv(self, x, conv): 22 | return conv(nn.UpsamplingNearest2d(scale_factor=2)(x)) 23 | 24 | def residual(self, x): 25 | h = x 26 | h = self.b1(h) 27 | h = self.activation(h) 28 | h = self.upsample_conv(h, self.c1) if self.upsample else self.c1(h) 29 | h = self.b2(h) 30 | h = self.activation(h) 31 | h = self.c2(h) 32 | return h 33 | 34 | def shortcut(self, x): 35 | if self.learnable_sc: 36 | x = self.upsample_conv(x, self.c_sc) if self.upsample else self.c_sc(x) 37 | return x 38 | else: 39 | return x 40 | 41 | def forward(self, x): 42 | return self.residual(x) + self.shortcut(x) 43 | 44 | 45 | class Generator(nn.Module): 46 | def __init__(self, args, activation=nn.ReLU(), n_classes=0): 47 | super(Generator, self).__init__() 48 | self.bottom_width = args.bottom_width 49 | self.activation = activation 50 | self.n_classes = n_classes 51 | self.ch = args.gf_dim 52 | self.l1 = nn.Linear(args.latent_dim, (self.bottom_width ** 2) * self.ch) 53 | self.block2 = GenBlock(self.ch, self.ch, activation=activation, upsample=True, n_classes=n_classes) 54 | self.block3 = GenBlock(self.ch, self.ch, activation=activation, upsample=True, n_classes=n_classes) 55 | self.block4 = GenBlock(self.ch, self.ch, activation=activation, upsample=True, n_classes=n_classes) 56 | self.b5 = nn.BatchNorm2d(self.ch) 57 | self.c5 = nn.Conv2d(self.ch, 3, kernel_size=3, stride=1, padding=1) 58 | 59 | def forward(self, z): 60 | 61 | h = z 62 | h = self.l1(h).view(-1, self.ch, self.bottom_width, self.bottom_width) 63 | h = self.block2(h) 64 | h = self.block3(h) 65 | h = self.block4(h) 66 | h = self.b5(h) 67 | h = self.activation(h) 68 | h = nn.Tanh()(self.c5(h)) 69 | return h 70 | 71 | 72 | """Discriminator""" 73 | 74 | 75 | def _downsample(x): 76 | # Downsample (Mean Avg Pooling with 2x2 kernel) 77 | return nn.AvgPool2d(kernel_size=2)(x) 78 | 79 | 80 | class OptimizedDisBlock(nn.Module): 81 | def __init__(self, args, in_channels, out_channels, ksize=3, pad=1, activation=nn.ReLU()): 82 | super(OptimizedDisBlock, self).__init__() 83 | self.activation = activation 84 | 85 | self.c1 = nn.Conv2d(in_channels, out_channels, kernel_size=ksize, padding=pad) 86 | self.c2 = nn.Conv2d(out_channels, out_channels, kernel_size=ksize, padding=pad) 87 | self.c_sc = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) 88 | if args.d_spectral_norm: 89 | self.c1 = nn.utils.spectral_norm(self.c1) 90 | self.c2 = nn.utils.spectral_norm(self.c2) 91 | self.c_sc = nn.utils.spectral_norm(self.c_sc) 92 | 93 | def residual(self, x): 94 | h = x 95 | h = self.c1(h) 96 | h = self.activation(h) 97 | h = self.c2(h) 98 | h = _downsample(h) 99 | return h 100 | 101 | def shortcut(self, x): 102 | return self.c_sc(_downsample(x)) 103 | 104 | def forward(self, x): 105 | return self.residual(x) + self.shortcut(x) 106 | 107 | 108 | class DisBlock(nn.Module): 109 | def __init__(self, args, in_channels, out_channels, hidden_channels=None, ksize=3, pad=1, 110 | activation=nn.ReLU(), downsample=False): 111 | super(DisBlock, self).__init__() 112 | self.activation = activation 113 | self.downsample = downsample 114 | self.learnable_sc = (in_channels != out_channels) or downsample 115 | hidden_channels = in_channels if hidden_channels is None else hidden_channels 116 | self.c1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=ksize, padding=pad) 117 | self.c2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=ksize, padding=pad) 118 | if args.d_spectral_norm: 119 | self.c1 = nn.utils.spectral_norm(self.c1) 120 | self.c2 = nn.utils.spectral_norm(self.c2) 121 | 122 | if self.learnable_sc: 123 | self.c_sc = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) 124 | if args.d_spectral_norm: 125 | self.c_sc = nn.utils.spectral_norm(self.c_sc) 126 | 127 | def residual(self, x): 128 | h = x 129 | h = self.activation(h) 130 | h = self.c1(h) 131 | h = self.activation(h) 132 | h = self.c2(h) 133 | if self.downsample: 134 | h = _downsample(h) 135 | return h 136 | 137 | def shortcut(self, x): 138 | if self.learnable_sc: 139 | x = self.c_sc(x) 140 | if self.downsample: 141 | return _downsample(x) 142 | else: 143 | return x 144 | else: 145 | return x 146 | 147 | def forward(self, x): 148 | return self.residual(x) + self.shortcut(x) 149 | 150 | 151 | class Discriminator(nn.Module): 152 | def __init__(self, args, activation=nn.ReLU()): 153 | super(Discriminator, self).__init__() 154 | self.ch = args.df_dim 155 | self.activation = activation 156 | self.block1 = OptimizedDisBlock(args, 3, self.ch) 157 | self.block2 = DisBlock(args, self.ch, self.ch, activation=activation, downsample=True) 158 | self.block3 = DisBlock(args, self.ch, self.ch, activation=activation, downsample=False) 159 | self.block4 = DisBlock(args, self.ch, self.ch, activation=activation, downsample=False) 160 | self.l5 = nn.Linear(self.ch, 1, bias=False) 161 | if args.d_spectral_norm: 162 | self.l5 = nn.utils.spectral_norm(self.l5) 163 | 164 | def forward(self, x): 165 | h = x 166 | h = self.block1(h) 167 | h = self.block2(h) 168 | h = self.block3(h) 169 | h = self.block4(h) 170 | h = self.activation(h) 171 | # Global average pooling 172 | h = h.sum(2).sum(2) 173 | output = self.l5(h) 174 | 175 | return output 176 | -------------------------------------------------------------------------------- /models/sngan_cifar10.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .gen_resblock import GenBlock 3 | 4 | 5 | class Generator(nn.Module): 6 | def __init__(self, args, activation=nn.ReLU(), n_classes=0): 7 | super(Generator, self).__init__() 8 | self.bottom_width = args.bottom_width 9 | self.activation = activation 10 | self.n_classes = n_classes 11 | self.ch = args.gf_dim 12 | self.l1 = nn.Linear(args.latent_dim, (self.bottom_width ** 2) * self.ch) 13 | self.block2 = GenBlock(self.ch, self.ch, activation=activation, upsample=True, n_classes=n_classes) 14 | self.block3 = GenBlock(self.ch, self.ch, activation=activation, upsample=True, n_classes=n_classes) 15 | self.block4 = GenBlock(self.ch, self.ch, activation=activation, upsample=True, n_classes=n_classes) 16 | self.b5 = nn.BatchNorm2d(self.ch) 17 | self.c5 = nn.Conv2d(self.ch, 3, kernel_size=3, stride=1, padding=1) 18 | 19 | def forward(self, z): 20 | 21 | h = z 22 | h = self.l1(h).view(-1, self.ch, self.bottom_width, self.bottom_width) 23 | h = self.block2(h) 24 | h = self.block3(h) 25 | h = self.block4(h) 26 | h = self.b5(h) 27 | h = self.activation(h) 28 | h = nn.Tanh()(self.c5(h)) 29 | return h 30 | 31 | 32 | """Discriminator""" 33 | 34 | 35 | def _downsample(x): 36 | # Downsample (Mean Avg Pooling with 2x2 kernel) 37 | return nn.AvgPool2d(kernel_size=2)(x) 38 | 39 | 40 | class OptimizedDisBlock(nn.Module): 41 | def __init__(self, args, in_channels, out_channels, ksize=3, pad=1, activation=nn.ReLU()): 42 | super(OptimizedDisBlock, self).__init__() 43 | self.activation = activation 44 | 45 | self.c1 = nn.Conv2d(in_channels, out_channels, kernel_size=ksize, padding=pad) 46 | self.c2 = nn.Conv2d(out_channels, out_channels, kernel_size=ksize, padding=pad) 47 | self.c_sc = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) 48 | if args.d_spectral_norm: 49 | self.c1 = nn.utils.spectral_norm(self.c1) 50 | self.c2 = nn.utils.spectral_norm(self.c2) 51 | self.c_sc = nn.utils.spectral_norm(self.c_sc) 52 | 53 | def residual(self, x): 54 | h = x 55 | h = self.c1(h) 56 | h = self.activation(h) 57 | h = self.c2(h) 58 | h = _downsample(h) 59 | return h 60 | 61 | def shortcut(self, x): 62 | return self.c_sc(_downsample(x)) 63 | 64 | def forward(self, x): 65 | return self.residual(x) + self.shortcut(x) 66 | 67 | 68 | class DisBlock(nn.Module): 69 | def __init__(self, args, in_channels, out_channels, hidden_channels=None, ksize=3, pad=1, 70 | activation=nn.ReLU(), downsample=False): 71 | super(DisBlock, self).__init__() 72 | self.activation = activation 73 | self.downsample = downsample 74 | self.learnable_sc = (in_channels != out_channels) or downsample 75 | hidden_channels = in_channels if hidden_channels is None else hidden_channels 76 | self.c1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=ksize, padding=pad) 77 | self.c2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=ksize, padding=pad) 78 | if args.d_spectral_norm: 79 | self.c1 = nn.utils.spectral_norm(self.c1) 80 | self.c2 = nn.utils.spectral_norm(self.c2) 81 | 82 | if self.learnable_sc: 83 | self.c_sc = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) 84 | if args.d_spectral_norm: 85 | self.c_sc = nn.utils.spectral_norm(self.c_sc) 86 | 87 | def residual(self, x): 88 | h = x 89 | h = self.activation(h) 90 | h = self.c1(h) 91 | h = self.activation(h) 92 | h = self.c2(h) 93 | if self.downsample: 94 | h = _downsample(h) 95 | return h 96 | 97 | def shortcut(self, x): 98 | if self.learnable_sc: 99 | x = self.c_sc(x) 100 | if self.downsample: 101 | return _downsample(x) 102 | else: 103 | return x 104 | else: 105 | return x 106 | 107 | def forward(self, x): 108 | return self.residual(x) + self.shortcut(x) 109 | 110 | 111 | class Discriminator(nn.Module): 112 | def __init__(self, args, activation=nn.ReLU()): 113 | super(Discriminator, self).__init__() 114 | self.ch = args.df_dim 115 | self.activation = activation 116 | self.block1 = OptimizedDisBlock(args, 3, self.ch) 117 | self.block2 = DisBlock(args, self.ch, self.ch, activation=activation, downsample=True) 118 | self.block3 = DisBlock(args, self.ch, self.ch, activation=activation, downsample=False) 119 | self.block4 = DisBlock(args, self.ch, self.ch, activation=activation, downsample=False) 120 | self.l5 = nn.Linear(self.ch, 1, bias=False) 121 | if args.d_spectral_norm: 122 | self.l5 = nn.utils.spectral_norm(self.l5) 123 | 124 | def forward(self, x): 125 | h = x 126 | h = self.block1(h) 127 | h = self.block2(h) 128 | h = self.block3(h) 129 | h = self.block4(h) 130 | h = self.activation(h) 131 | # Global average pooling 132 | h = h.sum(2).sum(2) 133 | output = self.l5(h) 134 | 135 | return output 136 | -------------------------------------------------------------------------------- /models/sngan_stl10.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class GenBlock(nn.Module): 5 | def __init__(self, in_channels, out_channels, hidden_channels=None, ksize=3, pad=1, 6 | activation=nn.ReLU(), upsample=False, n_classes=0): 7 | super(GenBlock, self).__init__() 8 | self.activation = activation 9 | self.upsample = upsample 10 | self.learnable_sc = in_channels != out_channels or upsample 11 | hidden_channels = out_channels if hidden_channels is None else hidden_channels 12 | self.n_classes = n_classes 13 | self.c1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=ksize, padding=pad) 14 | self.c2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=ksize, padding=pad) 15 | 16 | self.b1 = nn.BatchNorm2d(in_channels) 17 | self.b2 = nn.BatchNorm2d(hidden_channels) 18 | if self.learnable_sc: 19 | self.c_sc = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) 20 | 21 | def upsample_conv(self, x, conv): 22 | return conv(nn.UpsamplingNearest2d(scale_factor=2)(x)) 23 | 24 | def residual(self, x): 25 | h = x 26 | h = self.b1(h) 27 | h = self.activation(h) 28 | h = self.upsample_conv(h, self.c1) if self.upsample else self.c1(h) 29 | h = self.b2(h) 30 | h = self.activation(h) 31 | h = self.c2(h) 32 | return h 33 | 34 | def shortcut(self, x): 35 | if self.learnable_sc: 36 | x = self.upsample_conv(x, self.c_sc) if self.upsample else self.c_sc(x) 37 | return x 38 | else: 39 | return x 40 | 41 | def forward(self, x): 42 | return self.residual(x) + self.shortcut(x) 43 | 44 | 45 | class Generator(nn.Module): 46 | def __init__(self, args, activation=nn.ReLU(), n_classes=0): 47 | super(Generator, self).__init__() 48 | self.bottom_width = args.bottom_width 49 | self.activation = activation 50 | self.n_classes = n_classes 51 | self.ch = 512 52 | self.l1 = nn.Linear(args.latent_dim, (self.bottom_width ** 2) * self.ch) 53 | self.block2 = GenBlock(512, 256, activation=activation, upsample=True, n_classes=n_classes) 54 | self.block3 = GenBlock(256, 128, activation=activation, upsample=True, n_classes=n_classes) 55 | self.block4 = GenBlock(128, 64, activation=activation, upsample=True, n_classes=n_classes) 56 | self.b5 = nn.BatchNorm2d(64) 57 | self.c5 = nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1) 58 | 59 | def forward(self, z): 60 | 61 | h = z 62 | h = self.l1(h).view(-1, self.ch, self.bottom_width, self.bottom_width) 63 | h = self.block2(h) 64 | h = self.block3(h) 65 | h = self.block4(h) 66 | h = self.b5(h) 67 | h = self.activation(h) 68 | h = nn.Tanh()(self.c5(h)) 69 | return h 70 | 71 | 72 | """Discriminator""" 73 | 74 | 75 | def _downsample(x): 76 | # Downsample (Mean Avg Pooling with 2x2 kernel) 77 | return nn.AvgPool2d(kernel_size=2)(x) 78 | 79 | 80 | class OptimizedDisBlock(nn.Module): 81 | def __init__(self, args, in_channels, out_channels, ksize=3, pad=1, activation=nn.ReLU()): 82 | super(OptimizedDisBlock, self).__init__() 83 | self.activation = activation 84 | 85 | self.c1 = nn.Conv2d(in_channels, out_channels, kernel_size=ksize, padding=pad) 86 | self.c2 = nn.Conv2d(out_channels, out_channels, kernel_size=ksize, padding=pad) 87 | self.c_sc = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) 88 | if args.d_spectral_norm: 89 | self.c1 = nn.utils.spectral_norm(self.c1) 90 | self.c2 = nn.utils.spectral_norm(self.c2) 91 | self.c_sc = nn.utils.spectral_norm(self.c_sc) 92 | 93 | def residual(self, x): 94 | h = x 95 | h = self.c1(h) 96 | h = self.activation(h) 97 | h = self.c2(h) 98 | h = _downsample(h) 99 | return h 100 | 101 | def shortcut(self, x): 102 | return self.c_sc(_downsample(x)) 103 | 104 | def forward(self, x): 105 | return self.residual(x) + self.shortcut(x) 106 | 107 | 108 | class DisBlock(nn.Module): 109 | def __init__(self, args, in_channels, out_channels, hidden_channels=None, ksize=3, pad=1, 110 | activation=nn.ReLU(), downsample=False): 111 | super(DisBlock, self).__init__() 112 | self.activation = activation 113 | self.downsample = downsample 114 | self.learnable_sc = (in_channels != out_channels) or downsample 115 | hidden_channels = in_channels if hidden_channels is None else hidden_channels 116 | self.c1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=ksize, padding=pad) 117 | self.c2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=ksize, padding=pad) 118 | if args.d_spectral_norm: 119 | self.c1 = nn.utils.spectral_norm(self.c1) 120 | self.c2 = nn.utils.spectral_norm(self.c2) 121 | 122 | if self.learnable_sc: 123 | self.c_sc = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) 124 | if args.d_spectral_norm: 125 | self.c_sc = nn.utils.spectral_norm(self.c_sc) 126 | 127 | def residual(self, x): 128 | h = x 129 | h = self.activation(h) 130 | h = self.c1(h) 131 | h = self.activation(h) 132 | h = self.c2(h) 133 | if self.downsample: 134 | h = _downsample(h) 135 | return h 136 | 137 | def shortcut(self, x): 138 | if self.learnable_sc: 139 | x = self.c_sc(x) 140 | if self.downsample: 141 | return _downsample(x) 142 | else: 143 | return x 144 | else: 145 | return x 146 | 147 | def forward(self, x): 148 | return self.residual(x) + self.shortcut(x) 149 | 150 | 151 | class Discriminator(nn.Module): 152 | def __init__(self, args, activation=nn.ReLU()): 153 | super(Discriminator, self).__init__() 154 | self.activation = activation 155 | self.block1 = OptimizedDisBlock(args, 3, 64) 156 | self.block2 = DisBlock(args, 64, 128, activation=activation, downsample=True) 157 | self.block3 = DisBlock(args, 128, 256, activation=activation, downsample=True) 158 | self.block4 = DisBlock(args, 256, 512, activation=activation, downsample=True) 159 | self.block5 = DisBlock(args, 512, 1024, activation=activation, downsample=False) 160 | 161 | self.l6 = nn.Linear(1024, 1, bias=False) 162 | if args.d_spectral_norm: 163 | self.l6 = nn.utils.spectral_norm(self.l6) 164 | 165 | def forward(self, x): 166 | h = x 167 | h = self.block1(h) 168 | h = self.block2(h) 169 | h = self.block3(h) 170 | h = self.block4(h) 171 | h = self.block5(h) 172 | h = self.activation(h) 173 | # Global average pooling 174 | h = h.sum(2).sum(2) 175 | output = self.l6(h) 176 | 177 | return output 178 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy 2 | six 3 | imageio 4 | dateutil 5 | numpy 6 | tensorboard==1.12.2 7 | tensorboardX==1.6 8 | tensorflow-gpu==1.12.0 9 | torch==1.1.0 10 | torchvision==0.3.0 11 | tqdm==4.29.1 12 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2019-07-25 3 | # @Author : Xinyu Gong (xy_gong@tamu.edu) 4 | # @Link : None 5 | # @Version : 0.0 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import cfg 12 | import models 13 | from functions import validate 14 | from utils.utils import set_log_dir, create_logger 15 | from utils.inception_score import _init_inception 16 | from utils.fid_score import create_inception_graph, check_or_download_inception 17 | 18 | import torch 19 | import os 20 | import numpy as np 21 | from tensorboardX import SummaryWriter 22 | 23 | torch.backends.cudnn.enabled = True 24 | torch.backends.cudnn.benchmark = True 25 | 26 | 27 | def main(): 28 | args = cfg.parse_args() 29 | torch.cuda.manual_seed(args.random_seed) 30 | assert args.exp_name 31 | assert args.load_path.endswith('.pth') 32 | assert os.path.exists(args.load_path) 33 | args.path_helper = set_log_dir('logs_eval', args.exp_name) 34 | logger = create_logger(args.path_helper['log_path'], phase='test') 35 | 36 | # set tf env 37 | _init_inception() 38 | inception_path = check_or_download_inception(None) 39 | create_inception_graph(inception_path) 40 | 41 | # import network 42 | gen_net = eval('models.'+args.model+'.Generator')(args=args).cuda() 43 | 44 | # fid stat 45 | if args.dataset.lower() == 'cifar10': 46 | fid_stat = 'fid_stat/fid_stats_cifar10_train.npz' 47 | else: 48 | raise NotImplementedError(f'no fid stat for {args.dataset.lower()}') 49 | assert os.path.exists(fid_stat) 50 | 51 | # initial 52 | fixed_z = torch.cuda.FloatTensor(np.random.normal(0, 1, (25, args.latent_dim))) 53 | 54 | # set writer 55 | logger.info(f'=> resuming from {args.load_path}') 56 | checkpoint_file = args.load_path 57 | assert os.path.exists(checkpoint_file) 58 | checkpoint = torch.load(checkpoint_file) 59 | 60 | if 'avg_gen_state_dict' in checkpoint: 61 | gen_net.load_state_dict(checkpoint['avg_gen_state_dict']) 62 | epoch = checkpoint['epoch'] 63 | logger.info(f'=> loaded checkpoint {checkpoint_file} (epoch {epoch})') 64 | else: 65 | gen_net.load_state_dict(checkpoint) 66 | logger.info(f'=> loaded checkpoint {checkpoint_file}') 67 | 68 | logger.info(args) 69 | writer_dict = { 70 | 'writer': SummaryWriter(args.path_helper['log_path']), 71 | 'valid_global_steps': 0, 72 | } 73 | inception_score, fid_score = validate(args, fixed_z, fid_stat, gen_net, writer_dict) 74 | logger.info(f'Inception score: {inception_score}, FID score: {fid_score}.') 75 | 76 | 77 | if __name__ == '__main__': 78 | main() 79 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2019-07-25 3 | # @Author : Xinyu Gong (xy_gong@tamu.edu) 4 | # @Link : None 5 | # @Version : 0.0 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import cfg 12 | import models 13 | import datasets 14 | from functions import train, validate, LinearLrDecay, load_params, copy_params 15 | from utils.utils import set_log_dir, save_checkpoint, create_logger 16 | from utils.inception_score import _init_inception 17 | from utils.fid_score import create_inception_graph, check_or_download_inception 18 | 19 | import torch 20 | import os 21 | import numpy as np 22 | import torch.nn as nn 23 | from tensorboardX import SummaryWriter 24 | from tqdm import tqdm 25 | from copy import deepcopy 26 | 27 | torch.backends.cudnn.enabled = True 28 | torch.backends.cudnn.benchmark = True 29 | 30 | 31 | def main(): 32 | args = cfg.parse_args() 33 | torch.cuda.manual_seed(args.random_seed) 34 | 35 | # set tf env 36 | _init_inception() 37 | inception_path = check_or_download_inception(None) 38 | create_inception_graph(inception_path) 39 | 40 | # import network 41 | gen_net = eval('models.'+args.model+'.Generator')(args=args).cuda() 42 | dis_net = eval('models.'+args.model+'.Discriminator')(args=args).cuda() 43 | 44 | # weight init 45 | def weights_init(m): 46 | classname = m.__class__.__name__ 47 | if classname.find('Conv2d') != -1: 48 | if args.init_type == 'normal': 49 | nn.init.normal_(m.weight.data, 0.0, 0.02) 50 | elif args.init_type == 'orth': 51 | nn.init.orthogonal_(m.weight.data) 52 | elif args.init_type == 'xavier_uniform': 53 | nn.init.xavier_uniform(m.weight.data, 1.) 54 | else: 55 | raise NotImplementedError('{} unknown inital type'.format(args.init_type)) 56 | elif classname.find('BatchNorm2d') != -1: 57 | nn.init.normal_(m.weight.data, 1.0, 0.02) 58 | nn.init.constant_(m.bias.data, 0.0) 59 | 60 | gen_net.apply(weights_init) 61 | dis_net.apply(weights_init) 62 | 63 | # set optimizer 64 | gen_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, gen_net.parameters()), 65 | args.g_lr, (args.beta1, args.beta2)) 66 | dis_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, dis_net.parameters()), 67 | args.d_lr, (args.beta1, args.beta2)) 68 | gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0, args.max_iter * args.n_critic) 69 | dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0, args.max_iter * args.n_critic) 70 | 71 | # set up data_loader 72 | dataset = datasets.ImageDataset(args) 73 | train_loader = dataset.train 74 | 75 | # fid stat 76 | if args.dataset.lower() == 'cifar10': 77 | fid_stat = 'fid_stat/fid_stats_cifar10_train.npz' 78 | elif args.dataset.lower() == 'stl10': 79 | fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz' 80 | else: 81 | raise NotImplementedError(f'no fid stat for {args.dataset.lower()}') 82 | assert os.path.exists(fid_stat) 83 | 84 | # epoch number for dis_net 85 | args.max_epoch = args.max_epoch * args.n_critic 86 | if args.max_iter: 87 | args.max_epoch = np.ceil(args.max_iter * args.n_critic / len(train_loader)) 88 | 89 | # initial 90 | fixed_z = torch.cuda.FloatTensor(np.random.normal(0, 1, (25, args.latent_dim))) 91 | gen_avg_param = copy_params(gen_net) 92 | start_epoch = 0 93 | best_fid = 1e4 94 | 95 | # set writer 96 | if args.load_path: 97 | print(f'=> resuming from {args.load_path}') 98 | assert os.path.exists(args.load_path) 99 | checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint.pth') 100 | assert os.path.exists(checkpoint_file) 101 | checkpoint = torch.load(checkpoint_file) 102 | start_epoch = checkpoint['epoch'] 103 | best_fid = checkpoint['best_fid'] 104 | gen_net.load_state_dict(checkpoint['gen_state_dict']) 105 | dis_net.load_state_dict(checkpoint['dis_state_dict']) 106 | gen_optimizer.load_state_dict(checkpoint['gen_optimizer']) 107 | dis_optimizer.load_state_dict(checkpoint['dis_optimizer']) 108 | avg_gen_net = deepcopy(gen_net) 109 | avg_gen_net.load_state_dict(checkpoint['avg_gen_state_dict']) 110 | gen_avg_param = copy_params(avg_gen_net) 111 | del avg_gen_net 112 | 113 | args.path_helper = checkpoint['path_helper'] 114 | logger = create_logger(args.path_helper['log_path']) 115 | logger.info(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})') 116 | else: 117 | # create new log dir 118 | assert args.exp_name 119 | args.path_helper = set_log_dir('logs', args.exp_name) 120 | logger = create_logger(args.path_helper['log_path']) 121 | 122 | logger.info(args) 123 | writer_dict = { 124 | 'writer': SummaryWriter(args.path_helper['log_path']), 125 | 'train_global_steps': start_epoch * len(train_loader), 126 | 'valid_global_steps': start_epoch // args.val_freq, 127 | } 128 | 129 | # train loop 130 | lr_schedulers = (gen_scheduler, dis_scheduler) if args.lr_decay else None 131 | for epoch in tqdm(range(int(start_epoch), int(args.max_epoch)), desc='total progress'): 132 | train(args, gen_net, dis_net, gen_optimizer, dis_optimizer, gen_avg_param, train_loader, epoch, writer_dict, 133 | lr_schedulers) 134 | 135 | if epoch and epoch % args.val_freq == 0 or epoch == int(args.max_epoch)-1: 136 | backup_param = copy_params(gen_net) 137 | load_params(gen_net, gen_avg_param) 138 | inception_score, fid_score = validate(args, fixed_z, fid_stat, gen_net, writer_dict) 139 | logger.info(f'Inception score: {inception_score}, FID score: {fid_score} || @ epoch {epoch}.') 140 | load_params(gen_net, backup_param) 141 | if fid_score < best_fid: 142 | best_fid = fid_score 143 | is_best = True 144 | else: 145 | is_best = False 146 | else: 147 | is_best = False 148 | 149 | avg_gen_net = deepcopy(gen_net) 150 | load_params(avg_gen_net, gen_avg_param) 151 | save_checkpoint({ 152 | 'epoch': epoch + 1, 153 | 'model': args.model, 154 | 'gen_state_dict': gen_net.state_dict(), 155 | 'dis_state_dict': dis_net.state_dict(), 156 | 'avg_gen_state_dict': avg_gen_net.state_dict(), 157 | 'gen_optimizer': gen_optimizer.state_dict(), 158 | 'dis_optimizer': dis_optimizer.state_dict(), 159 | 'best_fid': best_fid, 160 | 'path_helper': args.path_helper 161 | }, is_best, args.path_helper['ckpt_path']) 162 | del avg_gen_net 163 | 164 | 165 | if __name__ == '__main__': 166 | main() 167 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2019-07-25 3 | # @Author : Xinyu Gong (xy_gong@tamu.edu) 4 | # @Link : None 5 | # @Version : 0.0 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | from utils import utils 12 | -------------------------------------------------------------------------------- /utils/cal_fid_stat.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2019-07-26 3 | # @Author : Xinyu Gong (xy_gong@tamu.edu) 4 | # @Link : None 5 | # @Version : 0.0 6 | 7 | 8 | import os 9 | import glob 10 | import argparse 11 | import numpy as np 12 | from imageio import imread 13 | import tensorflow as tf 14 | 15 | import utils.fid_score as fid 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument( 21 | '--data_path', 22 | type=str, 23 | required=True, 24 | help='set path to training set jpg images dir') 25 | parser.add_argument( 26 | '--output_file', 27 | type=str, 28 | default='fid_stat/fid_stats_cifar10_train.npz', 29 | help='path for where to store the statistics') 30 | 31 | opt = parser.parse_args() 32 | print(opt) 33 | return opt 34 | 35 | 36 | def main(): 37 | args = parse_args() 38 | 39 | ######## 40 | # PATHS 41 | ######## 42 | data_path = args.data_path 43 | output_path = args.output_file 44 | # if you have downloaded and extracted 45 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 46 | # set this path to the directory where the extracted files are, otherwise 47 | # just set it to None and the script will later download the files for you 48 | inception_path = None 49 | print("check for inception model..", end=" ", flush=True) 50 | inception_path = fid.check_or_download_inception(inception_path) # download inception if necessary 51 | print("ok") 52 | 53 | # loads all images into memory (this might require a lot of RAM!) 54 | print("load images..", end=" ", flush=True) 55 | image_list = glob.glob(os.path.join(data_path, '*.jpg')) 56 | images = np.array([imread(str(fn)).astype(np.float32) for fn in image_list]) 57 | print("%d images found and loaded" % len(images)) 58 | 59 | print("create inception graph..", end=" ", flush=True) 60 | fid.create_inception_graph(inception_path) # load the graph into the current TF graph 61 | print("ok") 62 | 63 | print("calculte FID stats..", end=" ", flush=True) 64 | config = tf.ConfigProto() 65 | config.gpu_options.allow_growth = True 66 | with tf.Session(config=config) as sess: 67 | sess.run(tf.global_variables_initializer()) 68 | mu, sigma = fid.calculate_activation_statistics(images, sess, batch_size=100) 69 | np.savez_compressed(output_path, mu=mu, sigma=sigma) 70 | print("finished") 71 | 72 | 73 | if __name__ == '__main__': 74 | main() 75 | -------------------------------------------------------------------------------- /utils/fid_score.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ Calculates the Frechet Inception Distance (FID) to evaluate GANs. 3 | 4 | The FID metric calculates the distance between two distributions of images. 5 | Typically, we have summary statistics (mean & covariance matrix) of one 6 | of these distributions, while the 2nd distribution is given by a GAN. 7 | 8 | When run as a stand-alone program, it compares the distribution of 9 | images that are stored as PNG/JPEG at a specified location with a 10 | distribution given by summary statistics (in pickle format). 11 | 12 | The FID is calculated by assuming that X_1 and X_2 are the activations of 13 | the pool_3 layer of the inception net for generated samples and real world 14 | samples respectively. 15 | 16 | See --help to see further details. 17 | """ 18 | 19 | from __future__ import absolute_import, division, print_function 20 | import numpy as np 21 | import os 22 | import tensorflow as tf 23 | from imageio import imread 24 | from scipy import linalg 25 | import pathlib 26 | import warnings 27 | 28 | 29 | class InvalidFIDException(Exception): 30 | pass 31 | 32 | 33 | def create_inception_graph(pth): 34 | """Creates a graph from saved GraphDef file.""" 35 | # Creates graph from saved graph_def.pb. 36 | with tf.gfile.FastGFile(pth, 'rb') as f: 37 | graph_def = tf.GraphDef() 38 | graph_def.ParseFromString(f.read()) 39 | _ = tf.import_graph_def(graph_def, name='FID_Inception_Net') 40 | 41 | 42 | # ------------------------------------------------------------------------------- 43 | 44 | 45 | # code for handling inception net derived from 46 | # https://github.com/openai/improved-gan/blob/master/inception_score/model.py 47 | def _get_inception_layer(sess): 48 | """Prepares inception net for batched usage and returns pool_3 layer. """ 49 | layername = 'FID_Inception_Net/pool_3:0' 50 | pool3 = sess.graph.get_tensor_by_name(layername) 51 | ops = pool3.graph.get_operations() 52 | for op_idx, op in enumerate(ops): 53 | for o in op.outputs: 54 | shape = o.get_shape() 55 | if shape._dims != []: 56 | shape = [s.value for s in shape] 57 | new_shape = [] 58 | for j, s in enumerate(shape): 59 | if s == 1 and j == 0: 60 | new_shape.append(None) 61 | else: 62 | new_shape.append(s) 63 | o.__dict__['_shape_val'] = tf.TensorShape(new_shape) 64 | return pool3 65 | 66 | 67 | # ------------------------------------------------------------------------------- 68 | 69 | 70 | def get_activations(images, sess, batch_size=50, verbose=False): 71 | """Calculates the activations of the pool_3 layer for all images. 72 | 73 | Params: 74 | -- images : Numpy array of dimension (n_images, hi, wi, 3). The values 75 | must lie between 0 and 256. 76 | -- sess : current session 77 | -- batch_size : the images numpy array is split into batches with batch size 78 | batch_size. A reasonable batch size depends on the disposable hardware. 79 | -- verbose : If set to True and parameter out_step is given, the number of calculated 80 | batches is reported. 81 | Returns: 82 | -- A numpy array of dimension (num images, 2048) that contains the 83 | activations of the given tensor when feeding inception with the query tensor. 84 | """ 85 | inception_layer = _get_inception_layer(sess) 86 | d0 = images.shape[0] 87 | if batch_size > d0: 88 | print("warning: batch size is bigger than the data size. setting batch size to data size") 89 | batch_size = d0 90 | n_batches = d0 // batch_size 91 | n_used_imgs = n_batches * batch_size 92 | pred_arr = np.empty((n_used_imgs, 2048)) 93 | for i in range(n_batches): 94 | if verbose: 95 | print("\rPropagating batch %d/%d" % (i + 1, n_batches), end="", flush=True) 96 | start = i * batch_size 97 | end = start + batch_size 98 | batch = images[start:end] 99 | pred = sess.run(inception_layer, {'FID_Inception_Net/ExpandDims:0': batch}) 100 | pred_arr[start:end] = pred.reshape(batch_size, -1) 101 | if verbose: 102 | print(" done") 103 | return pred_arr 104 | 105 | 106 | # ------------------------------------------------------------------------------- 107 | 108 | 109 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 110 | """Numpy implementation of the Frechet Distance. 111 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 112 | and X_2 ~ N(mu_2, C_2) is 113 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 114 | 115 | Stable version by Dougal J. Sutherland. 116 | 117 | Params: 118 | -- mu1 : Numpy array containing the activations of the pool_3 layer of the 119 | inception net ( like returned by the function 'get_predictions') 120 | for generated samples. 121 | -- mu2 : The sample mean over activations of the pool_3 layer, precalcualted 122 | on an representive data set. 123 | -- sigma1: The covariance matrix over activations of the pool_3 layer for 124 | generated samples. 125 | -- sigma2: The covariance matrix over activations of the pool_3 layer, 126 | precalcualted on an representive data set. 127 | 128 | Returns: 129 | -- : The Frechet Distance. 130 | """ 131 | 132 | mu1 = np.atleast_1d(mu1) 133 | mu2 = np.atleast_1d(mu2) 134 | 135 | sigma1 = np.atleast_2d(sigma1) 136 | sigma2 = np.atleast_2d(sigma2) 137 | 138 | assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths" 139 | assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions" 140 | 141 | diff = mu1 - mu2 142 | 143 | # product might be almost singular 144 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 145 | if not np.isfinite(covmean).all(): 146 | msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps 147 | warnings.warn(msg) 148 | offset = np.eye(sigma1.shape[0]) * eps 149 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 150 | 151 | # numerical error might give slight imaginary component 152 | if np.iscomplexobj(covmean): 153 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 154 | m = np.max(np.abs(covmean.imag)) 155 | raise ValueError("Imaginary component {}".format(m)) 156 | covmean = covmean.real 157 | 158 | tr_covmean = np.trace(covmean) 159 | 160 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 161 | 162 | 163 | # ------------------------------------------------------------------------------- 164 | 165 | 166 | def calculate_activation_statistics(images, sess, batch_size=50, verbose=False): 167 | """Calculation of the statistics used by the FID. 168 | Params: 169 | -- images : Numpy array of dimension (n_images, hi, wi, 3). The values 170 | must lie between 0 and 255. 171 | -- sess : current session 172 | -- batch_size : the images numpy array is split into batches with batch size 173 | batch_size. A reasonable batch size depends on the available hardware. 174 | -- verbose : If set to True and parameter out_step is given, the number of calculated 175 | batches is reported. 176 | Returns: 177 | -- mu : The mean over samples of the activations of the pool_3 layer of 178 | the incption model. 179 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 180 | the incption model. 181 | """ 182 | act = get_activations(images, sess, batch_size, verbose) 183 | mu = np.mean(act, axis=0) 184 | sigma = np.cov(act, rowvar=False) 185 | return mu, sigma 186 | 187 | 188 | # ------------------ 189 | # The following methods are implemented to obtain a batched version of the activations. 190 | # This has the advantage to reduce memory requirements, at the cost of slightly reduced efficiency. 191 | # - Pyrestone 192 | # ------------------ 193 | 194 | 195 | def load_image_batch(files): 196 | """Convenience method for batch-loading images 197 | Params: 198 | -- files : list of paths to image files. Images need to have same dimensions for all files. 199 | Returns: 200 | -- A numpy array of dimensions (num_images,hi, wi, 3) representing the image pixel values. 201 | """ 202 | return np.array([imread(str(fn)).astype(np.float32) for fn in files]) 203 | 204 | 205 | def get_activations_from_files(files, sess, batch_size=50, verbose=False): 206 | """Calculates the activations of the pool_3 layer for all images. 207 | 208 | Params: 209 | -- files : list of paths to image files. Images need to have same dimensions for all files. 210 | -- sess : current session 211 | -- batch_size : the images numpy array is split into batches with batch size 212 | batch_size. A reasonable batch size depends on the disposable hardware. 213 | -- verbose : If set to True and parameter out_step is given, the number of calculated 214 | batches is reported. 215 | Returns: 216 | -- A numpy array of dimension (num images, 2048) that contains the 217 | activations of the given tensor when feeding inception with the query tensor. 218 | """ 219 | inception_layer = _get_inception_layer(sess) 220 | d0 = len(files) 221 | if batch_size > d0: 222 | print("warning: batch size is bigger than the data size. setting batch size to data size") 223 | batch_size = d0 224 | n_batches = d0 // batch_size 225 | n_used_imgs = n_batches * batch_size 226 | pred_arr = np.empty((n_used_imgs, 2048)) 227 | for i in range(n_batches): 228 | if verbose: 229 | print("\rPropagating batch %d/%d" % (i + 1, n_batches), end="", flush=True) 230 | start = i * batch_size 231 | end = start + batch_size 232 | batch = load_image_batch(files[start:end]) 233 | pred = sess.run(inception_layer, {'FID_Inception_Net/ExpandDims:0': batch}) 234 | pred_arr[start:end] = pred.reshape(batch_size, -1) 235 | del batch # clean up memory 236 | if verbose: 237 | print(" done") 238 | return pred_arr 239 | 240 | 241 | def calculate_activation_statistics_from_files(files, sess, batch_size=50, verbose=False): 242 | """Calculation of the statistics used by the FID. 243 | Params: 244 | -- files : list of paths to image files. Images need to have same dimensions for all files. 245 | -- sess : current session 246 | -- batch_size : the images numpy array is split into batches with batch size 247 | batch_size. A reasonable batch size depends on the available hardware. 248 | -- verbose : If set to True and parameter out_step is given, the number of calculated 249 | batches is reported. 250 | Returns: 251 | -- mu : The mean over samples of the activations of the pool_3 layer of 252 | the incption model. 253 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 254 | the incption model. 255 | """ 256 | act = get_activations_from_files(files, sess, batch_size, verbose) 257 | mu = np.mean(act, axis=0) 258 | sigma = np.cov(act, rowvar=False) 259 | return mu, sigma 260 | 261 | 262 | # ------------------------------------------------------------------------------- 263 | 264 | 265 | # ------------------------------------------------------------------------------- 266 | # The following functions aren't needed for calculating the FID 267 | # they're just here to make this module work as a stand-alone script 268 | # for calculating FID scores 269 | # ------------------------------------------------------------------------------- 270 | def check_or_download_inception(inception_path): 271 | """ Checks if the path to the inception file is valid, or downloads 272 | the file if it is not present. """ 273 | INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 274 | if inception_path is None: 275 | inception_path = '/tmp' 276 | inception_path = pathlib.Path(inception_path) 277 | model_file = inception_path / 'classify_image_graph_def.pb' 278 | if not model_file.exists(): 279 | print("Downloading Inception model") 280 | from urllib import request 281 | import tarfile 282 | fn, _ = request.urlretrieve(INCEPTION_URL) 283 | with tarfile.open(fn, mode='r') as f: 284 | f.extract('classify_image_graph_def.pb', str(model_file.parent)) 285 | return str(model_file) 286 | 287 | 288 | def _handle_path(path, sess, low_profile=False): 289 | if path.endswith('.npz'): 290 | f = np.load(path) 291 | m, s = f['mu'][:], f['sigma'][:] 292 | f.close() 293 | else: 294 | path = pathlib.Path(path) 295 | files = list(path.glob('*.jpg')) + list(path.glob('*.png')) 296 | if low_profile: 297 | m, s = calculate_activation_statistics_from_files(files, sess) 298 | else: 299 | x = np.array([imread(str(fn)).astype(np.float32) for fn in files]) 300 | m, s = calculate_activation_statistics(x, sess) 301 | del x # clean up memory 302 | return m, s 303 | 304 | 305 | def calculate_fid_given_paths(paths, inception_path, low_profile=False): 306 | """ Calculates the FID of two paths. """ 307 | # inception_path = check_or_download_inception(inception_path) 308 | 309 | for p in paths: 310 | if not os.path.exists(p): 311 | raise RuntimeError("Invalid path: %s" % p) 312 | # from utils import memory 313 | # memory() 314 | config = tf.ConfigProto() 315 | config.gpu_options.allow_growth = True 316 | with tf.Session(config=config) as sess: 317 | sess.run(tf.global_variables_initializer()) 318 | m1, s1 = _handle_path(paths[0], sess, low_profile=low_profile) 319 | m2, s2 = _handle_path(paths[1], sess, low_profile=low_profile) 320 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 321 | sess.close() 322 | del m1, s1, m2, s2 323 | 324 | return fid_value 325 | -------------------------------------------------------------------------------- /utils/inception_score.py: -------------------------------------------------------------------------------- 1 | # Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from tqdm import tqdm 6 | 7 | import os.path 8 | import tarfile 9 | 10 | import numpy as np 11 | from six.moves import urllib 12 | import tensorflow as tf 13 | 14 | import math 15 | import sys 16 | 17 | MODEL_DIR = '/tmp/imagenet' 18 | DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 19 | softmax = None 20 | 21 | config = tf.ConfigProto() 22 | config.gpu_options.allow_growth = True 23 | 24 | 25 | # Call this function with list of images. Each of elements should be a 26 | # numpy array with values ranging from 0 to 255. 27 | def get_inception_score(images, splits=10): 28 | assert (type(images) == list) 29 | assert (type(images[0]) == np.ndarray) 30 | assert (len(images[0].shape) == 3) 31 | assert (np.max(images[0]) > 10) 32 | assert (np.min(images[0]) >= 0.0) 33 | inps = [] 34 | for img in images: 35 | img = img.astype(np.float32) 36 | inps.append(np.expand_dims(img, 0)) 37 | bs = 100 38 | with tf.Session(config=config) as sess: 39 | preds = [] 40 | n_batches = int(math.ceil(float(len(inps)) / float(bs))) 41 | for i in tqdm(range(n_batches), desc="Calculate inception score"): 42 | sys.stdout.flush() 43 | inp = inps[(i * bs):min((i + 1) * bs, len(inps))] 44 | inp = np.concatenate(inp, 0) 45 | pred = sess.run(softmax, {'ExpandDims:0': inp}) 46 | preds.append(pred) 47 | preds = np.concatenate(preds, 0) 48 | scores = [] 49 | for i in range(splits): 50 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] 51 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 52 | kl = np.mean(np.sum(kl, 1)) 53 | scores.append(np.exp(kl)) 54 | 55 | sess.close() 56 | return np.mean(scores), np.std(scores) 57 | 58 | 59 | # This function is called automatically. 60 | def _init_inception(): 61 | global softmax 62 | if not os.path.exists(MODEL_DIR): 63 | os.makedirs(MODEL_DIR) 64 | filename = DATA_URL.split('/')[-1] 65 | filepath = os.path.join(MODEL_DIR, filename) 66 | if not os.path.exists(filepath): 67 | def _progress(count, block_size, total_size): 68 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 69 | filename, float(count * block_size) / float(total_size) * 100.0)) 70 | sys.stdout.flush() 71 | 72 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 73 | print() 74 | statinfo = os.stat(filepath) 75 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') 76 | tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR) 77 | with tf.gfile.FastGFile(os.path.join( 78 | MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f: 79 | graph_def = tf.GraphDef() 80 | graph_def.ParseFromString(f.read()) 81 | _ = tf.import_graph_def(graph_def, name='') 82 | # Works with an arbitrary minibatch size. 83 | with tf.Session(config=config) as sess: 84 | pool3 = sess.graph.get_tensor_by_name('pool_3:0') 85 | ops = pool3.graph.get_operations() 86 | for op_idx, op in enumerate(ops): 87 | for o in op.outputs: 88 | shape = o.get_shape() 89 | if shape._dims != []: 90 | shape = [s.value for s in shape] 91 | new_shape = [] 92 | for j, s in enumerate(shape): 93 | if s == 1 and j == 0: 94 | new_shape.append(None) 95 | else: 96 | new_shape.append(s) 97 | o.__dict__['_shape_val'] = tf.TensorShape(new_shape) 98 | w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1] 99 | logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w) 100 | softmax = tf.nn.softmax(logits) 101 | sess.close() 102 | 103 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2019-07-25 3 | # @Author : Xinyu Gong (xy_gong@tamu.edu) 4 | # @Link : None 5 | # @Version : 0.0 6 | 7 | import os 8 | import torch 9 | import dateutil.tz 10 | from datetime import datetime 11 | import time 12 | import logging 13 | 14 | 15 | def create_logger(log_dir, phase='train'): 16 | time_str = time.strftime('%Y-%m-%d-%H-%M') 17 | log_file = '{}_{}.log'.format(time_str, phase) 18 | final_log_file = os.path.join(log_dir, log_file) 19 | head = '%(asctime)-15s %(message)s' 20 | logging.basicConfig(filename=str(final_log_file), 21 | format=head) 22 | logger = logging.getLogger() 23 | logger.setLevel(logging.INFO) 24 | console = logging.StreamHandler() 25 | logging.getLogger('').addHandler(console) 26 | 27 | return logger 28 | 29 | 30 | def set_log_dir(root_dir, exp_name): 31 | path_dict = {} 32 | os.makedirs(root_dir, exist_ok=True) 33 | 34 | # set log path 35 | exp_path = os.path.join(root_dir, exp_name) 36 | now = datetime.now(dateutil.tz.tzlocal()) 37 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 38 | prefix = exp_path + '_' + timestamp 39 | os.makedirs(prefix) 40 | path_dict['prefix'] = prefix 41 | 42 | # set checkpoint path 43 | ckpt_path = os.path.join(prefix, 'Model') 44 | os.makedirs(ckpt_path) 45 | path_dict['ckpt_path'] = ckpt_path 46 | 47 | log_path = os.path.join(prefix, 'Log') 48 | os.makedirs(log_path) 49 | path_dict['log_path'] = log_path 50 | 51 | # set sample image path for fid calculation 52 | sample_path = os.path.join(prefix, 'Samples') 53 | os.makedirs(sample_path) 54 | path_dict['sample_path'] = sample_path 55 | 56 | return path_dict 57 | 58 | 59 | def save_checkpoint(states, is_best, output_dir, 60 | filename='checkpoint.pth'): 61 | torch.save(states, os.path.join(output_dir, filename)) 62 | if is_best: 63 | torch.save(states, os.path.join(output_dir, 'checkpoint_best.pth')) 64 | --------------------------------------------------------------------------------