├── .gitignore ├── LICENSE ├── README.md ├── configs └── aging_gan.yaml ├── dataset.py ├── gan_module.py ├── infer.py ├── main.py ├── models.py ├── preprocessing ├── __init__.py ├── preprocess_cacd.py └── preprocess_utk.py ├── pretrained_model └── state_dict.pth ├── requirements.txt └── timing.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | __MACOSX 4 | lightning_logs 5 | checkpoints 6 | /test_images 7 | /.venv 8 | /mygraph.png 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Hasnain Raza 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fast-AgingGAN 2 | This repository holds code for a face aging deep learning model. It is based on the CycleGAN, where we translate young faces to old and vice versa. 3 | 4 | # Samples 5 | Top row is input image, bottom row is aged output from the GAN. 6 | ![Sample](https://user-images.githubusercontent.com/4294680/86517626-b4d54100-be2a-11ea-8cf1-7e4e088f96a3.png) 7 | ![Second-Sample](https://user-images.githubusercontent.com/4294680/86517663-f5cd5580-be2a-11ea-9e39-51ddf8be2084.png) 8 | # Timing 9 | The model executes at 66fps on a GTX1080 with an image size of 512x512. Because of the way it is trained, a face detection pipeline is not needed. As long as the image of spatial dims 512x512 contains a face of size 256x256, this will work fine. 10 | 11 | # Demo 12 | To try out the pretrained model on your images, use the following command: 13 | ```bash 14 | python infer.py --image_dir 'path/to/your/image/directory' 15 | ``` 16 | 17 | # Training 18 | To train your own model on CACD or UTK faces datasets, you can use the provided preprocessing scripts in the preprocessing directory to prepare the dataset. 19 | If you are going to use CACD, use the following command: 20 | ```bash 21 | python preprocessing/preprocess_cacd.py --image_dir '/path/to/cacd/images' --metadata '/path/to/the/cacd/metadata/file' --output_dir 'path/to/save/processed/data' 22 | ``` 23 | If using UTK faces, use the following: 24 | ```bash 25 | python preprocessing/preprocess_utk.py --data_dir '/path/to/cacd/images' --output_dir 'path/to/save/processed/data' 26 | ``` 27 | 28 | Once the dataset is processed, you should go into ``` configs/aging_gan.yaml``` and modify the paths to point to the processed dataset you just created. Change any other hyperparameters if you wish, then run training with: 29 | ```bash 30 | python main.py 31 | ``` 32 | 33 | # Tensorboard 34 | While training is running, you can observe the losses and the gan generated images in tensorboard, just point it to the 'lightning_logs' directory like so: 35 | ```bash 36 | tensorboard --logdir=lightning_logs --bind_all 37 | ``` 38 | -------------------------------------------------------------------------------- /configs/aging_gan.yaml: -------------------------------------------------------------------------------- 1 | # Data configs 2 | domainA_dir: '/home/hasnain/Datasets/processedCACDDomains/trainA' 3 | domainB_dir: '/home/hasnain/Datasets/processedCACDDomains/trainB' 4 | 5 | # Network Configs 6 | ngf: 32 7 | ndf: 32 8 | n_blocks: 9 9 | 10 | # Loss configs: 11 | adv_weight: 2 12 | cycle_weight: 10 13 | identity_weight: 7 14 | 15 | # Optimizer configs: 16 | lr: 0.0001 17 | weight_decay: 0.0001 18 | 19 | # Training configs: 20 | img_size: 256 21 | batch_size: 3 22 | num_workers: 4 23 | epochs: 100 24 | augment_rotation: 80 25 | gpus: 26 | - 1 27 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | 6 | import numpy as np 7 | 8 | IMG_EXTENSIONS = ["png", "jpg"] 9 | 10 | class ImagetoImageDataset(Dataset): 11 | def __init__(self, domainA_dir, domainB_dir, transforms=None): 12 | self.imagesA = [os.path.join(domainA_dir, x) for x in os.listdir(domainA_dir) if 13 | x.lower().endswith(tuple(IMG_EXTENSIONS))] 14 | self.imagesB = [os.path.join(domainB_dir, x) for x in os.listdir(domainB_dir) if 15 | x.lower().endswith(tuple(IMG_EXTENSIONS))] 16 | 17 | self.transforms = transforms 18 | 19 | self.lenA = len(self.imagesA) 20 | self.lenB = len(self.imagesB) 21 | 22 | def __len__(self): 23 | return max(self.lenA, self.lenB) 24 | 25 | def __getitem__(self, idx): 26 | idx_a = idx_b = idx 27 | if idx_a >= self.lenA: 28 | idx_a = np.random.randint(self.lenA) 29 | if idx_b >= self.lenB: 30 | idx_b = np.random.randint(self.lenB) 31 | 32 | imageA = np.array(Image.open(self.imagesA[idx_a]).convert("RGB")) 33 | imageB = np.array(Image.open(self.imagesB[idx_b]).convert("RGB")) 34 | 35 | if self.transforms is not None: 36 | imageA = self.transforms(imageA) 37 | imageB = self.transforms(imageB) 38 | 39 | return imageA, imageB 40 | -------------------------------------------------------------------------------- /gan_module.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader 7 | from torchvision import transforms 8 | from torchvision.utils import make_grid 9 | 10 | from dataset import ImagetoImageDataset 11 | from models import Generator, Discriminator 12 | 13 | 14 | class AgingGAN(pl.LightningModule): 15 | 16 | def __init__(self, hparams): 17 | super(AgingGAN, self).__init__() 18 | self.save_hyperparameters(hparams) 19 | self.genA2B = Generator(hparams['ngf'], n_residual_blocks=hparams['n_blocks']) 20 | self.genB2A = Generator(hparams['ngf'], n_residual_blocks=hparams['n_blocks']) 21 | self.disGA = Discriminator(hparams['ndf']) 22 | self.disGB = Discriminator(hparams['ndf']) 23 | 24 | # cache for generated images 25 | self.generated_A = None 26 | self.generated_B = None 27 | self.real_A = None 28 | self.real_B = None 29 | 30 | def forward(self, x): 31 | return self.genA2B(x) 32 | 33 | def training_step(self, batch, batch_idx, optimizer_idx): 34 | real_A, real_B = batch 35 | 36 | if optimizer_idx == 0: 37 | # Identity loss 38 | # G_A2B(B) should equal B if real B is fed 39 | same_B = self.genA2B(real_B) 40 | loss_identity_B = F.l1_loss(same_B, real_B) * self.hparams['identity_weight'] 41 | # G_B2A(A) should equal A if real A is fed 42 | same_A = self.genB2A(real_A) 43 | loss_identity_A = F.l1_loss(same_A, real_A) * self.hparams['identity_weight'] 44 | 45 | # GAN loss 46 | fake_B = self.genA2B(real_A) 47 | pred_fake = self.disGB(fake_B) 48 | loss_GAN_A2B = F.mse_loss(pred_fake, torch.ones(pred_fake.shape).type_as(pred_fake)) * self.hparams[ 49 | 'adv_weight'] 50 | 51 | fake_A = self.genB2A(real_B) 52 | pred_fake = self.disGA(fake_A) 53 | loss_GAN_B2A = F.mse_loss(pred_fake, torch.ones(pred_fake.shape).type_as(pred_fake)) * self.hparams[ 54 | 'adv_weight'] 55 | 56 | # Cycle loss 57 | recovered_A = self.genB2A(fake_B) 58 | loss_cycle_ABA = F.l1_loss(recovered_A, real_A) * self.hparams['cycle_weight'] 59 | 60 | recovered_B = self.genA2B(fake_A) 61 | loss_cycle_BAB = F.l1_loss(recovered_B, real_B) * self.hparams['cycle_weight'] 62 | 63 | # Total loss 64 | g_loss = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB 65 | 66 | output = { 67 | 'loss': g_loss, 68 | 'log': {'Loss/Generator': g_loss} 69 | } 70 | self.log('Loss/Generator', g_loss) 71 | 72 | self.generated_B = fake_B 73 | self.generated_A = fake_A 74 | 75 | self.real_B = real_B 76 | self.real_A = real_A 77 | 78 | # Log to tb 79 | if batch_idx % 500 == 0: 80 | self.genA2B.eval() 81 | self.genB2A.eval() 82 | fake_A = self.genB2A(real_B) 83 | fake_B = self.genA2B(real_A) 84 | self.logger.experiment.add_image('Real/A', make_grid(self.real_A, normalize=True, scale_each=True), 85 | self.current_epoch) 86 | self.logger.experiment.add_image('Real/B', make_grid(self.real_B, normalize=True, scale_each=True), 87 | self.current_epoch) 88 | self.logger.experiment.add_image('Generated/A', 89 | make_grid(self.generated_A, normalize=True, scale_each=True), 90 | self.current_epoch) 91 | self.logger.experiment.add_image('Generated/B', 92 | make_grid(self.generated_B, normalize=True, scale_each=True), 93 | self.current_epoch) 94 | self.genA2B.train() 95 | self.genB2A.train() 96 | return output 97 | 98 | if optimizer_idx == 1: 99 | # Real loss 100 | pred_real = self.disGA(real_A) 101 | loss_D_real = F.mse_loss(pred_real, torch.ones(pred_real.shape).type_as(pred_real)) 102 | 103 | # Fake loss 104 | fake_A = self.generated_A 105 | pred_fake = self.disGA(fake_A.detach()) 106 | loss_D_fake = F.mse_loss(pred_fake, torch.zeros(pred_fake.shape).type_as(pred_fake)) 107 | 108 | # Total loss 109 | loss_D_A = (loss_D_real + loss_D_fake) * 0.5 110 | 111 | # Real loss 112 | pred_real = self.disGB(real_B) 113 | loss_D_real = F.mse_loss(pred_real, torch.ones(pred_real.shape).type_as(pred_real)) 114 | 115 | # Fake loss 116 | fake_B = self.generated_B 117 | pred_fake = self.disGB(fake_B.detach()) 118 | loss_D_fake = F.mse_loss(pred_fake, torch.zeros(pred_fake.shape).type_as(pred_fake)) 119 | 120 | # Total loss 121 | loss_D_B = (loss_D_real + loss_D_fake) * 0.5 122 | d_loss = loss_D_A + loss_D_B 123 | output = { 124 | 'loss': d_loss, 125 | 'log': {'Loss/Discriminator': d_loss} 126 | } 127 | self.log('Loss/Discriminator', d_loss) 128 | 129 | return output 130 | 131 | def configure_optimizers(self): 132 | g_optim = torch.optim.Adam(itertools.chain(self.genA2B.parameters(), self.genB2A.parameters()), 133 | lr=self.hparams['lr'], betas=(0.5, 0.999), 134 | weight_decay=self.hparams['weight_decay']) 135 | d_optim = torch.optim.Adam(itertools.chain(self.disGA.parameters(), 136 | self.disGB.parameters()), 137 | lr=self.hparams['lr'], 138 | betas=(0.5, 0.999), 139 | weight_decay=self.hparams['weight_decay']) 140 | return [g_optim, d_optim], [] 141 | 142 | def train_dataloader(self): 143 | train_transform = transforms.Compose([ 144 | transforms.ToPILImage(), 145 | transforms.RandomHorizontalFlip(), 146 | transforms.Resize((self.hparams['img_size'] + 50, self.hparams['img_size'] + 50)), 147 | transforms.RandomCrop(self.hparams['img_size']), 148 | #transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3), 149 | #transforms.RandomPerspective(p=0.5), 150 | transforms.RandomRotation(degrees=(0, int(self.hparams['augment_rotation']))), 151 | transforms.ToTensor(), 152 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 153 | ]) 154 | dataset = ImagetoImageDataset(self.hparams['domainA_dir'], self.hparams['domainB_dir'], train_transform) 155 | return DataLoader(dataset, 156 | batch_size=self.hparams['batch_size'], 157 | num_workers=self.hparams['num_workers'], 158 | shuffle=True) 159 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from argparse import ArgumentParser 4 | 5 | import matplotlib.pyplot as plt 6 | import torch 7 | from PIL import Image 8 | from torchvision import transforms 9 | 10 | from gan_module import Generator 11 | 12 | parser = ArgumentParser() 13 | parser.add_argument( 14 | '--image_dir', default='/Downloads/CACD_VS/', help='The image directory') 15 | 16 | 17 | @torch.no_grad() 18 | def main(): 19 | args = parser.parse_args() 20 | image_paths = [os.path.join(args.image_dir, x) for x in os.listdir(args.image_dir) if 21 | x.endswith('.png') or x.endswith('.jpg')] 22 | model = Generator(ngf=32, n_residual_blocks=9) 23 | ckpt = torch.load('pretrained_model/state_dict.pth', map_location='cpu') 24 | model.load_state_dict(ckpt) 25 | model.eval() 26 | trans = transforms.Compose([ 27 | transforms.Resize((512, 512)), 28 | transforms.ToTensor(), 29 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 30 | ]) 31 | nr_images = len(image_paths) if len(image_paths) >= 6 else 6 32 | fig, ax = plt.subplots(2, nr_images, figsize=(20, 10)) 33 | random.shuffle(image_paths) 34 | for i in range(nr_images): 35 | img = Image.open(image_paths[i]).convert('RGB') 36 | img = trans(img).unsqueeze(0) 37 | aged_face = model(img) 38 | aged_face = (aged_face.squeeze().permute(1, 2, 0).numpy() + 1.0) / 2.0 39 | ax[0, i].imshow((img.squeeze().permute(1, 2, 0).numpy() + 1.0) / 2.0) 40 | ax[1, i].imshow(aged_face) 41 | # plt.show() 42 | plt.savefig("mygraph.png") 43 | 44 | 45 | if __name__ == '__main__': 46 | main() 47 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import yaml 4 | from pytorch_lightning import Trainer 5 | 6 | from gan_module import AgingGAN 7 | 8 | parser = ArgumentParser() 9 | parser.add_argument('--config', default='configs/aging_gan.yaml', help='Config to use for training') 10 | 11 | 12 | def main(): 13 | args = parser.parse_args() 14 | with open(args.config) as file: 15 | config = yaml.load(file, Loader=yaml.FullLoader) 16 | print(config) 17 | model = AgingGAN(config) 18 | trainer = Trainer(max_epochs=config['epochs'], gpus=config['gpus'], auto_scale_batch_size='binsearch') 19 | trainer.fit(model) 20 | 21 | 22 | if __name__ == '__main__': 23 | main() 24 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class ResidualBlock(nn.Module): 6 | def __init__(self, in_features): 7 | super(ResidualBlock, self).__init__() 8 | 9 | conv_block = [nn.ReflectionPad2d(1), 10 | nn.Conv2d(in_features, in_features, 3), 11 | nn.BatchNorm2d(in_features), 12 | nn.ReLU(), 13 | nn.ReflectionPad2d(1), 14 | nn.Conv2d(in_features, in_features, 3), 15 | nn.BatchNorm2d(in_features)] 16 | 17 | self.conv_block = nn.Sequential(*conv_block) 18 | 19 | def forward(self, x): 20 | return x + self.conv_block(x) 21 | 22 | 23 | class Generator(nn.Module): 24 | def __init__(self, ngf, n_residual_blocks=9): 25 | super(Generator, self).__init__() 26 | 27 | # Initial convolution block 28 | model = [nn.ReflectionPad2d(3), 29 | nn.Conv2d(3, ngf, 7), 30 | nn.BatchNorm2d(ngf), 31 | nn.ReLU()] 32 | 33 | # Downsampling 34 | in_features = ngf 35 | out_features = in_features * 2 36 | for _ in range(2): 37 | model += [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), 38 | nn.BatchNorm2d(out_features), 39 | nn.ReLU()] 40 | in_features = out_features 41 | out_features = in_features * 2 42 | 43 | # Residual blocks 44 | for _ in range(n_residual_blocks): 45 | model += [ResidualBlock(in_features)] 46 | 47 | # Upsampling 48 | out_features = in_features // 2 49 | for _ in range(2): 50 | model += [nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), 51 | nn.BatchNorm2d(out_features), 52 | nn.ReLU()] 53 | in_features = out_features 54 | out_features = in_features // 2 55 | 56 | # Output layer 57 | model += [nn.ReflectionPad2d(3), 58 | nn.Conv2d(ngf, 3, 7), 59 | nn.Tanh()] 60 | 61 | self.model = nn.Sequential(*model) 62 | 63 | def forward(self, x): 64 | return self.model(x) 65 | 66 | 67 | class Discriminator(nn.Module): 68 | def __init__(self, ndf): 69 | super(Discriminator, self).__init__() 70 | 71 | # A bunch of convolutions one after another 72 | model = [nn.Conv2d(3, ndf, 4, stride=2, padding=1), 73 | nn.LeakyReLU(0.2, inplace=True)] 74 | 75 | model += [nn.Conv2d(ndf, ndf * 2, 4, stride=2, padding=1), 76 | nn.BatchNorm2d(ndf * 2), 77 | nn.LeakyReLU(0.2, inplace=True)] 78 | 79 | model += [nn.Conv2d(ndf * 2, ndf * 4, 4, stride=2, padding=1), 80 | nn.InstanceNorm2d(ndf * 4), 81 | nn.LeakyReLU(0.2, inplace=True)] 82 | 83 | model += [nn.Conv2d(ndf * 4, ndf * 8, 4, padding=1), 84 | nn.InstanceNorm2d(ndf * 8), 85 | nn.LeakyReLU(0.2, inplace=True)] 86 | 87 | # FCN classification layer 88 | model += [nn.Conv2d(ndf * 8, 1, 4, padding=1)] 89 | 90 | self.model = nn.Sequential(*model) 91 | 92 | def forward(self, x): 93 | x = self.model(x) 94 | # Average pooling and flatten 95 | return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1) 96 | -------------------------------------------------------------------------------- /preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HasnainRaz/Fast-AgingGAN/f452c3afe63ccc65ac86b52f2338f8173f023925/preprocessing/__init__.py -------------------------------------------------------------------------------- /preprocessing/preprocess_cacd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from argparse import ArgumentParser 4 | 5 | from scipy.io import loadmat 6 | 7 | parser = ArgumentParser() 8 | parser.add_argument('--image_dir', 9 | default='/Users/hasnainraza/Downloads/CACD2000/', 10 | help='The CACD200 images dir') 11 | parser.add_argument('--metadata', 12 | default='/Users/hasnainraza/Downloads/celebrity2000_meta.mat', 13 | help='The metadata for the CACD2000') 14 | parser.add_argument('--output_dir', 15 | default='/Users/hasnainraza/Downloads/CACDDomains', 16 | help='The directory to write processed images') 17 | 18 | 19 | def main(): 20 | args = parser.parse_args() 21 | metadata = loadmat(args.metadata)['celebrityImageData'][0][0] 22 | ages = [x[0] for x in metadata[0]] 23 | names = [x[0][0] for x in metadata[-1]] 24 | 25 | ages_to_keep_a = [x for x in range(18, 30)] 26 | ages_to_keep_b = [x for x in range(55, 100)] 27 | 28 | domainA, domainB = [], [] 29 | for age, name in zip(ages, names): 30 | if age in ages_to_keep_a: 31 | domainA.append(name) 32 | if age in ages_to_keep_b: 33 | domainB.append(name) 34 | 35 | N = min(len(domainA), len(domainB)) 36 | domainA = domainA[:N] 37 | domainB = domainB[:N] 38 | print(f'Images in A {len(domainA)} and B {len(domainB)}') 39 | 40 | domainA_dir = os.path.join(args.output_dir, 'trainA') 41 | domainB_dir = os.path.join(args.output_dir, 'trainB') 42 | 43 | os.makedirs(domainA_dir, exist_ok=True) 44 | os.makedirs(domainB_dir, exist_ok=True) 45 | 46 | for imageA, imageB in zip(domainA, domainB): 47 | shutil.copy(os.path.join(args.image_dir, imageA), os.path.join(domainA_dir, imageA)) 48 | shutil.copy(os.path.join(args.image_dir, imageB), os.path.join(domainB_dir, imageB)) 49 | 50 | 51 | if __name__ == '__main__': 52 | main() 53 | -------------------------------------------------------------------------------- /preprocessing/preprocess_utk.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from argparse import ArgumentParser 4 | 5 | parser = ArgumentParser() 6 | parser.add_argument('--data_dir', 7 | default='/Users/hasnainraza/Downloads/UTKFace/', 8 | help='The UTKFace aligned images dir') 9 | parser.add_argument('--output_dir', 10 | default='/Users/hasnainraza/Downloads/FacesProcessed', 11 | help='The directory to write processed images') 12 | 13 | 14 | def main(): 15 | args = parser.parse_args() 16 | image_names = [x for x in os.listdir(args.data_dir) if x.endswith('.jpg')] 17 | print(f"Total images found: {len(image_names)}") 18 | 19 | ages = [int(x.split('_')[0]) for x in image_names] 20 | 21 | ages_to_keep_a = [x for x in range(18, 29)] 22 | ages_to_keep_b = [x for x in range(40, 120)] 23 | 24 | domainA, domainB = [], [] 25 | for image_name, age in zip(image_names, ages): 26 | if age in ages_to_keep_a: 27 | domainA.append(image_name) 28 | elif age in ages_to_keep_b: 29 | domainB.append(image_name) 30 | 31 | N = min(len(domainA), len(domainB)) 32 | domainA = domainA[:N] 33 | domainB = domainB[:N] 34 | 35 | print(f"Image in A: {len(domainA)} and B: {len(domainB)}") 36 | 37 | domainA_dir = os.path.join(args.output_dir, 'trainA') 38 | domainB_dir = os.path.join(args.output_dir, 'trainB') 39 | 40 | os.makedirs(domainA_dir, exist_ok=True) 41 | os.makedirs(domainB_dir, exist_ok=True) 42 | 43 | for imageA, imageB in zip(domainA, domainB): 44 | shutil.copy(os.path.join(args.data_dir, imageA), os.path.join(domainA_dir, imageA)) 45 | shutil.copy(os.path.join(args.data_dir, imageB), os.path.join(domainB_dir, imageB)) 46 | 47 | 48 | if __name__ == '__main__': 49 | main() 50 | -------------------------------------------------------------------------------- /pretrained_model/state_dict.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HasnainRaz/Fast-AgingGAN/f452c3afe63ccc65ac86b52f2338f8173f023925/pretrained_model/state_dict.pth -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HasnainRaz/Fast-AgingGAN/f452c3afe63ccc65ac86b52f2338f8173f023925/requirements.txt -------------------------------------------------------------------------------- /timing.py: -------------------------------------------------------------------------------- 1 | from timeit import default_timer as timer 2 | 3 | import torch 4 | 5 | 6 | def time_model(model, input_size): 7 | model.eval() 8 | count, duration = 0, 0 9 | for i in range(50): 10 | start = timer() 11 | _ = model(torch.rand(size=input_size)) 12 | if i < 10: 13 | continue 14 | duration += timer() - start 15 | count += 1 16 | 17 | return duration / count 18 | 19 | 20 | def main(): 21 | from models import Generator 22 | model = Generator(32, 9) 23 | duration = time_model(model, [1, 3, 512, 512]) 24 | print("Time Taken (excluding warmup): ", duration) 25 | 26 | 27 | if __name__ == '__main__': 28 | main() 29 | --------------------------------------------------------------------------------