├── README.md ├── data_loader.py ├── download.sh ├── image ├── attn_gf1.png ├── attn_gf2.png ├── main_model.PNG ├── sagan_attn.png ├── sagan_celeb.png ├── sagan_lsun.png └── unnamed ├── main.py ├── parameter.py ├── sagan_models.py ├── spectral.py ├── trainer.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Self-Attention GAN 2 | **[Han Zhang, Ian Goodfellow, Dimitris Metaxas and Augustus Odena, "Self-Attention Generative Adversarial Networks." arXiv preprint arXiv:1805.08318 (2018)](https://arxiv.org/abs/1805.08318).** 3 | 4 | ## Meta overview 5 | This repository provides a PyTorch implementation of [SAGAN](https://arxiv.org/abs/1805.08318). Both wgan-gp and wgan-hinge loss are ready, but note that wgan-gp is somehow not compatible with the spectral normalization. Remove all the spectral normalization at the model for the adoption of wgan-gp. 6 | 7 | Self-attentions are applied to later two layers of both discriminator and generator. 8 | 9 |

10 | 11 | ## Current update status 12 | * [ ] Supervised setting 13 | * [ ] Tensorboard loggings 14 | * [x] **[20180608] updated the self-attention module. Thanks to my colleague [Cheonbok Park](https://github.com/cheonbok94)! see 'sagan_models.py' for the update. Should be efficient, and run on large sized images** 15 | * [x] Attention visualization (LSUN Church-outdoor) 16 | * [x] Unsupervised setting (use no label yet) 17 | * [x] Applied: [Spectral Normalization](https://arxiv.org/abs/1802.05957), code from [here](https://github.com/christiancosgrove/pytorch-spectral-normalization-gan) 18 | * [x] Implemented: self-attention module, two-timescale update rule (TTUR), wgan-hinge loss, wgan-gp loss 19 | 20 |   21 |   22 | 23 | ## Results 24 | 25 | ### Attention result on LSUN (epoch #8) 26 |

27 | Per-pixel attention result of SAGAN on LSUN church-outdoor dataset. It shows that unsupervised training of self-attention module still works, although it is not interpretable with the attention map itself. Better results with regard to the generated images will be added. These are the visualization of self-attention in generator layer3 and layer4, which are in the size of 16 x 16 and 32 x 32 respectively, each for 64 images. To visualize the per-pixel attentions, only a number of pixels are chosen, as shown on the leftmost and the rightmost numbers indicate. 28 | 29 | ### CelebA dataset (epoch on the left, still under training) 30 |

31 | 32 | ### LSUN church-outdoor dataset (epoch on the left, still under training) 33 |

34 | 35 | ## Prerequisites 36 | * [Python 3.5+](https://www.continuum.io/downloads) 37 | * [PyTorch 0.3.0](http://pytorch.org/) 38 | 39 |   40 | 41 | ## Usage 42 | 43 | #### 1. Clone the repository 44 | ```bash 45 | $ git clone https://github.com/heykeetae/Self-Attention-GAN.git 46 | $ cd Self-Attention-GAN 47 | ``` 48 | 49 | #### 2. Install datasets (CelebA or LSUN) 50 | ```bash 51 | $ bash download.sh CelebA 52 | or 53 | $ bash download.sh LSUN 54 | ``` 55 | 56 | 57 | #### 3. Train 58 | ##### (i) Train 59 | ```bash 60 | $ python python main.py --batch_size 64 --imsize 64 --dataset celeb --adv_loss hinge --version sagan_celeb 61 | or 62 | $ python python main.py --batch_size 64 --imsize 64 --dataset lsun --adv_loss hinge --version sagan_lsun 63 | ``` 64 | #### 4. Enjoy the results 65 | ```bash 66 | $ cd samples/sagan_celeb 67 | or 68 | $ cd samples/sagan_lsun 69 | 70 | ``` 71 | Samples generated every 100 iterations are located. The rate of sampling could be controlled via --sample_step (ex, --sample_step 100). 72 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.datasets as dsets 3 | from torchvision import transforms 4 | 5 | 6 | class Data_Loader(): 7 | def __init__(self, train, dataset, image_path, image_size, batch_size, shuf=True): 8 | self.dataset = dataset 9 | self.path = image_path 10 | self.imsize = image_size 11 | self.batch = batch_size 12 | self.shuf = shuf 13 | self.train = train 14 | 15 | def transform(self, resize, totensor, normalize, centercrop): 16 | options = [] 17 | if centercrop: 18 | options.append(transforms.CenterCrop(160)) 19 | if resize: 20 | options.append(transforms.Resize((self.imsize,self.imsize))) 21 | if totensor: 22 | options.append(transforms.ToTensor()) 23 | if normalize: 24 | options.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) 25 | transform = transforms.Compose(options) 26 | return transform 27 | 28 | def load_lsun(self, classes='church_outdoor_train'): 29 | transforms = self.transform(True, True, True, False) 30 | dataset = dsets.LSUN(self.path, classes=[classes], transform=transforms) 31 | return dataset 32 | 33 | def load_celeb(self): 34 | transforms = self.transform(True, True, True, True) 35 | dataset = dsets.ImageFolder(self.path+'/CelebA', transform=transforms) 36 | return dataset 37 | 38 | 39 | def loader(self): 40 | if self.dataset == 'lsun': 41 | dataset = self.load_lsun() 42 | elif self.dataset == 'celeb': 43 | dataset = self.load_celeb() 44 | 45 | loader = torch.utils.data.DataLoader(dataset=dataset, 46 | batch_size=self.batch, 47 | shuffle=self.shuf, 48 | num_workers=2, 49 | drop_last=True) 50 | return loader 51 | 52 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | FILE=$1 2 | 3 | if [ $FILE == 'CelebA' ] 4 | then 5 | URL=https://www.dropbox.com/s/3e5cmqgplchz85o/CelebA_nocrop.zip?dl=0 6 | ZIP_FILE=./data/CelebA.zip 7 | 8 | elif [ $FILE == 'LSUN' ] 9 | then 10 | URL=https://www.dropbox.com/s/zt7d2hchrw7cp9p/church_outdoor_train_lmdb.zip?dl=0 11 | ZIP_FILE=./data/church_outdoor_train_lmdb.zip 12 | else 13 | echo "Available datasets are: CelebA and LSUN" 14 | exit 1 15 | fi 16 | 17 | mkdir -p ./data/ 18 | wget -N $URL -O $ZIP_FILE 19 | unzip $ZIP_FILE -d ./data/ 20 | 21 | if [ $FILE == 'CelebA' ] 22 | then 23 | mv ./data/CelebA_nocrop ./data/CelebA 24 | fi 25 | 26 | rm $ZIP_FILE 27 | -------------------------------------------------------------------------------- /image/attn_gf1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heykeetae/Self-Attention-GAN/8714a54ba5027d680190791ba3a6bb08f9c9a129/image/attn_gf1.png -------------------------------------------------------------------------------- /image/attn_gf2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heykeetae/Self-Attention-GAN/8714a54ba5027d680190791ba3a6bb08f9c9a129/image/attn_gf2.png -------------------------------------------------------------------------------- /image/main_model.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heykeetae/Self-Attention-GAN/8714a54ba5027d680190791ba3a6bb08f9c9a129/image/main_model.PNG -------------------------------------------------------------------------------- /image/sagan_attn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heykeetae/Self-Attention-GAN/8714a54ba5027d680190791ba3a6bb08f9c9a129/image/sagan_attn.png -------------------------------------------------------------------------------- /image/sagan_celeb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heykeetae/Self-Attention-GAN/8714a54ba5027d680190791ba3a6bb08f9c9a129/image/sagan_celeb.png -------------------------------------------------------------------------------- /image/sagan_lsun.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heykeetae/Self-Attention-GAN/8714a54ba5027d680190791ba3a6bb08f9c9a129/image/sagan_lsun.png -------------------------------------------------------------------------------- /image/unnamed: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | from parameter import * 3 | from trainer import Trainer 4 | # from tester import Tester 5 | from data_loader import Data_Loader 6 | from torch.backends import cudnn 7 | from utils import make_folder 8 | 9 | def main(config): 10 | # For fast training 11 | cudnn.benchmark = True 12 | 13 | 14 | # Data loader 15 | data_loader = Data_Loader(config.train, config.dataset, config.image_path, config.imsize, 16 | config.batch_size, shuf=config.train) 17 | 18 | # Create directories if not exist 19 | make_folder(config.model_save_path, config.version) 20 | make_folder(config.sample_path, config.version) 21 | make_folder(config.log_path, config.version) 22 | make_folder(config.attn_path, config.version) 23 | 24 | 25 | if config.train: 26 | if config.model=='sagan': 27 | trainer = Trainer(data_loader.loader(), config) 28 | elif config.model == 'qgan': 29 | trainer = qgan_trainer(data_loader.loader(), config) 30 | trainer.train() 31 | else: 32 | tester = Tester(data_loader.loader(), config) 33 | tester.test() 34 | 35 | if __name__ == '__main__': 36 | config = get_parameters() 37 | print(config) 38 | main(config) -------------------------------------------------------------------------------- /parameter.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def str2bool(v): 4 | return v.lower() in ('true') 5 | 6 | def get_parameters(): 7 | 8 | parser = argparse.ArgumentParser() 9 | 10 | # Model hyper-parameters 11 | parser.add_argument('--model', type=str, default='sagan', choices=['sagan', 'qgan']) 12 | parser.add_argument('--adv_loss', type=str, default='wgan-gp', choices=['wgan-gp', 'hinge']) 13 | parser.add_argument('--imsize', type=int, default=32) 14 | parser.add_argument('--g_num', type=int, default=5) 15 | parser.add_argument('--z_dim', type=int, default=128) 16 | parser.add_argument('--g_conv_dim', type=int, default=64) 17 | parser.add_argument('--d_conv_dim', type=int, default=64) 18 | parser.add_argument('--lambda_gp', type=float, default=10) 19 | parser.add_argument('--version', type=str, default='sagan_1') 20 | 21 | # Training setting 22 | parser.add_argument('--total_step', type=int, default=1000000, help='how many times to update the generator') 23 | parser.add_argument('--d_iters', type=float, default=5) 24 | parser.add_argument('--batch_size', type=int, default=64) 25 | parser.add_argument('--num_workers', type=int, default=2) 26 | parser.add_argument('--g_lr', type=float, default=0.0001) 27 | parser.add_argument('--d_lr', type=float, default=0.0004) 28 | parser.add_argument('--lr_decay', type=float, default=0.95) 29 | parser.add_argument('--beta1', type=float, default=0.0) 30 | parser.add_argument('--beta2', type=float, default=0.9) 31 | 32 | # using pretrained 33 | parser.add_argument('--pretrained_model', type=int, default=None) 34 | 35 | # Misc 36 | parser.add_argument('--train', type=str2bool, default=True) 37 | parser.add_argument('--parallel', type=str2bool, default=False) 38 | parser.add_argument('--dataset', type=str, default='cifar', choices=['lsun', 'celeb']) 39 | parser.add_argument('--use_tensorboard', type=str2bool, default=False) 40 | 41 | # Path 42 | parser.add_argument('--image_path', type=str, default='./data') 43 | parser.add_argument('--log_path', type=str, default='./logs') 44 | parser.add_argument('--model_save_path', type=str, default='./models') 45 | parser.add_argument('--sample_path', type=str, default='./samples') 46 | parser.add_argument('--attn_path', type=str, default='./attn') 47 | 48 | # Step size 49 | parser.add_argument('--log_step', type=int, default=10) 50 | parser.add_argument('--sample_step', type=int, default=100) 51 | parser.add_argument('--model_save_step', type=float, default=1.0) 52 | 53 | 54 | return parser.parse_args() -------------------------------------------------------------------------------- /sagan_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from spectral import SpectralNorm 6 | import numpy as np 7 | 8 | class Self_Attn(nn.Module): 9 | """ Self attention Layer""" 10 | def __init__(self,in_dim,activation): 11 | super(Self_Attn,self).__init__() 12 | self.chanel_in = in_dim 13 | self.activation = activation 14 | 15 | self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1) 16 | self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1) 17 | self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1) 18 | self.gamma = nn.Parameter(torch.zeros(1)) 19 | 20 | self.softmax = nn.Softmax(dim=-1) # 21 | def forward(self,x): 22 | """ 23 | inputs : 24 | x : input feature maps( B X C X W X H) 25 | returns : 26 | out : self attention value + input feature 27 | attention: B X N X N (N is Width*Height) 28 | """ 29 | m_batchsize,C,width ,height = x.size() 30 | proj_query = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N) 31 | proj_key = self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H) 32 | energy = torch.bmm(proj_query,proj_key) # transpose check 33 | attention = self.softmax(energy) # BX (N) X (N) 34 | proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N 35 | 36 | out = torch.bmm(proj_value,attention.permute(0,2,1) ) 37 | out = out.view(m_batchsize,C,width,height) 38 | 39 | out = self.gamma*out + x 40 | return out,attention 41 | 42 | class Generator(nn.Module): 43 | """Generator.""" 44 | 45 | def __init__(self, batch_size, image_size=64, z_dim=100, conv_dim=64): 46 | super(Generator, self).__init__() 47 | self.imsize = image_size 48 | layer1 = [] 49 | layer2 = [] 50 | layer3 = [] 51 | last = [] 52 | 53 | repeat_num = int(np.log2(self.imsize)) - 3 54 | mult = 2 ** repeat_num # 8 55 | layer1.append(SpectralNorm(nn.ConvTranspose2d(z_dim, conv_dim * mult, 4))) 56 | layer1.append(nn.BatchNorm2d(conv_dim * mult)) 57 | layer1.append(nn.ReLU()) 58 | 59 | curr_dim = conv_dim * mult 60 | 61 | layer2.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1))) 62 | layer2.append(nn.BatchNorm2d(int(curr_dim / 2))) 63 | layer2.append(nn.ReLU()) 64 | 65 | curr_dim = int(curr_dim / 2) 66 | 67 | layer3.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1))) 68 | layer3.append(nn.BatchNorm2d(int(curr_dim / 2))) 69 | layer3.append(nn.ReLU()) 70 | 71 | if self.imsize == 64: 72 | layer4 = [] 73 | curr_dim = int(curr_dim / 2) 74 | layer4.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1))) 75 | layer4.append(nn.BatchNorm2d(int(curr_dim / 2))) 76 | layer4.append(nn.ReLU()) 77 | self.l4 = nn.Sequential(*layer4) 78 | curr_dim = int(curr_dim / 2) 79 | 80 | self.l1 = nn.Sequential(*layer1) 81 | self.l2 = nn.Sequential(*layer2) 82 | self.l3 = nn.Sequential(*layer3) 83 | 84 | last.append(nn.ConvTranspose2d(curr_dim, 3, 4, 2, 1)) 85 | last.append(nn.Tanh()) 86 | self.last = nn.Sequential(*last) 87 | 88 | self.attn1 = Self_Attn( 128, 'relu') 89 | self.attn2 = Self_Attn( 64, 'relu') 90 | 91 | def forward(self, z): 92 | z = z.view(z.size(0), z.size(1), 1, 1) 93 | out=self.l1(z) 94 | out=self.l2(out) 95 | out=self.l3(out) 96 | out,p1 = self.attn1(out) 97 | out=self.l4(out) 98 | out,p2 = self.attn2(out) 99 | out=self.last(out) 100 | 101 | return out, p1, p2 102 | 103 | 104 | class Discriminator(nn.Module): 105 | """Discriminator, Auxiliary Classifier.""" 106 | 107 | def __init__(self, batch_size=64, image_size=64, conv_dim=64): 108 | super(Discriminator, self).__init__() 109 | self.imsize = image_size 110 | layer1 = [] 111 | layer2 = [] 112 | layer3 = [] 113 | last = [] 114 | 115 | layer1.append(SpectralNorm(nn.Conv2d(3, conv_dim, 4, 2, 1))) 116 | layer1.append(nn.LeakyReLU(0.1)) 117 | 118 | curr_dim = conv_dim 119 | 120 | layer2.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1))) 121 | layer2.append(nn.LeakyReLU(0.1)) 122 | curr_dim = curr_dim * 2 123 | 124 | layer3.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1))) 125 | layer3.append(nn.LeakyReLU(0.1)) 126 | curr_dim = curr_dim * 2 127 | 128 | if self.imsize == 64: 129 | layer4 = [] 130 | layer4.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1))) 131 | layer4.append(nn.LeakyReLU(0.1)) 132 | self.l4 = nn.Sequential(*layer4) 133 | curr_dim = curr_dim*2 134 | self.l1 = nn.Sequential(*layer1) 135 | self.l2 = nn.Sequential(*layer2) 136 | self.l3 = nn.Sequential(*layer3) 137 | 138 | last.append(nn.Conv2d(curr_dim, 1, 4)) 139 | self.last = nn.Sequential(*last) 140 | 141 | self.attn1 = Self_Attn(256, 'relu') 142 | self.attn2 = Self_Attn(512, 'relu') 143 | 144 | def forward(self, x): 145 | out = self.l1(x) 146 | out = self.l2(out) 147 | out = self.l3(out) 148 | out,p1 = self.attn1(out) 149 | out=self.l4(out) 150 | out,p2 = self.attn2(out) 151 | out=self.last(out) 152 | 153 | return out.squeeze(), p1, p2 154 | -------------------------------------------------------------------------------- /spectral.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer, required 3 | 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | from torch import nn 7 | from torch import Tensor 8 | from torch.nn import Parameter 9 | 10 | def l2normalize(v, eps=1e-12): 11 | return v / (v.norm() + eps) 12 | 13 | 14 | class SpectralNorm(nn.Module): 15 | def __init__(self, module, name='weight', power_iterations=1): 16 | super(SpectralNorm, self).__init__() 17 | self.module = module 18 | self.name = name 19 | self.power_iterations = power_iterations 20 | if not self._made_params(): 21 | self._make_params() 22 | 23 | def _update_u_v(self): 24 | u = getattr(self.module, self.name + "_u") 25 | v = getattr(self.module, self.name + "_v") 26 | w = getattr(self.module, self.name + "_bar") 27 | 28 | height = w.data.shape[0] 29 | for _ in range(self.power_iterations): 30 | v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data)) 31 | u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data)) 32 | 33 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) 34 | sigma = u.dot(w.view(height, -1).mv(v)) 35 | setattr(self.module, self.name, w / sigma.expand_as(w)) 36 | 37 | def _made_params(self): 38 | try: 39 | u = getattr(self.module, self.name + "_u") 40 | v = getattr(self.module, self.name + "_v") 41 | w = getattr(self.module, self.name + "_bar") 42 | return True 43 | except AttributeError: 44 | return False 45 | 46 | 47 | def _make_params(self): 48 | w = getattr(self.module, self.name) 49 | 50 | height = w.data.shape[0] 51 | width = w.view(height, -1).data.shape[1] 52 | 53 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 54 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 55 | u.data = l2normalize(u.data) 56 | v.data = l2normalize(v.data) 57 | w_bar = Parameter(w.data) 58 | 59 | del self.module._parameters[self.name] 60 | 61 | self.module.register_parameter(self.name + "_u", u) 62 | self.module.register_parameter(self.name + "_v", v) 63 | self.module.register_parameter(self.name + "_bar", w_bar) 64 | 65 | 66 | def forward(self, *args): 67 | self._update_u_v() 68 | return self.module.forward(*args) -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import time 4 | import torch 5 | import datetime 6 | 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | from torchvision.utils import save_image 10 | 11 | from sagan_models import Generator, Discriminator 12 | from utils import * 13 | 14 | class Trainer(object): 15 | def __init__(self, data_loader, config): 16 | 17 | # Data loader 18 | self.data_loader = data_loader 19 | 20 | # exact model and loss 21 | self.model = config.model 22 | self.adv_loss = config.adv_loss 23 | 24 | # Model hyper-parameters 25 | self.imsize = config.imsize 26 | self.g_num = config.g_num 27 | self.z_dim = config.z_dim 28 | self.g_conv_dim = config.g_conv_dim 29 | self.d_conv_dim = config.d_conv_dim 30 | self.parallel = config.parallel 31 | 32 | self.lambda_gp = config.lambda_gp 33 | self.total_step = config.total_step 34 | self.d_iters = config.d_iters 35 | self.batch_size = config.batch_size 36 | self.num_workers = config.num_workers 37 | self.g_lr = config.g_lr 38 | self.d_lr = config.d_lr 39 | self.lr_decay = config.lr_decay 40 | self.beta1 = config.beta1 41 | self.beta2 = config.beta2 42 | self.pretrained_model = config.pretrained_model 43 | 44 | self.dataset = config.dataset 45 | self.use_tensorboard = config.use_tensorboard 46 | self.image_path = config.image_path 47 | self.log_path = config.log_path 48 | self.model_save_path = config.model_save_path 49 | self.sample_path = config.sample_path 50 | self.log_step = config.log_step 51 | self.sample_step = config.sample_step 52 | self.model_save_step = config.model_save_step 53 | self.version = config.version 54 | 55 | # Path 56 | self.log_path = os.path.join(config.log_path, self.version) 57 | self.sample_path = os.path.join(config.sample_path, self.version) 58 | self.model_save_path = os.path.join(config.model_save_path, self.version) 59 | 60 | self.build_model() 61 | 62 | if self.use_tensorboard: 63 | self.build_tensorboard() 64 | 65 | # Start with trained model 66 | if self.pretrained_model: 67 | self.load_pretrained_model() 68 | 69 | 70 | 71 | def train(self): 72 | 73 | # Data iterator 74 | data_iter = iter(self.data_loader) 75 | step_per_epoch = len(self.data_loader) 76 | model_save_step = int(self.model_save_step * step_per_epoch) 77 | 78 | # Fixed input for debugging 79 | fixed_z = tensor2var(torch.randn(self.batch_size, self.z_dim)) 80 | 81 | # Start with trained model 82 | if self.pretrained_model: 83 | start = self.pretrained_model + 1 84 | else: 85 | start = 0 86 | 87 | # Start time 88 | start_time = time.time() 89 | for step in range(start, self.total_step): 90 | 91 | # ================== Train D ================== # 92 | self.D.train() 93 | self.G.train() 94 | 95 | try: 96 | real_images, _ = next(data_iter) 97 | except: 98 | data_iter = iter(self.data_loader) 99 | real_images, _ = next(data_iter) 100 | 101 | # Compute loss with real images 102 | # dr1, dr2, df1, df2, gf1, gf2 are attention scores 103 | real_images = tensor2var(real_images) 104 | d_out_real,dr1,dr2 = self.D(real_images) 105 | if self.adv_loss == 'wgan-gp': 106 | d_loss_real = - torch.mean(d_out_real) 107 | elif self.adv_loss == 'hinge': 108 | d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean() 109 | 110 | # apply Gumbel Softmax 111 | z = tensor2var(torch.randn(real_images.size(0), self.z_dim)) 112 | fake_images,gf1,gf2 = self.G(z) 113 | d_out_fake,df1,df2 = self.D(fake_images) 114 | 115 | if self.adv_loss == 'wgan-gp': 116 | d_loss_fake = d_out_fake.mean() 117 | elif self.adv_loss == 'hinge': 118 | d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean() 119 | 120 | 121 | # Backward + Optimize 122 | d_loss = d_loss_real + d_loss_fake 123 | self.reset_grad() 124 | d_loss.backward() 125 | self.d_optimizer.step() 126 | 127 | 128 | if self.adv_loss == 'wgan-gp': 129 | # Compute gradient penalty 130 | alpha = torch.rand(real_images.size(0), 1, 1, 1).cuda().expand_as(real_images) 131 | interpolated = Variable(alpha * real_images.data + (1 - alpha) * fake_images.data, requires_grad=True) 132 | out,_,_ = self.D(interpolated) 133 | 134 | grad = torch.autograd.grad(outputs=out, 135 | inputs=interpolated, 136 | grad_outputs=torch.ones(out.size()).cuda(), 137 | retain_graph=True, 138 | create_graph=True, 139 | only_inputs=True)[0] 140 | 141 | grad = grad.view(grad.size(0), -1) 142 | grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1)) 143 | d_loss_gp = torch.mean((grad_l2norm - 1) ** 2) 144 | 145 | # Backward + Optimize 146 | d_loss = self.lambda_gp * d_loss_gp 147 | 148 | self.reset_grad() 149 | d_loss.backward() 150 | self.d_optimizer.step() 151 | 152 | # ================== Train G and gumbel ================== # 153 | # Create random noise 154 | z = tensor2var(torch.randn(real_images.size(0), self.z_dim)) 155 | fake_images,_,_ = self.G(z) 156 | 157 | # Compute loss with fake images 158 | g_out_fake,_,_ = self.D(fake_images) # batch x n 159 | if self.adv_loss == 'wgan-gp': 160 | g_loss_fake = - g_out_fake.mean() 161 | elif self.adv_loss == 'hinge': 162 | g_loss_fake = - g_out_fake.mean() 163 | 164 | self.reset_grad() 165 | g_loss_fake.backward() 166 | self.g_optimizer.step() 167 | 168 | 169 | # Print out log info 170 | if (step + 1) % self.log_step == 0: 171 | elapsed = time.time() - start_time 172 | elapsed = str(datetime.timedelta(seconds=elapsed)) 173 | print("Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_out_real: {:.4f}, " 174 | " ave_gamma_l3: {:.4f}, ave_gamma_l4: {:.4f}". 175 | format(elapsed, step + 1, self.total_step, (step + 1), 176 | self.total_step , d_loss_real.data[0], 177 | self.G.attn1.gamma.mean().data[0], self.G.attn2.gamma.mean().data[0] )) 178 | 179 | # Sample images 180 | if (step + 1) % self.sample_step == 0: 181 | fake_images,_,_= self.G(fixed_z) 182 | save_image(denorm(fake_images.data), 183 | os.path.join(self.sample_path, '{}_fake.png'.format(step + 1))) 184 | 185 | if (step+1) % model_save_step==0: 186 | torch.save(self.G.state_dict(), 187 | os.path.join(self.model_save_path, '{}_G.pth'.format(step + 1))) 188 | torch.save(self.D.state_dict(), 189 | os.path.join(self.model_save_path, '{}_D.pth'.format(step + 1))) 190 | 191 | def build_model(self): 192 | 193 | self.G = Generator(self.batch_size,self.imsize, self.z_dim, self.g_conv_dim).cuda() 194 | self.D = Discriminator(self.batch_size,self.imsize, self.d_conv_dim).cuda() 195 | if self.parallel: 196 | self.G = nn.DataParallel(self.G) 197 | self.D = nn.DataParallel(self.D) 198 | 199 | # Loss and optimizer 200 | # self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) 201 | self.g_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr, [self.beta1, self.beta2]) 202 | self.d_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.beta1, self.beta2]) 203 | 204 | self.c_loss = torch.nn.CrossEntropyLoss() 205 | # print networks 206 | print(self.G) 207 | print(self.D) 208 | 209 | def build_tensorboard(self): 210 | from logger import Logger 211 | self.logger = Logger(self.log_path) 212 | 213 | def load_pretrained_model(self): 214 | self.G.load_state_dict(torch.load(os.path.join( 215 | self.model_save_path, '{}_G.pth'.format(self.pretrained_model)))) 216 | self.D.load_state_dict(torch.load(os.path.join( 217 | self.model_save_path, '{}_D.pth'.format(self.pretrained_model)))) 218 | print('loaded trained models (step: {})..!'.format(self.pretrained_model)) 219 | 220 | def reset_grad(self): 221 | self.d_optimizer.zero_grad() 222 | self.g_optimizer.zero_grad() 223 | 224 | def save_sample(self, data_iter): 225 | real_images, _ = next(data_iter) 226 | save_image(denorm(real_images), os.path.join(self.sample_path, 'real.png')) 227 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.autograd import Variable 4 | 5 | 6 | def make_folder(path, version): 7 | if not os.path.exists(os.path.join(path, version)): 8 | os.makedirs(os.path.join(path, version)) 9 | 10 | 11 | def tensor2var(x, grad=False): 12 | if torch.cuda.is_available(): 13 | x = x.cuda() 14 | return Variable(x, requires_grad=grad) 15 | 16 | def var2tensor(x): 17 | return x.data.cpu() 18 | 19 | def var2numpy(x): 20 | return x.data.cpu().numpy() 21 | 22 | def denorm(x): 23 | out = (x + 1) / 2 24 | return out.clamp_(0, 1) 25 | 26 | --------------------------------------------------------------------------------