├── README.md ├── agents ├── __init__.py └── stgan.py ├── configs └── train_stgan.yaml ├── datasets ├── __init__.py └── celeba.py ├── main.py ├── models ├── __init__.py └── stgan.py ├── sample.jpg └── utils ├── __init__.py ├── config.py └── misc.py /README.md: -------------------------------------------------------------------------------- 1 | # STGAN (CVPR 2019) 2 | 3 | An unofficial **PyTorch** implementation of [**STGAN: A Unified Selective Transfer Network for Arbitrary Image Attribute Editing**](https://arxiv.org/abs/1904.09709). 4 | 5 | ## Requirements 6 | - [Python 3.6+](https://www.python.org) 7 | - [PyTorch 1.0+](https://pytorch.org) 8 | - [tensorboardX 1.6+](https://github.com/lanpa/tensorboardX) 9 | - [torchsummary](https://github.com/sksq96/pytorch-summary) 10 | - [tqdm](https://github.com/tqdm/tqdm) 11 | - [Pillow](https://github.com/python-pillow/Pillow) 12 | - [easydict](https://github.com/makinacorpus/easydict) 13 | 14 | ## Sample 15 | 16 | From left to right: Origin, Bangs, Blond_Hair, Brown_Hair, Bushy_Eyebrows, Eyeglasses, Male, Mouth_Slightly_Open, Mustache, Pale_Skin, Young. 17 | 18 | ![](sample.jpg) 19 | 20 | ## Preparation 21 | 22 | Please download the [CelebA](http://openaccess.thecvf.com/content_iccv_2015/papers/Liu_Deep_Learning_Face_ICCV_2015_paper.pdf) dataset from this [project page](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html). Then organize the directory as: 23 | 24 | ``` 25 | ├── data_root 26 | │ └── image 27 | │ ├── 000001.jpg 28 | │ ├── 000002.jpg 29 | │ ├── 000003.jpg 30 | │ └── ... 31 | │ └── anno 32 | │ ├── list_attr_celeba.txt 33 | │ └── ... 34 | ``` 35 | 36 | ## Training 37 | 38 | - For quickly start, you can simply use the following command to train: 39 | 40 | ```console 41 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config ./configs/train_stgan.yaml 42 | ``` 43 | 44 | - If you want to modify some hyper-parameters, please edit them in the configuration file `./configs/train_stgan.yaml` following the explanations below: 45 | - `exp_name`: the name of current experiment. 46 | - `mode`: 'train' or 'test'. 47 | - `cuda`: use CUDA or not. 48 | - `ngpu`: how many gpu cards to use. Notice: this number should be no more than the length of CUDA_VISIBLE_DEVICES list. 49 | - `dataset`: the name of dataset. Notice: you can extend other datasets. 50 | - `data_root`: the root of dataset. 51 | - `crop_size`: the crop size of images. 52 | - `image_size`: the size of input images during training. 53 | - `g_conv_dim`: the base filter numbers of convolutional layers in G. 54 | - `d_conv_dim`: the base filter numbers of convolutional layers in D. 55 | - `d_fc_dim`: the dimmension of fully-connected layers in D. 56 | - `g_layers`: the number of convolutional layers in G. Notice: same for both encoder and decoder. 57 | - `d_layers`: the number of convolutional layers in D. 58 | - `shortcut_layers`: the number of shortcut connections in G. Notice: also the number of STUs. 59 | - `stu_kernel_size`: the kernel size of convolutional layers in STU. 60 | - `use_stu`: if set to false, there will be no STU in shortcut connections. 61 | - `one_more_conv`: if set to true, there will be another convolutional layer between the decoder and generated image. 62 | - `attrs`: the list of all selected atrributes. Notice: please refer to `list_attr_celeba.txt` for all avaliable attributes. 63 | - `checkpoint`: the iteration step number of the checkpoint to be resumed. Notice: please set this to `~` if it's first time to train. 64 | - `batch_size`: batch size of data loader. 65 | - `beta1`: beta1 value of Adam optimizer. 66 | - `beta2`: beta2 value of Adam optimizer. 67 | - `g_lr`: the base learning rate of G. 68 | - `d_lr`: the base learning rate of D. 69 | - `n_critic`: number of D updates per each G update. 70 | - `thres_int`: the threshold of target vector during training. 71 | - `lambda_gp`: tradeoff coefficient of D_loss_gp. 72 | - `lambda1`: tradeoff coefficient of D_loss_att. 73 | - `lambda2`: tradeoff coefficient of G_loss_att. 74 | - `lambda3`: tradeoff coefficient of G_loss_rec. 75 | - `max_iters`: maximum iteration steps. 76 | - `lr_decay_iters`: iteration steps per learning rate decay. 77 | - `summary_step`: iteration steps per summary operation with tensorboardX. 78 | - `sample_step`: iteration steps per sampling operation. 79 | - `checkpoint_step`: iteration steps per checkpoint saving operation. 80 | 81 | ## Acknowledgements 82 | 83 | This code refers to the following two projects: 84 | 85 | [1] [TensorFlow implementation of STGAN](https://github.com/csmliu/STGAN) 86 | 87 | [2] [PyTorch implementation of StarGAN](https://github.com/yunjey/stargan) 88 | -------------------------------------------------------------------------------- /agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .stgan import STGANAgent -------------------------------------------------------------------------------- /agents/stgan.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import time 4 | import datetime 5 | import traceback 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torch.backends import cudnn 11 | from torchvision.utils import make_grid, save_image 12 | from tqdm import tqdm 13 | from tensorboardX import SummaryWriter 14 | 15 | from datasets import * 16 | from models.stgan import Generator, Discriminator 17 | from utils.misc import print_cuda_statistics 18 | 19 | cudnn.benchmark = True 20 | 21 | 22 | class STGANAgent(object): 23 | def __init__(self, config): 24 | self.config = config 25 | self.logger = logging.getLogger("STGAN") 26 | self.logger.info("Creating STGAN architecture...") 27 | 28 | self.G = Generator(len(self.config.attrs), self.config.g_conv_dim, self.config.g_layers, self.config.shortcut_layers, use_stu=self.config.use_stu, one_more_conv=self.config.one_more_conv) 29 | self.D = Discriminator(self.config.image_size, len(self.config.attrs), self.config.d_conv_dim, self.config.d_fc_dim, self.config.d_layers) 30 | 31 | self.data_loader = globals()['{}_loader'.format(self.config.dataset)]( 32 | self.config.data_root, self.config.mode, self.config.attrs, 33 | self.config.crop_size, self.config.image_size, self.config.batch_size) 34 | 35 | self.current_iteration = 0 36 | self.cuda = torch.cuda.is_available() & self.config.cuda 37 | 38 | if self.cuda: 39 | self.device = torch.device("cuda") 40 | self.logger.info("Operation will be on *****GPU-CUDA***** ") 41 | print_cuda_statistics() 42 | else: 43 | self.device = torch.device("cpu") 44 | self.logger.info("Operation will be on *****CPU***** ") 45 | 46 | self.writer = SummaryWriter(log_dir=self.config.summary_dir) 47 | 48 | def save_checkpoint(self): 49 | G_state = { 50 | 'state_dict': self.G.state_dict(), 51 | 'optimizer': self.optimizer_G.state_dict(), 52 | } 53 | D_state = { 54 | 'state_dict': self.D.state_dict(), 55 | 'optimizer': self.optimizer_D.state_dict(), 56 | } 57 | G_filename = 'G_{}.pth.tar'.format(self.current_iteration) 58 | D_filename = 'D_{}.pth.tar'.format(self.current_iteration) 59 | torch.save(G_state, os.path.join(self.config.checkpoint_dir, G_filename)) 60 | torch.save(D_state, os.path.join(self.config.checkpoint_dir, D_filename)) 61 | 62 | def load_checkpoint(self): 63 | if self.config.checkpoint is None: 64 | self.G.to(self.device) 65 | self.D.to(self.device) 66 | return 67 | G_filename = 'G_{}.pth.tar'.format(self.config.checkpoint) 68 | D_filename = 'D_{}.pth.tar'.format(self.config.checkpoint) 69 | G_checkpoint = torch.load(os.path.join(self.config.checkpoint_dir, G_filename)) 70 | D_checkpoint = torch.load(os.path.join(self.config.checkpoint_dir, D_filename)) 71 | G_to_load = {k.replace('module.', ''): v for k, v in G_checkpoint['state_dict'].items()} 72 | D_to_load = {k.replace('module.', ''): v for k, v in D_checkpoint['state_dict'].items()} 73 | self.current_iteration = self.config.checkpoint 74 | self.G.load_state_dict(G_to_load) 75 | self.D.load_state_dict(D_to_load) 76 | self.G.to(self.device) 77 | self.D.to(self.device) 78 | if self.config.mode == 'train': 79 | self.optimizer_G.load_state_dict(G_checkpoint['optimizer']) 80 | self.optimizer_D.load_state_dict(D_checkpoint['optimizer']) 81 | 82 | def denorm(self, x): 83 | out = (x + 1) / 2 84 | return out.clamp_(0, 1) 85 | 86 | def create_labels(self, c_org, selected_attrs=None): 87 | """Generate target domain labels for debugging and testing.""" 88 | # get hair color indices 89 | hair_color_indices = [] 90 | for i, attr_name in enumerate(selected_attrs): 91 | if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']: 92 | hair_color_indices.append(i) 93 | 94 | c_trg_list = [] 95 | for i in range(len(selected_attrs)): 96 | c_trg = c_org.clone() 97 | if i in hair_color_indices: # set one hair color to 1 and the rest to 0 98 | c_trg[:, i] = 1 99 | for j in hair_color_indices: 100 | if j != i: 101 | c_trg[:, j] = 0 102 | else: 103 | c_trg[:, i] = (c_trg[:, i] == 0) # reverse attribute value 104 | 105 | c_trg_list.append(c_trg.to(self.device)) 106 | return c_trg_list 107 | 108 | def classification_loss(self, logit, target): 109 | """Compute binary cross entropy loss.""" 110 | return F.binary_cross_entropy_with_logits(logit, target, reduction='sum') / logit.size(0) 111 | 112 | def gradient_penalty(self, y, x): 113 | """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2.""" 114 | weight = torch.ones(y.size()).to(self.device) 115 | dydx = torch.autograd.grad(outputs=y, 116 | inputs=x, 117 | grad_outputs=weight, 118 | retain_graph=True, 119 | create_graph=True, 120 | only_inputs=True)[0] 121 | 122 | dydx = dydx.view(dydx.size(0), -1) 123 | dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1)) 124 | return torch.mean((dydx_l2norm-1)**2) 125 | 126 | def run(self): 127 | assert self.config.mode in ['train', 'test'] 128 | try: 129 | if self.config.mode == 'train': 130 | self.train() 131 | else: 132 | self.test() 133 | except KeyboardInterrupt: 134 | self.logger.info('You have entered CTRL+C.. Wait to finalize') 135 | except Exception as e: 136 | log_file = open(os.path.join(self.config.log_dir, 'exp_error.log'), 'w+') 137 | traceback.print_exc(file=log_file) 138 | finally: 139 | self.finalize() 140 | 141 | 142 | def train(self): 143 | self.optimizer_G = optim.Adam(self.G.parameters(), self.config.g_lr, [self.config.beta1, self.config.beta2]) 144 | self.optimizer_D = optim.Adam(self.D.parameters(), self.config.d_lr, [self.config.beta1, self.config.beta2]) 145 | self.lr_scheduler_G = optim.lr_scheduler.StepLR(self.optimizer_G, step_size=self.config.lr_decay_iters, gamma=0.1) 146 | self.lr_scheduler_D = optim.lr_scheduler.StepLR(self.optimizer_D, step_size=self.config.lr_decay_iters, gamma=0.1) 147 | 148 | self.load_checkpoint() 149 | if self.cuda and self.config.ngpu > 1: 150 | self.G = nn.DataParallel(self.G, device_ids=list(range(self.config.ngpu))) 151 | self.D = nn.DataParallel(self.D, device_ids=list(range(self.config.ngpu))) 152 | 153 | val_iter = iter(self.data_loader.val_loader) 154 | x_sample, c_org_sample = next(val_iter) 155 | x_sample = x_sample.to(self.device) 156 | c_sample_list = self.create_labels(c_org_sample, self.config.attrs) 157 | c_sample_list.insert(0, c_org_sample) # reconstruction 158 | 159 | self.g_lr = self.lr_scheduler_G.get_lr()[0] 160 | self.d_lr = self.lr_scheduler_D.get_lr()[0] 161 | 162 | data_iter = iter(self.data_loader.train_loader) 163 | start_time = time.time() 164 | for i in range(self.current_iteration, self.config.max_iters): 165 | self.G.train() 166 | self.D.train() 167 | # =================================================================================== # 168 | # 1. Preprocess input data # 169 | # =================================================================================== # 170 | 171 | # fetch real images and labels 172 | try: 173 | x_real, label_org = next(data_iter) 174 | except: 175 | data_iter = iter(self.data_loader.train_loader) 176 | x_real, label_org = next(data_iter) 177 | 178 | # generate target domain labels randomly 179 | rand_idx = torch.randperm(label_org.size(0)) 180 | label_trg = label_org[rand_idx] 181 | 182 | c_org = label_org.clone() 183 | c_trg = label_trg.clone() 184 | 185 | x_real = x_real.to(self.device) # input images 186 | c_org = c_org.to(self.device) # original domain labels 187 | c_trg = c_trg.to(self.device) # target domain labels 188 | label_org = label_org.to(self.device) # labels for computing classification loss 189 | label_trg = label_trg.to(self.device) # labels for computing classification loss 190 | 191 | # =================================================================================== # 192 | # 2. Train the discriminator # 193 | # =================================================================================== # 194 | 195 | # compute loss with real images 196 | out_src, out_cls = self.D(x_real) 197 | d_loss_real = - torch.mean(out_src) 198 | d_loss_cls = self.classification_loss(out_cls, label_org) 199 | 200 | # compute loss with fake images 201 | attr_diff = c_trg - c_org 202 | attr_diff = attr_diff * torch.rand_like(attr_diff) * (2 * self.config.thres_int) 203 | x_fake = self.G(x_real, attr_diff) 204 | out_src, out_cls = self.D(x_fake.detach()) 205 | d_loss_fake = torch.mean(out_src) 206 | 207 | # compute loss for gradient penalty 208 | alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device) 209 | x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True) 210 | out_src, _ = self.D(x_hat) 211 | d_loss_gp = self.gradient_penalty(out_src, x_hat) 212 | 213 | # backward and optimize 214 | d_loss_adv = d_loss_real + d_loss_fake + self.config.lambda_gp * d_loss_gp 215 | d_loss = d_loss_adv + self.config.lambda1 * d_loss_cls 216 | self.optimizer_D.zero_grad() 217 | d_loss.backward(retain_graph=True) 218 | self.optimizer_D.step() 219 | 220 | # summarize 221 | scalars = {} 222 | scalars['D/loss'] = d_loss.item() 223 | scalars['D/loss_adv'] = d_loss_adv.item() 224 | scalars['D/loss_cls'] = d_loss_cls.item() 225 | scalars['D/loss_real'] = d_loss_real.item() 226 | scalars['D/loss_fake'] = d_loss_fake.item() 227 | scalars['D/loss_gp'] = d_loss_gp.item() 228 | 229 | # =================================================================================== # 230 | # 3. Train the generator # 231 | # =================================================================================== # 232 | 233 | if (i + 1) % self.config.n_critic == 0: 234 | # original-to-target domain 235 | x_fake = self.G(x_real, attr_diff) 236 | out_src, out_cls = self.D(x_fake) 237 | g_loss_adv = - torch.mean(out_src) 238 | g_loss_cls = self.classification_loss(out_cls, label_trg) 239 | 240 | # target-to-original domain 241 | x_reconst = self.G(x_real, c_org - c_org) 242 | g_loss_rec = torch.mean(torch.abs(x_real - x_reconst)) 243 | 244 | # backward and optimize 245 | g_loss = g_loss_adv + self.config.lambda3 * g_loss_rec + self.config.lambda2 * g_loss_cls 246 | self.optimizer_G.zero_grad() 247 | g_loss.backward() 248 | self.optimizer_G.step() 249 | 250 | # summarize 251 | scalars['G/loss'] = g_loss.item() 252 | scalars['G/loss_adv'] = g_loss_adv.item() 253 | scalars['G/loss_cls'] = g_loss_cls.item() 254 | scalars['G/loss_rec'] = g_loss_rec.item() 255 | 256 | self.current_iteration += 1 257 | 258 | # =================================================================================== # 259 | # 4. Miscellaneous # 260 | # =================================================================================== # 261 | 262 | if self.current_iteration % self.config.summary_step == 0: 263 | et = time.time() - start_time 264 | et = str(datetime.timedelta(seconds=et))[:-7] 265 | print('Elapsed [{}], Iteration [{}/{}]'.format(et, self.current_iteration, self.config.max_iters)) 266 | for tag, value in scalars.items(): 267 | self.writer.add_scalar(tag, value, self.current_iteration) 268 | 269 | if self.current_iteration % self.config.sample_step == 0: 270 | self.G.eval() 271 | with torch.no_grad(): 272 | x_sample = x_sample.to(self.device) 273 | x_fake_list = [x_sample] 274 | for c_trg_sample in c_sample_list: 275 | attr_diff = c_trg_sample.to(self.device) - c_org_sample.to(self.device) 276 | attr_diff = attr_diff * self.config.thres_int 277 | x_fake_list.append(self.G(x_sample, attr_diff.to(self.device))) 278 | x_concat = torch.cat(x_fake_list, dim=3) 279 | self.writer.add_image('sample', make_grid(self.denorm(x_concat.data.cpu()), nrow=1), 280 | self.current_iteration) 281 | save_image(self.denorm(x_concat.data.cpu()), 282 | os.path.join(self.config.sample_dir, 'sample_{}.jpg'.format(self.current_iteration)), 283 | nrow=1, padding=0) 284 | 285 | if self.current_iteration % self.config.checkpoint_step == 0: 286 | self.save_checkpoint() 287 | 288 | self.lr_scheduler_G.step() 289 | self.lr_scheduler_D.step() 290 | 291 | def test(self): 292 | self.load_checkpoint() 293 | self.G.to(self.device) 294 | 295 | tqdm_loader = tqdm(self.data_loader.test_loader, total=self.data_loader.test_iterations, 296 | desc='Testing at checkpoint {}'.format(self.config.checkpoint)) 297 | 298 | self.G.eval() 299 | with torch.no_grad(): 300 | for i, (x_real, c_org) in enumerate(tqdm_loader): 301 | x_real = x_real.to(self.device) 302 | c_trg_list = self.create_labels(c_org, self.config.attrs) 303 | 304 | x_fake_list = [x_real] 305 | for c_trg in c_trg_list: 306 | attr_diff = c_trg - c_org 307 | x_fake_list.append(self.G(x_real, attr_diff.to(self.device))) 308 | x_concat = torch.cat(x_fake_list, dim=3) 309 | result_path = os.path.join(self.config.result_dir, 'sample_{}.jpg'.format(i + 1)) 310 | save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0) 311 | 312 | def finalize(self): 313 | print('Please wait while finalizing the operation.. Thank you') 314 | self.writer.export_scalars_to_json(os.path.join(self.config.summary_dir, 'all_scalars.json')) 315 | self.writer.close() 316 | -------------------------------------------------------------------------------- /configs/train_stgan.yaml: -------------------------------------------------------------------------------- 1 | # meta 2 | exp_name: stgan 3 | mode: train 4 | cuda: true 5 | ngpu: 2 6 | 7 | # data 8 | dataset: celeba 9 | data_root: /root/datasets/CelebA/ 10 | crop_size: 178 11 | image_size: 128 12 | 13 | # model 14 | g_conv_dim: 64 15 | d_conv_dim: 64 16 | d_fc_dim: 1024 17 | g_layers: 5 18 | d_layers: 5 19 | shortcut_layers: 3 20 | stu_kernel_size: 3 21 | use_stu: true 22 | one_more_conv: true 23 | attrs: [Bangs, Blond_Hair, Brown_Hair, Bushy_Eyebrows, Eyeglasses, Male, Mouth_Slightly_Open, Mustache, Pale_Skin, Young] 24 | checkpoint: ~ 25 | 26 | # training 27 | batch_size: 32 28 | beta1: 0.5 29 | beta2: 0.999 30 | g_lr: 0.0002 31 | d_lr: 0.0002 32 | n_critic: 5 33 | thres_int: 0.5 34 | lambda_gp: 10 35 | lambda1: 1 36 | lambda2: 10 37 | lambda3: 100 38 | max_iters: 1000000 39 | lr_decay_iters: 800000 40 | 41 | # steps: 42 | summary_step: 100 43 | sample_step: 5000 44 | checkpoint_step: 20000 -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .celeba import CelebADataLoader as celeba_loader -------------------------------------------------------------------------------- /datasets/celeba.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | from torch.utils import data 5 | from torchvision import transforms 6 | from PIL import Image 7 | 8 | 9 | def make_dataset(root, mode, selected_attrs): 10 | assert mode in ['train', 'val', 'test'] 11 | lines = [line.rstrip() for line in open(os.path.join(root, 'anno', 'list_attr_celeba.txt'), 'r')] 12 | all_attr_names = lines[1].split() 13 | attr2idx = {} 14 | idx2attr = {} 15 | for i, attr_name in enumerate(all_attr_names): 16 | attr2idx[attr_name] = i 17 | idx2attr[i] = attr_name 18 | 19 | lines = lines[2:] 20 | if mode == 'train': 21 | lines = lines[:-2000] # train set contains 200599 images 22 | if mode == 'val': 23 | lines = lines[-2000:-1800] # val set contains 200 images 24 | if mode == 'test': 25 | lines = lines[-1800:] # test set contains 1800 images 26 | 27 | items = [] 28 | for i, line in enumerate(lines): 29 | split = line.split() 30 | filename = split[0] 31 | values = split[1:] 32 | label = [] 33 | for attr_name in selected_attrs: 34 | idx = attr2idx[attr_name] 35 | label.append(values[idx] == '1') 36 | items.append([filename, label]) 37 | return items 38 | 39 | 40 | class CelebADataset(data.Dataset): 41 | def __init__(self, root, mode, selected_attrs, transform=None): 42 | self.items = make_dataset(root, mode, selected_attrs) 43 | self.root = root 44 | self.mode = mode 45 | self.transform = transform 46 | 47 | def __getitem__(self, index): 48 | filename, label = self.items[index] 49 | image = Image.open(os.path.join(self.root, 'image', filename)) 50 | if self.transform is not None: 51 | image = self.transform(image) 52 | return image, torch.FloatTensor(label) 53 | 54 | def __len__(self): 55 | return len(self.items) 56 | 57 | 58 | class CelebADataLoader(object): 59 | def __init__(self, root, mode, selected_attrs, crop_size=None, image_size=128, batch_size=16): 60 | if mode not in ['train', 'test',]: 61 | return 62 | 63 | transform = [] 64 | if crop_size is not None: 65 | transform.append(transforms.CenterCrop(crop_size)) 66 | transform.append(transforms.Resize(image_size)) 67 | transform.append(transforms.ToTensor()) 68 | transform.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) 69 | 70 | if mode == 'train': 71 | val_transform = transforms.Compose(transform) # make val loader before transform is inserted 72 | val_set = CelebADataset(root, 'val', selected_attrs, transform=val_transform) 73 | self.val_loader = data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=4) 74 | self.val_iterations = int(math.ceil(len(val_set) / batch_size)) 75 | 76 | transform.insert(0, transforms.RandomHorizontalFlip()) 77 | train_transform = transforms.Compose(transform) 78 | train_set = CelebADataset(root, 'train', selected_attrs, transform=train_transform) 79 | self.train_loader = data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4) 80 | self.train_iterations = int(math.ceil(len(train_set) / batch_size)) 81 | else: 82 | test_transform = transforms.Compose(transform) 83 | test_set = CelebADataset(root, 'test', selected_attrs, transform=test_transform) 84 | self.test_loader = data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4) 85 | self.test_iterations = int(math.ceil(len(test_set) / batch_size)) 86 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from agents import STGANAgent 4 | from utils.config import * 5 | 6 | 7 | def main(): 8 | arg_parser = argparse.ArgumentParser() 9 | arg_parser.add_argument( 10 | '--config', 11 | default=None, 12 | help='The path of configuration file in yaml format') 13 | args = arg_parser.parse_args() 14 | config = process_config(args.config) 15 | agent = STGANAgent(config) 16 | agent.run() 17 | 18 | 19 | if __name__ == '__main__': 20 | main() -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bluestyle97/STGAN-pytorch/1136a5dff16f3c799bee467501b28cdd78517e70/models/__init__.py -------------------------------------------------------------------------------- /models/stgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchsummary import summary 4 | 5 | 6 | class ConvGRUCell(nn.Module): 7 | def __init__(self, n_attrs, in_dim, out_dim, kernel_size=3): 8 | super(ConvGRUCell, self).__init__() 9 | self.n_attrs = n_attrs 10 | self.upsample = nn.ConvTranspose2d(in_dim * 2 + n_attrs, out_dim, 4, 2, 1, bias=False) 11 | self.reset_gate = nn.Sequential( 12 | nn.Conv2d(in_dim + out_dim, out_dim, kernel_size, 1, (kernel_size - 1) // 2, bias=False), 13 | nn.BatchNorm2d(out_dim), 14 | nn.Sigmoid() 15 | ) 16 | self.update_gate = nn.Sequential( 17 | nn.Conv2d(in_dim + out_dim, out_dim, kernel_size, 1, (kernel_size - 1) // 2, bias=False), 18 | nn.BatchNorm2d(out_dim), 19 | nn.Sigmoid() 20 | ) 21 | self.hidden = nn.Sequential( 22 | nn.Conv2d(in_dim + out_dim, out_dim, kernel_size, 1, (kernel_size - 1) // 2, bias=False), 23 | nn.BatchNorm2d(out_dim), 24 | nn.Tanh() 25 | ) 26 | 27 | def forward(self, input, old_state, attr): 28 | n, _, h, w = old_state.size() 29 | attr = attr.view((n, self.n_attrs, 1, 1)).expand((n, self.n_attrs, h, w)) 30 | state_hat = self.upsample(torch.cat([old_state, attr], 1)) 31 | r = self.reset_gate(torch.cat([input, state_hat], dim=1)) 32 | z = self.update_gate(torch.cat([input, state_hat], dim=1)) 33 | new_state = r * state_hat 34 | hidden_info = self.hidden(torch.cat([input, new_state], dim=1)) 35 | output = (1-z) * state_hat + z * hidden_info 36 | return output, new_state 37 | 38 | 39 | class Generator(nn.Module): 40 | def __init__(self, attr_dim, conv_dim=64, n_layers=5, shortcut_layers=2, stu_kernel_size=3, use_stu=True, one_more_conv=True): 41 | super(Generator, self).__init__() 42 | self.n_attrs = attr_dim 43 | self.n_layers = n_layers 44 | self.shortcut_layers = min(shortcut_layers, n_layers - 1) 45 | self.use_stu = use_stu 46 | 47 | self.encoder = nn.ModuleList() 48 | in_channels = 3 49 | for i in range(self.n_layers): 50 | self.encoder.append(nn.Sequential( 51 | nn.Conv2d(in_channels, conv_dim * 2 ** i, 4, 2, 1, bias=False), 52 | nn.BatchNorm2d(conv_dim * 2 ** i), 53 | nn.LeakyReLU(negative_slope=0.2, inplace=True) 54 | )) 55 | in_channels = conv_dim * 2 ** i 56 | 57 | if use_stu: 58 | self.stu = nn.ModuleList() 59 | for i in reversed(range(self.n_layers - 1 - self.shortcut_layers, self.n_layers - 1)): 60 | self.stu.append(ConvGRUCell(self.n_attrs, conv_dim * 2 ** i, conv_dim * 2 ** i, stu_kernel_size)) 61 | 62 | self.decoder = nn.ModuleList() 63 | for i in range(self.n_layers): 64 | if i < self.n_layers - 1: 65 | if i == 0: 66 | self.decoder.append(nn.Sequential( 67 | nn.ConvTranspose2d(conv_dim * 2 ** (self.n_layers - 1) + attr_dim, 68 | conv_dim * 2 ** (self.n_layers - 1), 4, 2, 1, bias=False), 69 | nn.BatchNorm2d(in_channels), 70 | nn.ReLU(inplace=True) 71 | )) 72 | elif i <= self.shortcut_layers: # not < 73 | self.decoder.append(nn.Sequential( 74 | nn.ConvTranspose2d(conv_dim * 3 * 2 ** (self.n_layers - 1 - i), 75 | conv_dim * 2 ** (self.n_layers - 1 - i), 4, 2, 1, bias=False), 76 | nn.BatchNorm2d(conv_dim * 2 ** (self.n_layers - 1 - i)), 77 | nn.ReLU(inplace=True) 78 | )) 79 | else: 80 | self.decoder.append(nn.Sequential( 81 | nn.ConvTranspose2d(conv_dim * 2 ** (self.n_layers - i), 82 | conv_dim * 2 ** (self.n_layers - 1 - i), 4, 2, 1, bias=False), 83 | nn.BatchNorm2d(conv_dim * 2 ** (self.n_layers - 1 - i)), 84 | nn.ReLU(inplace=True) 85 | )) 86 | else: 87 | in_dim = conv_dim * 3 if self.shortcut_layers == self.n_layers - 1 else conv_dim * 2 88 | if one_more_conv: 89 | self.decoder.append(nn.Sequential( 90 | nn.ConvTranspose2d(in_dim, conv_dim // 4, 4, 2, 1, bias=False), 91 | nn.BatchNorm2d(conv_dim // 4), 92 | nn.ReLU(inplace=True), 93 | 94 | nn.ConvTranspose2d(conv_dim // 4, 3, 3, 1, 1, bias=False), 95 | nn.Tanh() 96 | )) 97 | else: 98 | self.decoder.append(nn.Sequential( 99 | nn.ConvTranspose2d(in_dim, 3, 4, 2, 1, bias=False), 100 | nn.Tanh() 101 | )) 102 | 103 | def forward(self, x, a): 104 | # propagate encoder layers 105 | y = [] 106 | x_ = x 107 | for layer in self.encoder: 108 | x_ = layer(x_) 109 | y.append(x_) 110 | 111 | out = y[-1] 112 | n, _, h, w = out.size() 113 | attr = a.view((n, self.n_attrs, 1, 1)).expand((n, self.n_attrs, h, w)) 114 | out = self.decoder[0](torch.cat([out, attr], dim=1)) 115 | stu_state = y[-1] 116 | 117 | # propagate shortcut layers 118 | for i in range(1, self.shortcut_layers + 1): 119 | if self.use_stu: 120 | stu_out, stu_state = self.stu[i-1](y[-(i+1)], stu_state, a) 121 | out = torch.cat([out, stu_out], dim=1) 122 | out = self.decoder[i](out) 123 | else: 124 | out = torch.cat([out, y[-(i+1)]], dim=1) 125 | out = self.decoder[i](out) 126 | 127 | # propagate non-shortcut layers 128 | for i in range(self.shortcut_layers + 1, self.n_layers): 129 | out = self.decoder[i](out) 130 | 131 | return out 132 | 133 | 134 | class Discriminator(nn.Module): 135 | def __init__(self, image_size=128, attr_dim=10, conv_dim=64, fc_dim=1024, n_layers=5): 136 | super(Discriminator, self).__init__() 137 | layers = [] 138 | in_channels = 3 139 | for i in range(n_layers): 140 | layers.append(nn.Sequential( 141 | nn.Conv2d(in_channels, conv_dim * 2 ** i, 4, 2, 1), 142 | nn.InstanceNorm2d(conv_dim * 2 ** i, affine=True, track_running_stats=True), 143 | nn.LeakyReLU(negative_slope=0.2, inplace=True) 144 | )) 145 | in_channels = conv_dim * 2 ** i 146 | self.conv = nn.Sequential(*layers) 147 | feature_size = image_size // 2**n_layers 148 | self.fc_adv = nn.Sequential( 149 | nn.Linear(conv_dim * 2 ** (n_layers - 1) * feature_size ** 2, fc_dim), 150 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 151 | nn.Linear(fc_dim, 1) 152 | ) 153 | self.fc_att = nn.Sequential( 154 | nn.Linear(conv_dim * 2 ** (n_layers - 1) * feature_size ** 2, fc_dim), 155 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 156 | nn.Linear(fc_dim, attr_dim), 157 | ) 158 | 159 | def forward(self, x): 160 | y = self.conv(x) 161 | y = y.view(y.size()[0], -1) 162 | logit_adv = self.fc_adv(y) 163 | logit_att = self.fc_att(y) 164 | return logit_adv, logit_att 165 | 166 | 167 | if __name__ == '__main__': 168 | gen = Generator(5, n_layers=6, shortcut_layers=5, use_stu=True, one_more_conv=True) 169 | summary(gen, [(3, 384, 384), (5,)], device='cpu') 170 | 171 | dis = Discriminator(image_size=384, attr_dim=5) 172 | summary(dis, (3, 384, 384), device='cpu') -------------------------------------------------------------------------------- /sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bluestyle97/STGAN-pytorch/1136a5dff16f3c799bee467501b28cdd78517e70/sample.jpg -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bluestyle97/STGAN-pytorch/1136a5dff16f3c799bee467501b28cdd78517e70/utils/__init__.py -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import logging 4 | from logging import Formatter 5 | from logging.handlers import RotatingFileHandler 6 | from easydict import EasyDict 7 | 8 | from utils.misc import create_dirs 9 | 10 | 11 | def setup_logging(log_dir): 12 | log_file_format = '[%(levelname)s] - %(asctime)s - %(name)s - : %(message)s in %(pathname)s:%(lineno)d' 13 | log_console_format = '[%(levelname)s]: %(message)s' 14 | 15 | # Main logger 16 | main_logger = logging.getLogger() 17 | main_logger.setLevel(logging.INFO) 18 | 19 | console_handler = logging.StreamHandler() 20 | console_handler.setLevel(logging.INFO) 21 | console_handler.setFormatter(Formatter(log_console_format)) 22 | 23 | exp_file_handler = RotatingFileHandler('{}exp_debug.log'.format(log_dir), maxBytes=10**6, backupCount=5) 24 | exp_file_handler.setLevel(logging.DEBUG) 25 | exp_file_handler.setFormatter(Formatter(log_file_format)) 26 | 27 | exp_errors_file_handler = RotatingFileHandler('{}exp_error.log'.format(log_dir), maxBytes=10**6, backupCount=5) 28 | exp_errors_file_handler.setLevel(logging.WARNING) 29 | exp_errors_file_handler.setFormatter(Formatter(log_file_format)) 30 | 31 | main_logger.addHandler(console_handler) 32 | main_logger.addHandler(exp_file_handler) 33 | main_logger.addHandler(exp_errors_file_handler) 34 | 35 | 36 | def get_config_from_yaml(yaml_file): 37 | with open(yaml_file, 'r') as config_file: 38 | try: 39 | config_dict = yaml.load(config_file) 40 | config = EasyDict(config_dict) 41 | return config 42 | except ValueError: 43 | print('INVALID YAML file format.. Please provide a good yaml file') 44 | exit(-1) 45 | 46 | 47 | def process_config(yaml_file): 48 | config = get_config_from_yaml(yaml_file) 49 | 50 | print(' *************************************** ') 51 | print(' The experiment name is {} '.format(config.exp_name)) 52 | print(' The experiment mode is {} '.format(config.mode)) 53 | print(' *************************************** ') 54 | 55 | # create some important directories to be used for that experiments 56 | config.summary_dir = os.path.join('experiments', config.exp_name, 'summaries/') 57 | config.checkpoint_dir = os.path.join('experiments', config.exp_name, 'checkpoints/') 58 | config.sample_dir = os.path.join('experiments', config.exp_name, 'samples/') 59 | config.log_dir = os.path.join('experiments', config.exp_name, 'logs/') 60 | config.result_dir = os.path.join('experiments', config.exp_name, 'results/') 61 | create_dirs([config.summary_dir, config.checkpoint_dir, config.sample_dir, config.log_dir, config.result_dir]) 62 | 63 | # setup logging in the project 64 | setup_logging(config.log_dir) 65 | 66 | logging.getLogger().info('Hi, This is root.') 67 | logging.getLogger().info('After the configurations are successfully processed and dirs are created.') 68 | logging.getLogger().info('The pipeline of the project will begin now.') 69 | 70 | return config 71 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | 5 | 6 | def timeit(f): 7 | """ Decorator to time Any Function """ 8 | 9 | def timed(*args, **kwargs): 10 | start_time = time.time() 11 | result = f(*args, **kwargs) 12 | end_time = time.time() 13 | seconds = end_time - start_time 14 | logging.getLogger("Timer").info(" [-] %s : %2.5f sec, which is %2.5f min, which is %2.5f hour" % (f.__name__, seconds, seconds / 60, seconds / 3600)) 15 | return result 16 | 17 | return timed 18 | 19 | 20 | def print_cuda_statistics(): 21 | logger = logging.getLogger("Cuda Statistics") 22 | import sys 23 | from subprocess import call 24 | import torch 25 | logger.info('__Python VERSION: {}'.format(sys.version)) 26 | logger.info('__PyTorch VERSION: {}'.format(torch.__version__)) 27 | logger.info('__CUDA VERSION') 28 | call(["cat", "/usr/local/cuda/version.txt"]) 29 | logger.info('__CUDNN VERSION: {}'.format(torch.backends.cudnn.version())) 30 | logger.info('__Number CUDA Devices: {}'.format(torch.cuda.device_count())) 31 | logger.info('__Devices') 32 | call(["nvidia-smi", "--format=csv", 33 | "--query-gpu=index,name,driver_version,memory.total,memory.used,memory.free"]) 34 | logger.info('__Active CUDA Device: GPU {}'.format(torch.cuda.current_device())) 35 | logger.info('__Available devices {}'.format(torch.cuda.device_count())) 36 | logger.info('__Current cuda device {}'.format(torch.cuda.current_device())) 37 | 38 | 39 | def create_dirs(dirs): 40 | try: 41 | for dir_ in dirs: 42 | if not os.path.exists(dir_): 43 | os.makedirs(dir_) 44 | except Exception as err: 45 | logging.getLogger("Dirs Creator").info("Creating directories error: {0}".format(err)) 46 | exit(-1) 47 | --------------------------------------------------------------------------------