├── README.md ├── examples ├── dosgan_identity.jpg ├── dosganc_identity.jpg ├── facescrub_intra.png └── season.jpg ├── main_dosgan.py ├── model.py ├── solver_dosgan.py └── split2train_val.py /README.md: -------------------------------------------------------------------------------- 1 | # DosGAN-PyTorch 2 | PyTorch Implementation of [Exploring Explicit Domain Supervision for Latent Space Disentanglement in Unpaired Image-to-Image Translation](https://arxiv.org/abs/1902.03782). 3 | 4 | 5 | 6 | 7 | # Dependency: 8 | Python 2.7 9 | 10 | PyTorch 0.4.0 11 | 12 | # Usage: 13 | ### Multiple identity translation 14 | 1. Downloading Facescrub dataset following http://www.vintage.winklerbros.net/facescrub.html, and save it to `root_dir`. 15 | 16 | 2. Splitting training and testing sets into `train_dir` and `val_dir`: 17 | 18 | `$ python split2train_val.py root_dir train_dir val_dir` 19 | 20 | 3. Train a classifier for domain feature extraction and save it to `dosgan_cls`: 21 | 22 | `$ python main_dosgan.py --mode cls --model_dir dosgan_cls --train_data_path train_dir --test_data_path val_dir` 23 | 24 | 4. Train DosGAN: 25 | 26 | `$ python main_dosgan.py --mode train --model_dir dosgan --cls_save_dir dosgan_cls/models --train_data_path train_dir --test_data_path val_dir` 27 | 28 | 5. Train DosGAN-c: 29 | 30 | `$ python main_dosgan.py --mode train --model_dir dosgan_c --cls_save_dir dosgan_cls/models --non_conditional false --train_data_path train_dir --test_data_path val_dir` 31 | 32 | 6. Test DosGAN: 33 | 34 | `$ python main_dosgan.py --mode test --model_dir dosgan_c --cls_save_dir dosgan_cls/models --train_data_path train_dir --test_data_path val_dir` 35 | 36 | 7. Test DosGAN-c: 37 | 38 | `$ python main_dosgan.py --mode test --model_dir dosgan_c --cls_save_dir dosgan_cls/models --non_conditional false --train_data_path train_dir --test_data_path val_dir` 39 | ### Other mutliple domain translation 40 | 1. For other kinds of dataset, you can place train set and test set like: 41 | 42 | data 43 | ├── YOUR_DATASET_train_dir 44 | ├── damain1 45 | | ├── 1.jpg 46 | | ├── 2.jpg 47 | | └── ... 48 | ├── domain2 49 | | ├── 1.jpg 50 | | ├── 2.jpg 51 | | └── ... 52 | ├── domain3 53 | | ├── 1.jpg 54 | | ├── 2.jpg 55 | | └── ... 56 | ... 57 | 58 | data 59 | ├── YOUR_DATASET_val_dir 60 | ├── damain1 61 | | ├── 1.jpg 62 | | ├── 2.jpg 63 | | └── ... 64 | ├── domain2 65 | | ├── 1.jpg 66 | | ├── 2.jpg 67 | | └── ... 68 | ├── domain3 69 | | ├── 1.jpg 70 | | ├── 2.jpg 71 | | └── ... 72 | ... 73 | 74 | 2. Giving multiple season translation for example ([season dataset](https://github.com/AAnoosheh/ComboGAN)). Train a classifier for season domain feature extraction and save it to `dosgan_season_cls`: 75 | 76 | `$ python main_dosgan.py --mode cls --model_dir dosgan_season_cls --ft_num 64 --c_dim 4 --image_size 256 --train_data_path season_train_dir --test_data_path season_val_dir` 77 | 78 | 3. Train DosGAN for multiple season translation: 79 | 80 | `$ python main_dosgan.py --mode train --model_dir dosgan_season --cls_save_dir dosgan_season_cls/models --ft_num 64 --c_dim 4 --image_size 256 --lambda_fs 0.15 --num_iters 300000 --train_data_path season_train_dir --test_data_path season_val_dir` 81 | 82 | 83 | # Results: 84 | ### 1. Multiple identity translation 85 | 86 | **# Results of DosGAN**: 87 | 88 | 89 | 90 | **# Results of DosGAN-c**: 91 | 92 | 93 | 94 | ### 2. Multiple season translation: 95 | 96 | 97 | -------------------------------------------------------------------------------- /examples/dosgan_identity.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjx-ustc1106/DosGAN-PyTorch/2a2b443089f8de7f15ba48fec4f5cd2121214daa/examples/dosgan_identity.jpg -------------------------------------------------------------------------------- /examples/dosganc_identity.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjx-ustc1106/DosGAN-PyTorch/2a2b443089f8de7f15ba48fec4f5cd2121214daa/examples/dosganc_identity.jpg -------------------------------------------------------------------------------- /examples/facescrub_intra.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjx-ustc1106/DosGAN-PyTorch/2a2b443089f8de7f15ba48fec4f5cd2121214daa/examples/facescrub_intra.png -------------------------------------------------------------------------------- /examples/season.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjx-ustc1106/DosGAN-PyTorch/2a2b443089f8de7f15ba48fec4f5cd2121214daa/examples/season.jpg -------------------------------------------------------------------------------- /main_dosgan.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from solver_dosgan import Solver 4 | from torch.backends import cudnn 5 | from torchvision import transforms, datasets 6 | import torch.utils.data as data 7 | def str2bool(v): 8 | return v.lower() in ('true') 9 | def train_trans(config): 10 | return transforms.Compose([ 11 | transforms.RandomHorizontalFlip(), 12 | transforms.Resize((config.image_size,config.image_size)), 13 | transforms.ToTensor(), 14 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 15 | ]) 16 | def test_trans(config): 17 | return transforms.Compose([ 18 | transforms.Resize((config.image_size,config.image_size)), 19 | transforms.ToTensor(), 20 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 21 | ]) 22 | def main(config): 23 | # For fast training. 24 | cudnn.benchmark = True 25 | 26 | # Create directories if not exist. 27 | config.log_dir = os.path.join(config.model_dir, 'logs') 28 | config.model_save_dir = os.path.join(config.model_dir, 'models') 29 | config.sample_dir = os.path.join(config.model_dir, 'samples') 30 | config.result_dir = os.path.join(config.model_dir, 'results') 31 | 32 | if not os.path.exists(config.log_dir): 33 | os.makedirs(config.log_dir) 34 | if not os.path.exists(config.model_save_dir): 35 | os.makedirs(config.model_save_dir) 36 | if not os.path.exists(config.sample_dir): 37 | os.makedirs(config.sample_dir) 38 | if not os.path.exists(config.result_dir): 39 | os.makedirs(config.result_dir) 40 | 41 | # Data loader. 42 | 43 | train_dataset = datasets.ImageFolder(config.train_data_path, train_trans(config)) 44 | 45 | test_dataset = datasets.ImageFolder(config.test_data_path, test_trans(config)) 46 | data_loader_train = data.DataLoader(dataset=train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, drop_last=True) 47 | print('train dataset loaded') 48 | data_loader_test = data.DataLoader(dataset=test_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, drop_last=True) 49 | print('test dataset loaded') 50 | 51 | 52 | 53 | 54 | 55 | # Solver for training and testing dosgan. 56 | solver = Solver(data_loader_train, data_loader_test, config) 57 | 58 | if config.mode == 'train': 59 | if config.non_conditional: 60 | solver.train() 61 | else: 62 | solver.train_conditional() 63 | elif config.mode == 'test': 64 | solver.test() 65 | elif config.mode == 'cls': 66 | solver.cls() 67 | 68 | 69 | 70 | if __name__ == '__main__': 71 | parser = argparse.ArgumentParser() 72 | 73 | # Model configuration. 74 | parser.add_argument('--c_dim', type=int, default=531, help='number of domains') 75 | parser.add_argument('--d_conv_dim', type=int, default=64, help='number of conv filters in the first layer of D') 76 | parser.add_argument('--d_repeat_num', type=int, default=6, help='number of strided conv layers in D') 77 | parser.add_argument('--n_blocks', type=int, default=0, help='number of res conv layers in C') 78 | parser.add_argument('--image_size', type=int, default=128, help='image resolution') 79 | parser.add_argument('--lambda_rec', type=float, default=10, help='weight for self-reconstruction loss') 80 | parser.add_argument('--lambda_rec2', type=float, default=10, help='weight for cross-reconstruction2 loss') 81 | parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty') 82 | parser.add_argument('--lambda_fs', type=float, default=5, help='weight for fs recontrcution') 83 | parser.add_argument('--ft_num', type=int, default=1024, help='number of ds feature') 84 | 85 | # Training configuration. 86 | parser.add_argument('--batch_size', type=int, default=6, help='mini-batch size') 87 | parser.add_argument('--num_iters', type=int, default=200000, help='number of total iterations for training D') 88 | parser.add_argument('--num_iters_decay', type=int, default=100000, help='number of iterations for decaying lr') 89 | parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for encoder and decoder') 90 | parser.add_argument('--d_lr', type=float, default=0.0001, help='learning rate for D') 91 | parser.add_argument('--n_critic', type=int, default=5, help='number of D updates per each generator update') 92 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer') 93 | parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer') 94 | parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this step') 95 | # Test configuration. 96 | parser.add_argument('--test_iters', type=int, default=200000, help='test model from this step') 97 | parser.add_argument('--non_conditional', type=str2bool, default=True) 98 | 99 | # Miscellaneous. 100 | parser.add_argument('--num_workers', type=int, default=1) 101 | parser.add_argument('--mode', type=str, default='train', choices=['train', 'test', 'cls']) 102 | 103 | # Directories. 104 | parser.add_argument('--train_data_path', type=str, default='data/facescrub_train/') 105 | parser.add_argument('--test_data_path', type=str, default='data/facescrub_test/') 106 | parser.add_argument('--model_dir', type=str, default='dosgan') 107 | parser.add_argument('--cls_save_dir', type=str, default='dosgan_cls/models') 108 | 109 | 110 | # Step size. 111 | parser.add_argument('--log_step', type=int, default=1000) 112 | parser.add_argument('--sample_step', type=int, default=2000) 113 | parser.add_argument('--model_save_step', type=int, default=20000) 114 | parser.add_argument('--lr_update_step', type=int, default=1000) 115 | 116 | config = parser.parse_args() 117 | print(config) 118 | main(config) 119 | 120 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import torchvision.models as models 6 | #from cnn_finetune import make_model 7 | class ResidualBlock(nn.Module): 8 | """Residual Block with instance normalization.""" 9 | def __init__(self, dim_in, dim_out): 10 | super(ResidualBlock, self).__init__() 11 | self.main = nn.Sequential( 12 | nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False), 13 | nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True), 14 | nn.ReLU(inplace=True), 15 | nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False), 16 | nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True)) 17 | 18 | def forward(self, x): 19 | return x + self.main(x) 20 | 21 | 22 | class ResnetEncoder(nn.Module): 23 | def __init__(self, input_nc=3, output_nc=3, n_blocks=3): 24 | assert(n_blocks >= 0) 25 | super(ResnetEncoder, self).__init__() 26 | self.input_nc = input_nc 27 | self.output_nc = output_nc 28 | ngf = 64 29 | padding_type ='reflect' 30 | norm_layer = nn.InstanceNorm2d 31 | use_bias = False 32 | 33 | model = [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=3, 34 | bias=use_bias), 35 | norm_layer(ngf, affine=True, track_running_stats=True), 36 | nn.ReLU(True)] 37 | 38 | n_downsampling = 2 39 | for i in range(n_downsampling): 40 | mult = 2**i 41 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=4, 42 | stride=2, padding=1, bias=use_bias), 43 | norm_layer(ngf * mult * 2, affine=True, track_running_stats=True), 44 | nn.ReLU(True)] 45 | mult = 2**n_downsampling 46 | 47 | for i in range(n_blocks): 48 | model += [ResidualBlock(dim_in=ngf * mult, dim_out=ngf * mult)] 49 | 50 | self.model = nn.Sequential(*model) 51 | 52 | def forward(self, input): 53 | return self.model(input) 54 | class ResnetDecoder(nn.Module): 55 | def __init__(self, input_nc=3, output_nc=3, n_blocks=3, ft_num=16, image_size=128): 56 | assert(n_blocks >= 0) 57 | super(ResnetDecoder, self).__init__() 58 | self.input_nc = input_nc 59 | self.output_nc = output_nc 60 | ngf = 64 61 | ngf_o = ngf*2 62 | padding_type ='reflect' 63 | norm_layer = nn.InstanceNorm2d 64 | use_bias = False 65 | 66 | model = [ ] 67 | n_downsampling = 2 68 | mult = 2**n_downsampling 69 | model_2 = [ ] 70 | model_2 += [nn.Linear(ft_num, ngf * mult * int(image_size / np.power(2, n_downsampling)) * int(image_size / np.power(2, n_downsampling)))] 71 | model_2 += [nn.ReLU(True)] 72 | 73 | model += [nn.Conv2d(ngf * mult, ngf * mult, kernel_size=3, 74 | stride=1, padding=1, bias=use_bias), 75 | norm_layer(ngf * mult, affine=True, track_running_stats=True ), 76 | nn.ReLU(True)] 77 | for i in range(n_blocks): 78 | model += [ResidualBlock(dim_in=ngf * mult, dim_out=ngf * mult)] 79 | 80 | for i in range(n_downsampling): 81 | mult = 2**(n_downsampling - i) 82 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 83 | kernel_size=3, stride=2, 84 | padding=1, output_padding=1, 85 | bias=use_bias), 86 | norm_layer(int(ngf * mult / 2), affine=True, track_running_stats=True), 87 | nn.ReLU(True)] 88 | model += [nn.Conv2d(ngf, 3, kernel_size=7, stride=1, padding=3, bias=False)] 89 | 90 | model += [nn.Tanh()] 91 | self.model = nn.Sequential(*model) 92 | self.model_2 = nn.Sequential(*model_2) 93 | 94 | def forward(self, input1, input2): 95 | out_2 = self.model_2(input2) 96 | out_2 = out_2.view(input1.size(0), input1.size(1), input1.size(2), input1.size(3)) 97 | 98 | return self.model(input1+out_2)# self.model(torch.cat([input1, input2], dim=1)) 99 | 100 | 101 | class Classifier(nn.Module): 102 | """Discriminator network with PatchGAN.""" 103 | def __init__(self, image_size=128, conv_dim=64, c_dim=2, repeat_num=6, ft_num = 16, n_blocks = 3): 104 | super(Classifier, self).__init__() 105 | layers = [] 106 | layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)) 107 | layers.append(nn.LeakyReLU(0.01)) 108 | 109 | curr_dim = conv_dim 110 | for i in range(1, repeat_num): 111 | layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1)) 112 | layers.append(nn.LeakyReLU(0.01)) 113 | curr_dim = curr_dim * 2 114 | for i in range(n_blocks): 115 | layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim)) 116 | kernel_size = int(image_size / np.power(2, repeat_num)) 117 | self.main = nn.Sequential(*layers) 118 | self.conv1 = nn.Sequential(*[nn.Conv2d(curr_dim, ft_num, kernel_size= kernel_size), nn.LeakyReLU(0.01)]) 119 | 120 | self.conv2 = nn.Conv2d(ft_num, c_dim, kernel_size=1, bias=False) 121 | 122 | def forward(self, x): 123 | h = self.main(x) 124 | out_src = self.conv1(h) 125 | out_cls = self.conv2(out_src) 126 | return out_src.view(out_src.size(0), out_src.size(1)), out_cls.view(out_cls.size(0), out_cls.size(1)) 127 | 128 | class Discriminator(nn.Module): 129 | """Discriminator network with PatchGAN.""" 130 | def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6, ft_num = 16): 131 | super(Discriminator, self).__init__() 132 | layers = [] 133 | layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)) 134 | layers.append(nn.LeakyReLU(0.01)) 135 | 136 | curr_dim = conv_dim 137 | for i in range(1, repeat_num): 138 | layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1)) 139 | layers.append(nn.LeakyReLU(0.01)) 140 | curr_dim = curr_dim * 2 141 | 142 | kernel_size = int(image_size / np.power(2, repeat_num)) 143 | self.main = nn.Sequential(*layers) 144 | self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False) 145 | self.conv2 = nn.Conv2d(curr_dim, ft_num, kernel_size=kernel_size, bias=False) 146 | 147 | def forward(self, x): 148 | h = self.main(x) 149 | out_src = self.conv1(h) 150 | out_cls = self.conv2(h) 151 | return out_src, out_cls.view(out_cls.size(0), out_cls.size(1)) 152 | 153 | -------------------------------------------------------------------------------- /solver_dosgan.py: -------------------------------------------------------------------------------- 1 | from model import * 2 | from torch.autograd import Variable 3 | from torchvision.utils import save_image 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import os 8 | import time 9 | import datetime 10 | import itertools 11 | def accuracy(output, target, topk=(1,)): 12 | """Computes the precision@k for the specified values of k""" 13 | if len(output[0]) < topk[1]: 14 | topk = (1, len(output[0])) 15 | maxk = max(topk) 16 | batch_size = target.size(0) 17 | 18 | _, pred = output.topk(maxk, 1, True, True) 19 | pred = pred.t() 20 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 21 | 22 | res = [] 23 | for k in topk: 24 | correct_k = correct[:k].view(-1).float().sum(0) 25 | res.append(correct_k.mul_(100.0 / batch_size)) 26 | return res 27 | 28 | class Solver(object): 29 | 30 | def __init__(self, data_loader, data_loader_test, config): 31 | 32 | # Data loader. 33 | self.data_loader = data_loader 34 | self.data_loader_test = data_loader_test 35 | 36 | # Model configurations and loss weights. 37 | self.ft_num = config.ft_num 38 | self.c_dim = config.c_dim 39 | self.d_conv_dim = config.d_conv_dim 40 | self.d_repeat_num = config.d_repeat_num 41 | self.n_blocks = config.n_blocks 42 | self.lambda_rec = config.lambda_rec 43 | self.lambda_rec2 = config.lambda_rec2 44 | self.lambda_gp = config.lambda_gp 45 | self.lambda_fs = config.lambda_fs 46 | 47 | # Training configurations. 48 | self.batch_size = config.batch_size 49 | self.num_iters = config.num_iters 50 | self.num_iters_decay = config.num_iters_decay 51 | self.g_lr = config.g_lr 52 | self.d_lr = config.d_lr 53 | self.n_critic = config.n_critic 54 | self.beta1 = config.beta1 55 | self.beta2 = config.beta2 56 | self.resume_iters = config.resume_iters 57 | self.image_size = config.image_size 58 | 59 | # Test configurations. 60 | self.test_iters = config.test_iters 61 | self.non_conditional = config.non_conditional 62 | 63 | # Miscellaneous. 64 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 65 | 66 | # Directories. 67 | self.log_dir = config.log_dir 68 | self.sample_dir = config.sample_dir 69 | self.model_save_dir = config.model_save_dir 70 | self.cls_save_dir = config.cls_save_dir 71 | self.result_dir = config.result_dir 72 | 73 | # Step size. 74 | self.log_step = config.log_step 75 | self.sample_step = config.sample_step 76 | self.model_save_step = config.model_save_step 77 | self.lr_update_step = config.lr_update_step 78 | 79 | # Build the model. 80 | self.build_model() 81 | 82 | def build_model(self): 83 | """Initializing networks.""" 84 | 85 | self.encoder = ResnetEncoder() 86 | self.decoder = ResnetDecoder(ft_num=self.ft_num,image_size=self.image_size) 87 | self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num, ft_num = self.ft_num) 88 | self.C = Classifier(image_size=self.image_size, c_dim = self.c_dim, ft_num = self.ft_num, n_blocks = self.n_blocks) 89 | 90 | self.g_optimizer = torch.optim.Adam(itertools.chain(self.encoder.parameters(), self.decoder.parameters()), self.g_lr, [self.beta1, self.beta2]) 91 | self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) 92 | self.c_optimizer = torch.optim.Adam(self.C.parameters(), self.d_lr, [self.beta1, self.beta2]) 93 | self.encoder.to(self.device) 94 | self.decoder.to(self.device) 95 | self.D.to(self.device) 96 | self.C.to(self.device) 97 | 98 | 99 | def restore_model(self, resume_iters): 100 | """Restore the trained networks.""" 101 | 102 | print('Loading the trained models from step {}...'.format(resume_iters)) 103 | encoder_path = os.path.join(self.model_save_dir, '{}-encoder.ckpt'.format(resume_iters)) 104 | decoder_path = os.path.join(self.model_save_dir, '{}-decoder.ckpt'.format(resume_iters)) 105 | D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters)) 106 | self.encoder.load_state_dict(torch.load(encoder_path, map_location=lambda storage, loc: storage)) 107 | self.decoder.load_state_dict(torch.load(decoder_path, map_location=lambda storage, loc: storage)) 108 | self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage)) 109 | 110 | 111 | 112 | def update_lr(self, g_lr, d_lr): 113 | """Decay learning rates.""" 114 | for param_group in self.g_optimizer.param_groups: 115 | param_group['lr'] = g_lr 116 | for param_group in self.d_optimizer.param_groups: 117 | param_group['lr'] = d_lr 118 | 119 | def reset_grad(self): 120 | """Reset the gradient buffers.""" 121 | self.g_optimizer.zero_grad() 122 | self.d_optimizer.zero_grad() 123 | 124 | def denorm(self, x): 125 | """Convert the range from [-1, 1] to [0, 1].""" 126 | out = (x + 1) / 2 127 | return out.clamp_(0, 1) 128 | 129 | def gradient_penalty(self, y, x): 130 | """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2.""" 131 | weight = torch.ones(y.size()).to(self.device) 132 | dydx = torch.autograd.grad(outputs=y, 133 | inputs=x, 134 | grad_outputs=weight, 135 | retain_graph=True, 136 | create_graph=True, 137 | only_inputs=True)[0] 138 | 139 | dydx = dydx.view(dydx.size(0), -1) 140 | dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1)) 141 | return torch.mean((dydx_l2norm-1)**2) 142 | 143 | 144 | 145 | 146 | def classification_loss(self, logit, target): 147 | return F.cross_entropy(logit, target) 148 | 149 | 150 | 151 | def train(self): 152 | # Load pre-trained classification network 153 | cls_iter = 160000 154 | C_path = os.path.join(self.cls_save_dir, '{}-C.ckpt'.format(cls_iter)) 155 | self.C.load_state_dict(torch.load(C_path, map_location=lambda storage, loc: storage)) 156 | 157 | # Set data loader. 158 | data_loader = self.data_loader 159 | 160 | # Set learning rate 161 | g_lr = self.g_lr 162 | d_lr = self.d_lr 163 | 164 | # Start training from scratch or resume training. 165 | start_iters = 0 166 | if self.resume_iters: 167 | start_iters = self.resume_iters 168 | self.restore_model(self.resume_iters) 169 | 170 | 171 | empty = torch.FloatTensor(1, 3, self.image_size, self.image_size).to(self.device) 172 | empty.fill_(1) 173 | # Calculate domain feature centroid of each domain 174 | domain_sf_num = torch.FloatTensor(self.c_dim, 1).to(self.device) 175 | domain_sf_num.fill_(0.00000001) 176 | domain_sf = torch.FloatTensor(self.c_dim, self.ft_num).to(self.device) 177 | domain_sf.fill_(0) 178 | with torch.no_grad(): 179 | for indx, (x_real, label_org) in enumerate(data_loader): 180 | x_real = x_real.to(self.device) 181 | label_org = label_org.to(self.device) 182 | 183 | x_ds, x_cls = self.C(x_real) 184 | for j in range(label_org.size(0)): 185 | domain_sf[label_org[j], :] = (domain_sf[label_org[j], :] + x_ds[j] / domain_sf_num[label_org[j], :]) * ( 186 | domain_sf_num[label_org[j], :] / (domain_sf_num[label_org[j], :] + 1)) 187 | domain_sf_num[label_org[j], :] += 1 188 | 189 | start_time = time.time() 190 | # Start training. 191 | for i in range(start_iters, self.num_iters): 192 | 193 | # Fetch real images and labels. 194 | try: 195 | x_real, label_org = next(data_iter) 196 | except: 197 | data_iter = iter(data_loader) 198 | x_real, label_org = next(data_iter) 199 | 200 | x_real = x_real.to(self.device) 201 | label_org = label_org.to(self.device) 202 | 203 | x_ds, x_cls = self.C(x_real) #obtain domain feature for each real image 204 | 205 | #obtain domain feature centroid for each real image 206 | x_ds_mean = torch.FloatTensor(label_org.size(0), self.ft_num).to(self.device) 207 | for j in range(label_org.size(0)): 208 | x_ds_mean[j] = domain_sf[label_org[j]:label_org[j] + 1, :] 209 | 210 | # random target 211 | rand_idx = torch.randperm(label_org.size(0)) 212 | 213 | trg_dst = x_ds_mean[rand_idx] 214 | trg_ds = trg_dst.clone() 215 | 216 | # =================================================================================== # 217 | # 2. Train the discriminator # 218 | # =================================================================================== # 219 | 220 | # Compute loss with real images. 221 | out_src, out_cls = self.D(x_real) 222 | d_loss_real = - torch.mean(out_src) 223 | d_loss_dsrec = torch.mean( 224 | torch.abs(x_ds.detach() - out_cls)) 225 | 226 | # Compute loss with fake images. 227 | x_fake = self.decoder(self.encoder(x_real), trg_ds) 228 | out_src, out_cls = self.D(x_fake.detach()) 229 | d_loss_fake = torch.mean(out_src) 230 | 231 | # Compute loss for gradient penalty. 232 | alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device) 233 | x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True) 234 | out_src, _ = self.D(x_hat) 235 | d_loss_gp = self.gradient_penalty(out_src, x_hat) 236 | 237 | # Backward and optimize. 238 | d_loss = d_loss_real + d_loss_fake + self.lambda_fs * d_loss_dsrec + self.lambda_gp * d_loss_gp 239 | self.reset_grad() 240 | d_loss.backward() 241 | self.d_optimizer.step() 242 | 243 | # Logging. 244 | loss = {} 245 | loss['D/loss_real'] = d_loss_real.item() 246 | loss['D/loss_fake'] = d_loss_fake.item() 247 | loss['D/loss_dsrec'] = d_loss_dsrec.item() 248 | loss['D/loss_gp'] = d_loss_gp.item() 249 | 250 | # =================================================================================== # 251 | # 3. Train the encoder and decoder # 252 | # =================================================================================== # 253 | 254 | if (i + 1) % self.n_critic == 0: 255 | # Original-to-target domain. 256 | x_di = self.encoder(x_real) 257 | 258 | x_fake = self.decoder(x_di, trg_ds) 259 | x_reconst1 = self.decoder(x_di, x_ds) 260 | out_src, out_cls = self.D(x_fake) 261 | g_loss_fake = - torch.mean(out_src) 262 | g_loss_dsrec = torch.mean( 263 | torch.abs(trg_ds.detach() - out_cls)) 264 | 265 | # Target-to-original domain. 266 | x_fake_di = self.encoder(x_fake) 267 | 268 | x_reconst2 = self.decoder(x_fake_di, x_ds) 269 | 270 | g_loss_rec = torch.mean(torch.abs(x_real - x_reconst1)) 271 | 272 | g_loss_rec2 = torch.mean(torch.abs(x_real - x_reconst2)) 273 | 274 | # Backward and optimize. 275 | g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_rec2 * g_loss_rec2 + self.lambda_fs * g_loss_dsrec 276 | self.reset_grad() 277 | g_loss.backward() 278 | self.g_optimizer.step() 279 | 280 | # Logging. 281 | loss['G/loss_fake'] = g_loss_fake.item() 282 | loss['G/loss_rec'] = g_loss_rec.item() 283 | loss['G/loss_rec2'] = g_loss_rec2.item() 284 | loss['G/loss_dsrec'] = g_loss_dsrec.item() 285 | 286 | # =================================================================================== # 287 | # 4. Miscellaneous # 288 | # =================================================================================== # 289 | 290 | # Print out training information. 291 | if (i + 1) % self.log_step == 0: 292 | et = time.time() - start_time 293 | et = str(datetime.timedelta(seconds=et))[:-7] 294 | log = "Elapsed [{}], Iteration [{}/{}]".format(et, i + 1, self.num_iters) 295 | for tag, value in loss.items(): 296 | log += ", {}: {:.4f}".format(tag, value) 297 | print(log) 298 | 299 | 300 | # Translate fixed images for debugging. 301 | if (i) % self.sample_step == 0: 302 | with torch.no_grad(): 303 | out_A2B_results = [empty] 304 | 305 | for idx1 in range(label_org.size(0)): 306 | out_A2B_results.append(x_real[idx1:idx1 + 1]) 307 | 308 | for idx2 in range(label_org.size(0)): 309 | out_A2B_results.append(x_real[idx2:idx2 + 1]) 310 | 311 | for idx1 in range(label_org.size(0)): 312 | x_fake = self.decoder(self.encoder(x_real[idx2:idx2 + 1]), x_ds_mean[idx1:idx1 + 1]) 313 | out_A2B_results.append(x_fake) 314 | results_concat = torch.cat(out_A2B_results) 315 | x_AB_results_path = os.path.join(self.sample_dir, '{}_x_AB_results.jpg'.format(i + 1)) 316 | save_image(self.denorm(results_concat.data.cpu()), x_AB_results_path, nrow=label_org.size(0) + 1, 317 | padding=0) 318 | print('Saved real and fake images into {}...'.format(x_AB_results_path)) 319 | 320 | 321 | # Save model checkpoints. 322 | if (i + 1) % self.model_save_step == 0: 323 | encoder_path = os.path.join(self.model_save_dir, '{}-encoder.ckpt'.format(i + 1)) 324 | decoder_path = os.path.join(self.model_save_dir, '{}-decoder.ckpt'.format(i + 1)) 325 | D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i + 1)) 326 | torch.save(self.encoder.state_dict(), encoder_path) 327 | torch.save(self.decoder.state_dict(), decoder_path) 328 | torch.save(self.D.state_dict(), D_path) 329 | print('Saved model checkpoints into {}...'.format(self.model_save_dir)) 330 | 331 | # Decay learning rates. 332 | if (i + 1) % self.lr_update_step == 0 and (i + 1) > (self.num_iters - self.num_iters_decay): 333 | g_lr -= (self.g_lr / float(self.num_iters_decay)) 334 | d_lr -= (self.d_lr / float(self.num_iters_decay)) 335 | self.update_lr(g_lr, d_lr) 336 | print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) 337 | def train_conditional(self): 338 | # Load pre-trained classification network 339 | cls_iter = 160000 340 | C_path = os.path.join(self.cls_save_dir, '{}-C.ckpt'.format(cls_iter)) 341 | self.C.load_state_dict(torch.load(C_path, map_location=lambda storage, loc: storage)) 342 | 343 | # Set data loader. 344 | data_loader = self.data_loader 345 | 346 | # Set learning rate 347 | g_lr = self.g_lr 348 | d_lr = self.d_lr 349 | 350 | # Start training from scratch or resume training. 351 | start_iters = 0 352 | if self.resume_iters: 353 | start_iters = self.resume_iters 354 | self.restore_model(self.resume_iters) 355 | 356 | 357 | empty = torch.FloatTensor(1, 3, self.image_size, self.image_size).to(self.device) 358 | empty.fill_(1) 359 | 360 | 361 | start_time = time.time() 362 | # Start training. 363 | for i in range(start_iters, self.num_iters): 364 | 365 | # Fetch real images and labels. 366 | try: 367 | x_real, label_org = next(data_iter) 368 | except: 369 | data_iter = iter(data_loader) 370 | x_real, label_org = next(data_iter) 371 | 372 | x_real = x_real.to(self.device) 373 | label_org = label_org.to(self.device) 374 | 375 | x_ds, x_cls = self.C(x_real) # obtain domain feature for each real image 376 | 377 | # random target 378 | rand_idx = torch.randperm(label_org.size(0)) 379 | 380 | 381 | trg_dst = x_ds[rand_idx] 382 | trg_ds = trg_dst.clone() 383 | 384 | 385 | 386 | 387 | # =================================================================================== # 388 | # 2. Train the discriminator # 389 | # =================================================================================== # 390 | 391 | # Compute loss with real images. 392 | out_src, out_cls = self.D(x_real) 393 | 394 | d_loss_real = - torch.mean(out_src) 395 | 396 | d_loss_dsrec = torch.mean(torch.abs(x_ds.detach() - out_cls)) 397 | 398 | # Compute loss with fake images. 399 | x_fake = self.decoder(self.encoder(x_real), trg_ds) 400 | 401 | out_src, out_cls = self.D(x_fake.detach()) 402 | 403 | 404 | d_loss_fake = torch.mean(out_src) 405 | 406 | # Compute loss for gradient penalty. 407 | alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device) 408 | x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True) 409 | out_src, _ = self.D(x_hat) 410 | d_loss_gp = self.gradient_penalty(out_src, x_hat) 411 | 412 | # Backward and optimize. 413 | d_loss = d_loss_real + d_loss_fake + self.lambda_fs * d_loss_dsrec + self.lambda_gp * d_loss_gp 414 | self.reset_grad() 415 | d_loss.backward() 416 | self.d_optimizer.step() 417 | 418 | # Logging. 419 | loss = {} 420 | loss['D/loss_real'] = d_loss_real.item() 421 | loss['D/loss_fake'] = d_loss_fake.item() 422 | loss['D/loss_dsrec'] = d_loss_dsrec.item() 423 | loss['D/loss_gp'] = d_loss_gp.item() 424 | 425 | # =================================================================================== # 426 | # 3. Train the encoder and decoder # 427 | # =================================================================================== # 428 | 429 | if (i + 1) % self.n_critic == 0: 430 | # Original-to-target domain. 431 | x_di = self.encoder(x_real) 432 | 433 | x_fake = self.decoder(x_di, trg_ds) 434 | x_reconst1 = self.decoder(x_di, x_ds) 435 | 436 | 437 | out_src, out_cls = self.D(x_fake) 438 | 439 | g_loss_fake = - torch.mean(out_src) 440 | g_loss_dsrec = torch.mean(torch.abs(trg_ds.detach() - out_cls)) 441 | 442 | # Target-to-original domain. 443 | x_fake_di = self.encoder(x_fake) 444 | 445 | x_reconst2 = self.decoder(x_fake_di, x_ds) 446 | 447 | g_loss_rec = torch.mean(torch.abs(x_real - x_reconst1)) 448 | 449 | g_loss_rec2 = torch.mean(torch.abs(x_real - x_reconst2)) 450 | 451 | # Backward and optimize. 452 | g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_rec2 * g_loss_rec2 + self.lambda_fs * g_loss_dsrec 453 | self.reset_grad() 454 | g_loss.backward() 455 | self.g_optimizer.step() 456 | 457 | # Logging. 458 | loss['G/loss_fake'] = g_loss_fake.item() 459 | loss['G/loss_rec'] = g_loss_rec.item() 460 | loss['G/loss_rec2'] = g_loss_rec2.item() 461 | loss['G/loss_dsrec'] = g_loss_dsrec.item() 462 | 463 | # =================================================================================== # 464 | # 4. Miscellaneous # 465 | # =================================================================================== # 466 | 467 | # Print out training information. 468 | if (i + 1) % self.log_step == 0: 469 | et = time.time() - start_time 470 | et = str(datetime.timedelta(seconds=et))[:-7] 471 | log = "Elapsed [{}], Iteration [{}/{}]".format(et, i + 1, self.num_iters) 472 | for tag, value in loss.items(): 473 | log += ", {}: {:.4f}".format(tag, value) 474 | print(log) 475 | 476 | 477 | # Translate fixed images for debugging. 478 | if (i) % self.sample_step == 0: 479 | with torch.no_grad(): 480 | out_A2B_results = [empty] 481 | 482 | for idx1 in range(label_org.size(0)): 483 | out_A2B_results.append(x_real[idx1:idx1 + 1]) 484 | 485 | for idx2 in range(label_org.size(0)): 486 | out_A2B_results.append(x_real[idx2:idx2 + 1]) 487 | 488 | for idx1 in range(label_org.size(0)): 489 | x_fake = self.decoder(self.encoder(x_real[idx2:idx2 + 1]), x_ds[idx1:idx1 + 1]) 490 | out_A2B_results.append(x_fake) 491 | results_concat = torch.cat(out_A2B_results) 492 | x_AB_results_path = os.path.join(self.sample_dir, '{}_x_AB_results.jpg'.format(i + 1)) 493 | save_image(self.denorm(results_concat.data.cpu()), x_AB_results_path, nrow=label_org.size(0) + 1, 494 | padding=0) 495 | print('Saved real and fake images into {}...'.format(x_AB_results_path)) 496 | 497 | 498 | # Save model checkpoints. 499 | if (i + 1) % self.model_save_step == 0: 500 | encoder_path = os.path.join(self.model_save_dir, '{}-encoder.ckpt'.format(i + 1)) 501 | decoder_path = os.path.join(self.model_save_dir, '{}-decoder.ckpt'.format(i + 1)) 502 | D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i + 1)) 503 | torch.save(self.encoder.state_dict(), encoder_path) 504 | torch.save(self.decoder.state_dict(), decoder_path) 505 | torch.save(self.D.state_dict(), D_path) 506 | print('Saved model checkpoints into {}...'.format(self.model_save_dir)) 507 | 508 | # Decay learning rates. 509 | if (i + 1) % self.lr_update_step == 0 and (i + 1) > (self.num_iters - self.num_iters_decay): 510 | g_lr -= (self.g_lr / float(self.num_iters_decay)) 511 | d_lr -= (self.d_lr / float(self.num_iters_decay)) 512 | self.update_lr(g_lr, d_lr) 513 | print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) 514 | def cls(self): 515 | """Train a domain classifier""" 516 | # Set data loader. 517 | data_loader = self.data_loader 518 | 519 | 520 | # Start training from scratch or resume training. 521 | start_iters = 0 522 | 523 | # Start training. 524 | start_time = time.time() 525 | 526 | for i in range(start_iters, self.num_iters): 527 | 528 | try: 529 | x_real, label_org = next(data_iter) 530 | 531 | except: 532 | data_iter = iter(data_loader) 533 | x_real, label_org = next(data_iter) 534 | 535 | x_real = x_real.to(self.device) # Input images. 536 | label_org = label_org.to(self.device) # Labels for computing classification loss. 537 | 538 | 539 | # =================================================================================== # 540 | # Train the classifier # 541 | # =================================================================================== # 542 | 543 | out_src, out_cls = self.C(x_real) 544 | d_loss_cls = self.classification_loss(out_cls, label_org) 545 | 546 | # Backward and optimize. 547 | d_loss = d_loss_cls 548 | self.c_optimizer.zero_grad() 549 | d_loss.backward() 550 | self.c_optimizer.step() 551 | 552 | 553 | # Logging. 554 | loss = {} 555 | 556 | loss['D/loss_cls'] = d_loss_cls.item() 557 | 558 | 559 | 560 | # Print out training information. 561 | if (i+1) % self.log_step == 0: 562 | et = time.time() - start_time 563 | et = str(datetime.timedelta(seconds=et))[:-7] 564 | prec1, prec5 = accuracy(out_cls.data, label_org.data, topk=(1, 5)) 565 | loss['prec1'] = prec1 566 | loss['prec5'] = prec5 567 | log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters) 568 | for tag, value in loss.items(): 569 | log += ", {}: {:.4f}".format(tag, value) 570 | print(log) 571 | 572 | 573 | 574 | 575 | # Save model checkpoints. 576 | if (i+1) % self.model_save_step == 0: 577 | C_path = os.path.join(self.model_save_dir, '{}-C.ckpt'.format(i+1)) 578 | torch.save(self.C.state_dict(), C_path) 579 | print('Saved model checkpoints into {}...'.format(self.model_save_dir)) 580 | 581 | 582 | def label2onehot(self, labels, dim): 583 | """Convert label indices to one-hot vectors.""" 584 | batch_size = labels.size(0) 585 | out = torch.zeros(batch_size, dim) 586 | out[np.arange(batch_size), labels.long()] = 1 587 | return out 588 | def create_labels(self, c_org, c_dim=5): 589 | """Generate target domain labels for debugging and testing.""" 590 | # Get hair color indices. 591 | hair_color_indices = [] 592 | for i in range(c_dim): 593 | hair_color_indices.append(i) 594 | 595 | c_trg_list = [] 596 | for i in range(c_dim): 597 | c_trg = c_org.clone() 598 | if i in hair_color_indices: # Set one hair color to 1 and the rest to 0. 599 | c_trg[:, i] = 1 600 | for j in hair_color_indices: 601 | if j != i: 602 | c_trg[:, j] = 0 603 | 604 | c_trg_list.append(c_trg.to(self.device)) 605 | return c_trg_list 606 | def test(self): 607 | """Translate images with trained DosGAN.""" 608 | # Load the trained networks. 609 | cls_iter = 160000 610 | C_path = os.path.join(self.cls_save_dir, '{}-C.ckpt'.format(cls_iter)) 611 | self.C.load_state_dict(torch.load(C_path, map_location=lambda storage, loc: storage)) 612 | self.restore_model(self.test_iters) 613 | 614 | # Set data loader. 615 | data_loader = self.data_loader 616 | data_loader_test = self.data_loader_test 617 | step = 0 618 | empty = torch.FloatTensor(1, 3,self.image_size,self.image_size).to(self.device) 619 | empty.fill_(1) 620 | domain_sf_num = torch.FloatTensor(self.c_dim, 1).to(self.device) 621 | domain_sf_num.fill_(0.00000001) 622 | domain_sf = torch.FloatTensor(self.c_dim, self.ft_num).to(self.device) 623 | domain_sf.fill_(0) 624 | with torch.no_grad(): 625 | if self.non_conditional: # non_conditional testing 626 | for indx, (x_real, label_org) in enumerate(data_loader): 627 | x_real = x_real.to(self.device) # Input images. 628 | label_org = label_org.to(self.device) 629 | 630 | x_ds, x_cls = self.C(x_real) 631 | for j in range(label_org.size(0)): 632 | domain_sf[label_org[j],:] = (domain_sf[label_org[j],:] + x_ds[j]/domain_sf_num[label_org[j],:])*(domain_sf_num[label_org[j],:]/(domain_sf_num[label_org[j],:]+1)) 633 | domain_sf_num[label_org[j],:] += 1 634 | step = step +1 635 | 636 | for indx, (x_real, label_org) in enumerate(data_loader_test): 637 | x_real = x_real.to(self.device) # Input images. 638 | 639 | x_ds, x_cls = self.C(x_real) 640 | c_org = self.label2onehot(label_org, self.c_dim) 641 | 642 | 643 | c_org = c_org.to(self.device) 644 | label_org = label_org.to(self.device) 645 | 646 | c_fixed_list = self.create_labels(c_org, self.c_dim) 647 | 648 | x_fake_list = [x_real] 649 | for c_fixed in c_fixed_list: 650 | _, out_pred_fixed = torch.max(c_fixed.data, 1) 651 | x_ds_m = x_ds.clone() 652 | for k in range(label_org.size(0)): 653 | x_ds_m[k,:] = domain_sf[out_pred_fixed[k],:] 654 | x_fake = self.decoder(self.encoder(x_real), x_ds_m) 655 | x_fake_list.append(x_fake) 656 | 657 | x_concat = torch.cat(x_fake_list, dim=3) 658 | sample_path = os.path.join(self.result_dir, '{}-images.jpg'.format(indx+1)) 659 | save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0) 660 | print('Saved real and fake images into {}...'.format(sample_path)) 661 | else: # conditional image translation testing 662 | for indx, (x_real, label_org) in enumerate(data_loader_test): 663 | x_real = x_real.to(self.device) # Input images. 664 | label_org = label_org.to(self.device) 665 | x_ds, x_cls = self.C(x_real) 666 | 667 | 668 | out_A2B_results = [empty] 669 | 670 | for j in range(label_org.size(0)): 671 | out_A2B_results.append(x_real[j:j+1]) 672 | 673 | for i in range(label_org.size(0)): 674 | out_A2B_results.append(x_real[i:i+1]) 675 | 676 | for j in range(label_org.size(0)): 677 | x_fake = self.decoder(self.encoder(x_real[i:i+1]), x_ds[j:j+1]) 678 | out_A2B_results.append(x_fake) 679 | results_concat = torch.cat(out_A2B_results) 680 | x_AB_results_path = os.path.join(self.result_dir, '{}_x_AB_results.jpg'.format(indx+1)) 681 | save_image(self.denorm(results_concat.data.cpu()), x_AB_results_path, nrow=label_org.size(0)+1,padding=0) 682 | print('Saved real and fake images into {}...'.format(x_AB_results_path)) 683 | 684 | -------------------------------------------------------------------------------- /split2train_val.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | import sys 3 | import os 4 | import shutil 5 | 6 | def split2train_val(root_dir, train_dir, val_dir): 7 | try: 8 | os.mkdir(train_dir) 9 | os.mkdir(val_dir) 10 | except: 11 | print ("train_dir and val_dir have existed!") 12 | 13 | person_names = os.listdir(root_dir) 14 | index = 0 15 | 16 | size_list = list() 17 | 18 | for person_name in person_names: 19 | 20 | os.mkdir(os.path.join(train_dir, person_name)) 21 | os.mkdir(os.path.join(val_dir, person_name)) 22 | 23 | index += 1 24 | img_names = os.listdir(os.path.join(root_dir, person_name)) 25 | n_face = len(img_names) 26 | n_test = int(n_face/10) 27 | n_train = n_face - n_test 28 | 29 | print (len(person_names), str(index), person_name, str(n_train), str(n_test)) 30 | 31 | for i in range(len(img_names)): 32 | img_name = img_names[i] 33 | source_path = os.path.join(root_dir, person_name, img_name) 34 | if i < n_train: 35 | target_path = os.path.join(train_dir, person_name, img_name) 36 | else: 37 | target_path = os.path.join(val_dir, person_name, img_name) 38 | 39 | shutil.copy(source_path, target_path) 40 | 41 | 42 | if __name__ == "__main__": 43 | root_dir = sys.argv[1] 44 | train_dir = sys.argv[2] 45 | val_dir = sys.argv[3] 46 | split2train_val(root_dir, train_dir, val_dir) 47 | --------------------------------------------------------------------------------