├── README.md
├── data_loader.py
├── download.sh
├── image
├── attn_gf1.png
├── attn_gf2.png
├── main_model.PNG
├── sagan_attn.png
├── sagan_celeb.png
├── sagan_lsun.png
└── unnamed
├── main.py
├── parameter.py
├── sagan_models.py
├── spectral.py
├── trainer.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Self-Attention GAN
2 | **[Han Zhang, Ian Goodfellow, Dimitris Metaxas and Augustus Odena, "Self-Attention Generative Adversarial Networks." arXiv preprint arXiv:1805.08318 (2018)](https://arxiv.org/abs/1805.08318).**
3 |
4 | ## Meta overview
5 | This repository provides a PyTorch implementation of [SAGAN](https://arxiv.org/abs/1805.08318). Both wgan-gp and wgan-hinge loss are ready, but note that wgan-gp is somehow not compatible with the spectral normalization. Remove all the spectral normalization at the model for the adoption of wgan-gp.
6 |
7 | Self-attentions are applied to later two layers of both discriminator and generator.
8 |
9 |

10 |
11 | ## Current update status
12 | * [ ] Supervised setting
13 | * [ ] Tensorboard loggings
14 | * [x] **[20180608] updated the self-attention module. Thanks to my colleague [Cheonbok Park](https://github.com/cheonbok94)! see 'sagan_models.py' for the update. Should be efficient, and run on large sized images**
15 | * [x] Attention visualization (LSUN Church-outdoor)
16 | * [x] Unsupervised setting (use no label yet)
17 | * [x] Applied: [Spectral Normalization](https://arxiv.org/abs/1802.05957), code from [here](https://github.com/christiancosgrove/pytorch-spectral-normalization-gan)
18 | * [x] Implemented: self-attention module, two-timescale update rule (TTUR), wgan-hinge loss, wgan-gp loss
19 |
20 |
21 |
22 |
23 | ## Results
24 |
25 | ### Attention result on LSUN (epoch #8)
26 | 
27 | Per-pixel attention result of SAGAN on LSUN church-outdoor dataset. It shows that unsupervised training of self-attention module still works, although it is not interpretable with the attention map itself. Better results with regard to the generated images will be added. These are the visualization of self-attention in generator layer3 and layer4, which are in the size of 16 x 16 and 32 x 32 respectively, each for 64 images. To visualize the per-pixel attentions, only a number of pixels are chosen, as shown on the leftmost and the rightmost numbers indicate.
28 |
29 | ### CelebA dataset (epoch on the left, still under training)
30 | 
31 |
32 | ### LSUN church-outdoor dataset (epoch on the left, still under training)
33 | 
34 |
35 | ## Prerequisites
36 | * [Python 3.5+](https://www.continuum.io/downloads)
37 | * [PyTorch 0.3.0](http://pytorch.org/)
38 |
39 |
40 |
41 | ## Usage
42 |
43 | #### 1. Clone the repository
44 | ```bash
45 | $ git clone https://github.com/heykeetae/Self-Attention-GAN.git
46 | $ cd Self-Attention-GAN
47 | ```
48 |
49 | #### 2. Install datasets (CelebA or LSUN)
50 | ```bash
51 | $ bash download.sh CelebA
52 | or
53 | $ bash download.sh LSUN
54 | ```
55 |
56 |
57 | #### 3. Train
58 | ##### (i) Train
59 | ```bash
60 | $ python python main.py --batch_size 64 --imsize 64 --dataset celeb --adv_loss hinge --version sagan_celeb
61 | or
62 | $ python python main.py --batch_size 64 --imsize 64 --dataset lsun --adv_loss hinge --version sagan_lsun
63 | ```
64 | #### 4. Enjoy the results
65 | ```bash
66 | $ cd samples/sagan_celeb
67 | or
68 | $ cd samples/sagan_lsun
69 |
70 | ```
71 | Samples generated every 100 iterations are located. The rate of sampling could be controlled via --sample_step (ex, --sample_step 100).
72 |
--------------------------------------------------------------------------------
/data_loader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.datasets as dsets
3 | from torchvision import transforms
4 |
5 |
6 | class Data_Loader():
7 | def __init__(self, train, dataset, image_path, image_size, batch_size, shuf=True):
8 | self.dataset = dataset
9 | self.path = image_path
10 | self.imsize = image_size
11 | self.batch = batch_size
12 | self.shuf = shuf
13 | self.train = train
14 |
15 | def transform(self, resize, totensor, normalize, centercrop):
16 | options = []
17 | if centercrop:
18 | options.append(transforms.CenterCrop(160))
19 | if resize:
20 | options.append(transforms.Resize((self.imsize,self.imsize)))
21 | if totensor:
22 | options.append(transforms.ToTensor())
23 | if normalize:
24 | options.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
25 | transform = transforms.Compose(options)
26 | return transform
27 |
28 | def load_lsun(self, classes='church_outdoor_train'):
29 | transforms = self.transform(True, True, True, False)
30 | dataset = dsets.LSUN(self.path, classes=[classes], transform=transforms)
31 | return dataset
32 |
33 | def load_celeb(self):
34 | transforms = self.transform(True, True, True, True)
35 | dataset = dsets.ImageFolder(self.path+'/CelebA', transform=transforms)
36 | return dataset
37 |
38 |
39 | def loader(self):
40 | if self.dataset == 'lsun':
41 | dataset = self.load_lsun()
42 | elif self.dataset == 'celeb':
43 | dataset = self.load_celeb()
44 |
45 | loader = torch.utils.data.DataLoader(dataset=dataset,
46 | batch_size=self.batch,
47 | shuffle=self.shuf,
48 | num_workers=2,
49 | drop_last=True)
50 | return loader
51 |
52 |
--------------------------------------------------------------------------------
/download.sh:
--------------------------------------------------------------------------------
1 | FILE=$1
2 |
3 | if [ $FILE == 'CelebA' ]
4 | then
5 | URL=https://www.dropbox.com/s/3e5cmqgplchz85o/CelebA_nocrop.zip?dl=0
6 | ZIP_FILE=./data/CelebA.zip
7 |
8 | elif [ $FILE == 'LSUN' ]
9 | then
10 | URL=https://www.dropbox.com/s/zt7d2hchrw7cp9p/church_outdoor_train_lmdb.zip?dl=0
11 | ZIP_FILE=./data/church_outdoor_train_lmdb.zip
12 | else
13 | echo "Available datasets are: CelebA and LSUN"
14 | exit 1
15 | fi
16 |
17 | mkdir -p ./data/
18 | wget -N $URL -O $ZIP_FILE
19 | unzip $ZIP_FILE -d ./data/
20 |
21 | if [ $FILE == 'CelebA' ]
22 | then
23 | mv ./data/CelebA_nocrop ./data/CelebA
24 | fi
25 |
26 | rm $ZIP_FILE
27 |
--------------------------------------------------------------------------------
/image/attn_gf1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heykeetae/Self-Attention-GAN/8714a54ba5027d680190791ba3a6bb08f9c9a129/image/attn_gf1.png
--------------------------------------------------------------------------------
/image/attn_gf2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heykeetae/Self-Attention-GAN/8714a54ba5027d680190791ba3a6bb08f9c9a129/image/attn_gf2.png
--------------------------------------------------------------------------------
/image/main_model.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heykeetae/Self-Attention-GAN/8714a54ba5027d680190791ba3a6bb08f9c9a129/image/main_model.PNG
--------------------------------------------------------------------------------
/image/sagan_attn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heykeetae/Self-Attention-GAN/8714a54ba5027d680190791ba3a6bb08f9c9a129/image/sagan_attn.png
--------------------------------------------------------------------------------
/image/sagan_celeb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heykeetae/Self-Attention-GAN/8714a54ba5027d680190791ba3a6bb08f9c9a129/image/sagan_celeb.png
--------------------------------------------------------------------------------
/image/sagan_lsun.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heykeetae/Self-Attention-GAN/8714a54ba5027d680190791ba3a6bb08f9c9a129/image/sagan_lsun.png
--------------------------------------------------------------------------------
/image/unnamed:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 |
2 | from parameter import *
3 | from trainer import Trainer
4 | # from tester import Tester
5 | from data_loader import Data_Loader
6 | from torch.backends import cudnn
7 | from utils import make_folder
8 |
9 | def main(config):
10 | # For fast training
11 | cudnn.benchmark = True
12 |
13 |
14 | # Data loader
15 | data_loader = Data_Loader(config.train, config.dataset, config.image_path, config.imsize,
16 | config.batch_size, shuf=config.train)
17 |
18 | # Create directories if not exist
19 | make_folder(config.model_save_path, config.version)
20 | make_folder(config.sample_path, config.version)
21 | make_folder(config.log_path, config.version)
22 | make_folder(config.attn_path, config.version)
23 |
24 |
25 | if config.train:
26 | if config.model=='sagan':
27 | trainer = Trainer(data_loader.loader(), config)
28 | elif config.model == 'qgan':
29 | trainer = qgan_trainer(data_loader.loader(), config)
30 | trainer.train()
31 | else:
32 | tester = Tester(data_loader.loader(), config)
33 | tester.test()
34 |
35 | if __name__ == '__main__':
36 | config = get_parameters()
37 | print(config)
38 | main(config)
--------------------------------------------------------------------------------
/parameter.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def str2bool(v):
4 | return v.lower() in ('true')
5 |
6 | def get_parameters():
7 |
8 | parser = argparse.ArgumentParser()
9 |
10 | # Model hyper-parameters
11 | parser.add_argument('--model', type=str, default='sagan', choices=['sagan', 'qgan'])
12 | parser.add_argument('--adv_loss', type=str, default='wgan-gp', choices=['wgan-gp', 'hinge'])
13 | parser.add_argument('--imsize', type=int, default=32)
14 | parser.add_argument('--g_num', type=int, default=5)
15 | parser.add_argument('--z_dim', type=int, default=128)
16 | parser.add_argument('--g_conv_dim', type=int, default=64)
17 | parser.add_argument('--d_conv_dim', type=int, default=64)
18 | parser.add_argument('--lambda_gp', type=float, default=10)
19 | parser.add_argument('--version', type=str, default='sagan_1')
20 |
21 | # Training setting
22 | parser.add_argument('--total_step', type=int, default=1000000, help='how many times to update the generator')
23 | parser.add_argument('--d_iters', type=float, default=5)
24 | parser.add_argument('--batch_size', type=int, default=64)
25 | parser.add_argument('--num_workers', type=int, default=2)
26 | parser.add_argument('--g_lr', type=float, default=0.0001)
27 | parser.add_argument('--d_lr', type=float, default=0.0004)
28 | parser.add_argument('--lr_decay', type=float, default=0.95)
29 | parser.add_argument('--beta1', type=float, default=0.0)
30 | parser.add_argument('--beta2', type=float, default=0.9)
31 |
32 | # using pretrained
33 | parser.add_argument('--pretrained_model', type=int, default=None)
34 |
35 | # Misc
36 | parser.add_argument('--train', type=str2bool, default=True)
37 | parser.add_argument('--parallel', type=str2bool, default=False)
38 | parser.add_argument('--dataset', type=str, default='cifar', choices=['lsun', 'celeb'])
39 | parser.add_argument('--use_tensorboard', type=str2bool, default=False)
40 |
41 | # Path
42 | parser.add_argument('--image_path', type=str, default='./data')
43 | parser.add_argument('--log_path', type=str, default='./logs')
44 | parser.add_argument('--model_save_path', type=str, default='./models')
45 | parser.add_argument('--sample_path', type=str, default='./samples')
46 | parser.add_argument('--attn_path', type=str, default='./attn')
47 |
48 | # Step size
49 | parser.add_argument('--log_step', type=int, default=10)
50 | parser.add_argument('--sample_step', type=int, default=100)
51 | parser.add_argument('--model_save_step', type=float, default=1.0)
52 |
53 |
54 | return parser.parse_args()
--------------------------------------------------------------------------------
/sagan_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from spectral import SpectralNorm
6 | import numpy as np
7 |
8 | class Self_Attn(nn.Module):
9 | """ Self attention Layer"""
10 | def __init__(self,in_dim,activation):
11 | super(Self_Attn,self).__init__()
12 | self.chanel_in = in_dim
13 | self.activation = activation
14 |
15 | self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
16 | self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
17 | self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
18 | self.gamma = nn.Parameter(torch.zeros(1))
19 |
20 | self.softmax = nn.Softmax(dim=-1) #
21 | def forward(self,x):
22 | """
23 | inputs :
24 | x : input feature maps( B X C X W X H)
25 | returns :
26 | out : self attention value + input feature
27 | attention: B X N X N (N is Width*Height)
28 | """
29 | m_batchsize,C,width ,height = x.size()
30 | proj_query = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
31 | proj_key = self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
32 | energy = torch.bmm(proj_query,proj_key) # transpose check
33 | attention = self.softmax(energy) # BX (N) X (N)
34 | proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N
35 |
36 | out = torch.bmm(proj_value,attention.permute(0,2,1) )
37 | out = out.view(m_batchsize,C,width,height)
38 |
39 | out = self.gamma*out + x
40 | return out,attention
41 |
42 | class Generator(nn.Module):
43 | """Generator."""
44 |
45 | def __init__(self, batch_size, image_size=64, z_dim=100, conv_dim=64):
46 | super(Generator, self).__init__()
47 | self.imsize = image_size
48 | layer1 = []
49 | layer2 = []
50 | layer3 = []
51 | last = []
52 |
53 | repeat_num = int(np.log2(self.imsize)) - 3
54 | mult = 2 ** repeat_num # 8
55 | layer1.append(SpectralNorm(nn.ConvTranspose2d(z_dim, conv_dim * mult, 4)))
56 | layer1.append(nn.BatchNorm2d(conv_dim * mult))
57 | layer1.append(nn.ReLU())
58 |
59 | curr_dim = conv_dim * mult
60 |
61 | layer2.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
62 | layer2.append(nn.BatchNorm2d(int(curr_dim / 2)))
63 | layer2.append(nn.ReLU())
64 |
65 | curr_dim = int(curr_dim / 2)
66 |
67 | layer3.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
68 | layer3.append(nn.BatchNorm2d(int(curr_dim / 2)))
69 | layer3.append(nn.ReLU())
70 |
71 | if self.imsize == 64:
72 | layer4 = []
73 | curr_dim = int(curr_dim / 2)
74 | layer4.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
75 | layer4.append(nn.BatchNorm2d(int(curr_dim / 2)))
76 | layer4.append(nn.ReLU())
77 | self.l4 = nn.Sequential(*layer4)
78 | curr_dim = int(curr_dim / 2)
79 |
80 | self.l1 = nn.Sequential(*layer1)
81 | self.l2 = nn.Sequential(*layer2)
82 | self.l3 = nn.Sequential(*layer3)
83 |
84 | last.append(nn.ConvTranspose2d(curr_dim, 3, 4, 2, 1))
85 | last.append(nn.Tanh())
86 | self.last = nn.Sequential(*last)
87 |
88 | self.attn1 = Self_Attn( 128, 'relu')
89 | self.attn2 = Self_Attn( 64, 'relu')
90 |
91 | def forward(self, z):
92 | z = z.view(z.size(0), z.size(1), 1, 1)
93 | out=self.l1(z)
94 | out=self.l2(out)
95 | out=self.l3(out)
96 | out,p1 = self.attn1(out)
97 | out=self.l4(out)
98 | out,p2 = self.attn2(out)
99 | out=self.last(out)
100 |
101 | return out, p1, p2
102 |
103 |
104 | class Discriminator(nn.Module):
105 | """Discriminator, Auxiliary Classifier."""
106 |
107 | def __init__(self, batch_size=64, image_size=64, conv_dim=64):
108 | super(Discriminator, self).__init__()
109 | self.imsize = image_size
110 | layer1 = []
111 | layer2 = []
112 | layer3 = []
113 | last = []
114 |
115 | layer1.append(SpectralNorm(nn.Conv2d(3, conv_dim, 4, 2, 1)))
116 | layer1.append(nn.LeakyReLU(0.1))
117 |
118 | curr_dim = conv_dim
119 |
120 | layer2.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
121 | layer2.append(nn.LeakyReLU(0.1))
122 | curr_dim = curr_dim * 2
123 |
124 | layer3.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
125 | layer3.append(nn.LeakyReLU(0.1))
126 | curr_dim = curr_dim * 2
127 |
128 | if self.imsize == 64:
129 | layer4 = []
130 | layer4.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
131 | layer4.append(nn.LeakyReLU(0.1))
132 | self.l4 = nn.Sequential(*layer4)
133 | curr_dim = curr_dim*2
134 | self.l1 = nn.Sequential(*layer1)
135 | self.l2 = nn.Sequential(*layer2)
136 | self.l3 = nn.Sequential(*layer3)
137 |
138 | last.append(nn.Conv2d(curr_dim, 1, 4))
139 | self.last = nn.Sequential(*last)
140 |
141 | self.attn1 = Self_Attn(256, 'relu')
142 | self.attn2 = Self_Attn(512, 'relu')
143 |
144 | def forward(self, x):
145 | out = self.l1(x)
146 | out = self.l2(out)
147 | out = self.l3(out)
148 | out,p1 = self.attn1(out)
149 | out=self.l4(out)
150 | out,p2 = self.attn2(out)
151 | out=self.last(out)
152 |
153 | return out.squeeze(), p1, p2
154 |
--------------------------------------------------------------------------------
/spectral.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.optimizer import Optimizer, required
3 |
4 | from torch.autograd import Variable
5 | import torch.nn.functional as F
6 | from torch import nn
7 | from torch import Tensor
8 | from torch.nn import Parameter
9 |
10 | def l2normalize(v, eps=1e-12):
11 | return v / (v.norm() + eps)
12 |
13 |
14 | class SpectralNorm(nn.Module):
15 | def __init__(self, module, name='weight', power_iterations=1):
16 | super(SpectralNorm, self).__init__()
17 | self.module = module
18 | self.name = name
19 | self.power_iterations = power_iterations
20 | if not self._made_params():
21 | self._make_params()
22 |
23 | def _update_u_v(self):
24 | u = getattr(self.module, self.name + "_u")
25 | v = getattr(self.module, self.name + "_v")
26 | w = getattr(self.module, self.name + "_bar")
27 |
28 | height = w.data.shape[0]
29 | for _ in range(self.power_iterations):
30 | v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
31 | u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
32 |
33 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
34 | sigma = u.dot(w.view(height, -1).mv(v))
35 | setattr(self.module, self.name, w / sigma.expand_as(w))
36 |
37 | def _made_params(self):
38 | try:
39 | u = getattr(self.module, self.name + "_u")
40 | v = getattr(self.module, self.name + "_v")
41 | w = getattr(self.module, self.name + "_bar")
42 | return True
43 | except AttributeError:
44 | return False
45 |
46 |
47 | def _make_params(self):
48 | w = getattr(self.module, self.name)
49 |
50 | height = w.data.shape[0]
51 | width = w.view(height, -1).data.shape[1]
52 |
53 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
54 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
55 | u.data = l2normalize(u.data)
56 | v.data = l2normalize(v.data)
57 | w_bar = Parameter(w.data)
58 |
59 | del self.module._parameters[self.name]
60 |
61 | self.module.register_parameter(self.name + "_u", u)
62 | self.module.register_parameter(self.name + "_v", v)
63 | self.module.register_parameter(self.name + "_bar", w_bar)
64 |
65 |
66 | def forward(self, *args):
67 | self._update_u_v()
68 | return self.module.forward(*args)
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import time
4 | import torch
5 | import datetime
6 |
7 | import torch.nn as nn
8 | from torch.autograd import Variable
9 | from torchvision.utils import save_image
10 |
11 | from sagan_models import Generator, Discriminator
12 | from utils import *
13 |
14 | class Trainer(object):
15 | def __init__(self, data_loader, config):
16 |
17 | # Data loader
18 | self.data_loader = data_loader
19 |
20 | # exact model and loss
21 | self.model = config.model
22 | self.adv_loss = config.adv_loss
23 |
24 | # Model hyper-parameters
25 | self.imsize = config.imsize
26 | self.g_num = config.g_num
27 | self.z_dim = config.z_dim
28 | self.g_conv_dim = config.g_conv_dim
29 | self.d_conv_dim = config.d_conv_dim
30 | self.parallel = config.parallel
31 |
32 | self.lambda_gp = config.lambda_gp
33 | self.total_step = config.total_step
34 | self.d_iters = config.d_iters
35 | self.batch_size = config.batch_size
36 | self.num_workers = config.num_workers
37 | self.g_lr = config.g_lr
38 | self.d_lr = config.d_lr
39 | self.lr_decay = config.lr_decay
40 | self.beta1 = config.beta1
41 | self.beta2 = config.beta2
42 | self.pretrained_model = config.pretrained_model
43 |
44 | self.dataset = config.dataset
45 | self.use_tensorboard = config.use_tensorboard
46 | self.image_path = config.image_path
47 | self.log_path = config.log_path
48 | self.model_save_path = config.model_save_path
49 | self.sample_path = config.sample_path
50 | self.log_step = config.log_step
51 | self.sample_step = config.sample_step
52 | self.model_save_step = config.model_save_step
53 | self.version = config.version
54 |
55 | # Path
56 | self.log_path = os.path.join(config.log_path, self.version)
57 | self.sample_path = os.path.join(config.sample_path, self.version)
58 | self.model_save_path = os.path.join(config.model_save_path, self.version)
59 |
60 | self.build_model()
61 |
62 | if self.use_tensorboard:
63 | self.build_tensorboard()
64 |
65 | # Start with trained model
66 | if self.pretrained_model:
67 | self.load_pretrained_model()
68 |
69 |
70 |
71 | def train(self):
72 |
73 | # Data iterator
74 | data_iter = iter(self.data_loader)
75 | step_per_epoch = len(self.data_loader)
76 | model_save_step = int(self.model_save_step * step_per_epoch)
77 |
78 | # Fixed input for debugging
79 | fixed_z = tensor2var(torch.randn(self.batch_size, self.z_dim))
80 |
81 | # Start with trained model
82 | if self.pretrained_model:
83 | start = self.pretrained_model + 1
84 | else:
85 | start = 0
86 |
87 | # Start time
88 | start_time = time.time()
89 | for step in range(start, self.total_step):
90 |
91 | # ================== Train D ================== #
92 | self.D.train()
93 | self.G.train()
94 |
95 | try:
96 | real_images, _ = next(data_iter)
97 | except:
98 | data_iter = iter(self.data_loader)
99 | real_images, _ = next(data_iter)
100 |
101 | # Compute loss with real images
102 | # dr1, dr2, df1, df2, gf1, gf2 are attention scores
103 | real_images = tensor2var(real_images)
104 | d_out_real,dr1,dr2 = self.D(real_images)
105 | if self.adv_loss == 'wgan-gp':
106 | d_loss_real = - torch.mean(d_out_real)
107 | elif self.adv_loss == 'hinge':
108 | d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()
109 |
110 | # apply Gumbel Softmax
111 | z = tensor2var(torch.randn(real_images.size(0), self.z_dim))
112 | fake_images,gf1,gf2 = self.G(z)
113 | d_out_fake,df1,df2 = self.D(fake_images)
114 |
115 | if self.adv_loss == 'wgan-gp':
116 | d_loss_fake = d_out_fake.mean()
117 | elif self.adv_loss == 'hinge':
118 | d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()
119 |
120 |
121 | # Backward + Optimize
122 | d_loss = d_loss_real + d_loss_fake
123 | self.reset_grad()
124 | d_loss.backward()
125 | self.d_optimizer.step()
126 |
127 |
128 | if self.adv_loss == 'wgan-gp':
129 | # Compute gradient penalty
130 | alpha = torch.rand(real_images.size(0), 1, 1, 1).cuda().expand_as(real_images)
131 | interpolated = Variable(alpha * real_images.data + (1 - alpha) * fake_images.data, requires_grad=True)
132 | out,_,_ = self.D(interpolated)
133 |
134 | grad = torch.autograd.grad(outputs=out,
135 | inputs=interpolated,
136 | grad_outputs=torch.ones(out.size()).cuda(),
137 | retain_graph=True,
138 | create_graph=True,
139 | only_inputs=True)[0]
140 |
141 | grad = grad.view(grad.size(0), -1)
142 | grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
143 | d_loss_gp = torch.mean((grad_l2norm - 1) ** 2)
144 |
145 | # Backward + Optimize
146 | d_loss = self.lambda_gp * d_loss_gp
147 |
148 | self.reset_grad()
149 | d_loss.backward()
150 | self.d_optimizer.step()
151 |
152 | # ================== Train G and gumbel ================== #
153 | # Create random noise
154 | z = tensor2var(torch.randn(real_images.size(0), self.z_dim))
155 | fake_images,_,_ = self.G(z)
156 |
157 | # Compute loss with fake images
158 | g_out_fake,_,_ = self.D(fake_images) # batch x n
159 | if self.adv_loss == 'wgan-gp':
160 | g_loss_fake = - g_out_fake.mean()
161 | elif self.adv_loss == 'hinge':
162 | g_loss_fake = - g_out_fake.mean()
163 |
164 | self.reset_grad()
165 | g_loss_fake.backward()
166 | self.g_optimizer.step()
167 |
168 |
169 | # Print out log info
170 | if (step + 1) % self.log_step == 0:
171 | elapsed = time.time() - start_time
172 | elapsed = str(datetime.timedelta(seconds=elapsed))
173 | print("Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_out_real: {:.4f}, "
174 | " ave_gamma_l3: {:.4f}, ave_gamma_l4: {:.4f}".
175 | format(elapsed, step + 1, self.total_step, (step + 1),
176 | self.total_step , d_loss_real.data[0],
177 | self.G.attn1.gamma.mean().data[0], self.G.attn2.gamma.mean().data[0] ))
178 |
179 | # Sample images
180 | if (step + 1) % self.sample_step == 0:
181 | fake_images,_,_= self.G(fixed_z)
182 | save_image(denorm(fake_images.data),
183 | os.path.join(self.sample_path, '{}_fake.png'.format(step + 1)))
184 |
185 | if (step+1) % model_save_step==0:
186 | torch.save(self.G.state_dict(),
187 | os.path.join(self.model_save_path, '{}_G.pth'.format(step + 1)))
188 | torch.save(self.D.state_dict(),
189 | os.path.join(self.model_save_path, '{}_D.pth'.format(step + 1)))
190 |
191 | def build_model(self):
192 |
193 | self.G = Generator(self.batch_size,self.imsize, self.z_dim, self.g_conv_dim).cuda()
194 | self.D = Discriminator(self.batch_size,self.imsize, self.d_conv_dim).cuda()
195 | if self.parallel:
196 | self.G = nn.DataParallel(self.G)
197 | self.D = nn.DataParallel(self.D)
198 |
199 | # Loss and optimizer
200 | # self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
201 | self.g_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr, [self.beta1, self.beta2])
202 | self.d_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.beta1, self.beta2])
203 |
204 | self.c_loss = torch.nn.CrossEntropyLoss()
205 | # print networks
206 | print(self.G)
207 | print(self.D)
208 |
209 | def build_tensorboard(self):
210 | from logger import Logger
211 | self.logger = Logger(self.log_path)
212 |
213 | def load_pretrained_model(self):
214 | self.G.load_state_dict(torch.load(os.path.join(
215 | self.model_save_path, '{}_G.pth'.format(self.pretrained_model))))
216 | self.D.load_state_dict(torch.load(os.path.join(
217 | self.model_save_path, '{}_D.pth'.format(self.pretrained_model))))
218 | print('loaded trained models (step: {})..!'.format(self.pretrained_model))
219 |
220 | def reset_grad(self):
221 | self.d_optimizer.zero_grad()
222 | self.g_optimizer.zero_grad()
223 |
224 | def save_sample(self, data_iter):
225 | real_images, _ = next(data_iter)
226 | save_image(denorm(real_images), os.path.join(self.sample_path, 'real.png'))
227 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.autograd import Variable
4 |
5 |
6 | def make_folder(path, version):
7 | if not os.path.exists(os.path.join(path, version)):
8 | os.makedirs(os.path.join(path, version))
9 |
10 |
11 | def tensor2var(x, grad=False):
12 | if torch.cuda.is_available():
13 | x = x.cuda()
14 | return Variable(x, requires_grad=grad)
15 |
16 | def var2tensor(x):
17 | return x.data.cpu()
18 |
19 | def var2numpy(x):
20 | return x.data.cpu().numpy()
21 |
22 | def denorm(x):
23 | out = (x + 1) / 2
24 | return out.clamp_(0, 1)
25 |
26 |
--------------------------------------------------------------------------------