├── README.md
├── examples
├── dosgan_identity.jpg
├── dosganc_identity.jpg
├── facescrub_intra.png
└── season.jpg
├── main_dosgan.py
├── model.py
├── solver_dosgan.py
└── split2train_val.py
/README.md:
--------------------------------------------------------------------------------
1 | # DosGAN-PyTorch
2 | PyTorch Implementation of [Exploring Explicit Domain Supervision for Latent Space Disentanglement in Unpaired Image-to-Image Translation](https://arxiv.org/abs/1902.03782).
3 |
4 |
5 |
6 |
7 | # Dependency:
8 | Python 2.7
9 |
10 | PyTorch 0.4.0
11 |
12 | # Usage:
13 | ### Multiple identity translation
14 | 1. Downloading Facescrub dataset following http://www.vintage.winklerbros.net/facescrub.html, and save it to `root_dir`.
15 |
16 | 2. Splitting training and testing sets into `train_dir` and `val_dir`:
17 |
18 | `$ python split2train_val.py root_dir train_dir val_dir`
19 |
20 | 3. Train a classifier for domain feature extraction and save it to `dosgan_cls`:
21 |
22 | `$ python main_dosgan.py --mode cls --model_dir dosgan_cls --train_data_path train_dir --test_data_path val_dir`
23 |
24 | 4. Train DosGAN:
25 |
26 | `$ python main_dosgan.py --mode train --model_dir dosgan --cls_save_dir dosgan_cls/models --train_data_path train_dir --test_data_path val_dir`
27 |
28 | 5. Train DosGAN-c:
29 |
30 | `$ python main_dosgan.py --mode train --model_dir dosgan_c --cls_save_dir dosgan_cls/models --non_conditional false --train_data_path train_dir --test_data_path val_dir`
31 |
32 | 6. Test DosGAN:
33 |
34 | `$ python main_dosgan.py --mode test --model_dir dosgan_c --cls_save_dir dosgan_cls/models --train_data_path train_dir --test_data_path val_dir`
35 |
36 | 7. Test DosGAN-c:
37 |
38 | `$ python main_dosgan.py --mode test --model_dir dosgan_c --cls_save_dir dosgan_cls/models --non_conditional false --train_data_path train_dir --test_data_path val_dir`
39 | ### Other mutliple domain translation
40 | 1. For other kinds of dataset, you can place train set and test set like:
41 |
42 | data
43 | ├── YOUR_DATASET_train_dir
44 | ├── damain1
45 | | ├── 1.jpg
46 | | ├── 2.jpg
47 | | └── ...
48 | ├── domain2
49 | | ├── 1.jpg
50 | | ├── 2.jpg
51 | | └── ...
52 | ├── domain3
53 | | ├── 1.jpg
54 | | ├── 2.jpg
55 | | └── ...
56 | ...
57 |
58 | data
59 | ├── YOUR_DATASET_val_dir
60 | ├── damain1
61 | | ├── 1.jpg
62 | | ├── 2.jpg
63 | | └── ...
64 | ├── domain2
65 | | ├── 1.jpg
66 | | ├── 2.jpg
67 | | └── ...
68 | ├── domain3
69 | | ├── 1.jpg
70 | | ├── 2.jpg
71 | | └── ...
72 | ...
73 |
74 | 2. Giving multiple season translation for example ([season dataset](https://github.com/AAnoosheh/ComboGAN)). Train a classifier for season domain feature extraction and save it to `dosgan_season_cls`:
75 |
76 | `$ python main_dosgan.py --mode cls --model_dir dosgan_season_cls --ft_num 64 --c_dim 4 --image_size 256 --train_data_path season_train_dir --test_data_path season_val_dir`
77 |
78 | 3. Train DosGAN for multiple season translation:
79 |
80 | `$ python main_dosgan.py --mode train --model_dir dosgan_season --cls_save_dir dosgan_season_cls/models --ft_num 64 --c_dim 4 --image_size 256 --lambda_fs 0.15 --num_iters 300000 --train_data_path season_train_dir --test_data_path season_val_dir`
81 |
82 |
83 | # Results:
84 | ### 1. Multiple identity translation
85 |
86 | **# Results of DosGAN**:
87 |
88 |
89 |
90 | **# Results of DosGAN-c**:
91 |
92 |
93 |
94 | ### 2. Multiple season translation:
95 |
96 |
97 |
--------------------------------------------------------------------------------
/examples/dosgan_identity.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linjx-ustc1106/DosGAN-PyTorch/2a2b443089f8de7f15ba48fec4f5cd2121214daa/examples/dosgan_identity.jpg
--------------------------------------------------------------------------------
/examples/dosganc_identity.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linjx-ustc1106/DosGAN-PyTorch/2a2b443089f8de7f15ba48fec4f5cd2121214daa/examples/dosganc_identity.jpg
--------------------------------------------------------------------------------
/examples/facescrub_intra.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linjx-ustc1106/DosGAN-PyTorch/2a2b443089f8de7f15ba48fec4f5cd2121214daa/examples/facescrub_intra.png
--------------------------------------------------------------------------------
/examples/season.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linjx-ustc1106/DosGAN-PyTorch/2a2b443089f8de7f15ba48fec4f5cd2121214daa/examples/season.jpg
--------------------------------------------------------------------------------
/main_dosgan.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from solver_dosgan import Solver
4 | from torch.backends import cudnn
5 | from torchvision import transforms, datasets
6 | import torch.utils.data as data
7 | def str2bool(v):
8 | return v.lower() in ('true')
9 | def train_trans(config):
10 | return transforms.Compose([
11 | transforms.RandomHorizontalFlip(),
12 | transforms.Resize((config.image_size,config.image_size)),
13 | transforms.ToTensor(),
14 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
15 | ])
16 | def test_trans(config):
17 | return transforms.Compose([
18 | transforms.Resize((config.image_size,config.image_size)),
19 | transforms.ToTensor(),
20 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
21 | ])
22 | def main(config):
23 | # For fast training.
24 | cudnn.benchmark = True
25 |
26 | # Create directories if not exist.
27 | config.log_dir = os.path.join(config.model_dir, 'logs')
28 | config.model_save_dir = os.path.join(config.model_dir, 'models')
29 | config.sample_dir = os.path.join(config.model_dir, 'samples')
30 | config.result_dir = os.path.join(config.model_dir, 'results')
31 |
32 | if not os.path.exists(config.log_dir):
33 | os.makedirs(config.log_dir)
34 | if not os.path.exists(config.model_save_dir):
35 | os.makedirs(config.model_save_dir)
36 | if not os.path.exists(config.sample_dir):
37 | os.makedirs(config.sample_dir)
38 | if not os.path.exists(config.result_dir):
39 | os.makedirs(config.result_dir)
40 |
41 | # Data loader.
42 |
43 | train_dataset = datasets.ImageFolder(config.train_data_path, train_trans(config))
44 |
45 | test_dataset = datasets.ImageFolder(config.test_data_path, test_trans(config))
46 | data_loader_train = data.DataLoader(dataset=train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, drop_last=True)
47 | print('train dataset loaded')
48 | data_loader_test = data.DataLoader(dataset=test_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, drop_last=True)
49 | print('test dataset loaded')
50 |
51 |
52 |
53 |
54 |
55 | # Solver for training and testing dosgan.
56 | solver = Solver(data_loader_train, data_loader_test, config)
57 |
58 | if config.mode == 'train':
59 | if config.non_conditional:
60 | solver.train()
61 | else:
62 | solver.train_conditional()
63 | elif config.mode == 'test':
64 | solver.test()
65 | elif config.mode == 'cls':
66 | solver.cls()
67 |
68 |
69 |
70 | if __name__ == '__main__':
71 | parser = argparse.ArgumentParser()
72 |
73 | # Model configuration.
74 | parser.add_argument('--c_dim', type=int, default=531, help='number of domains')
75 | parser.add_argument('--d_conv_dim', type=int, default=64, help='number of conv filters in the first layer of D')
76 | parser.add_argument('--d_repeat_num', type=int, default=6, help='number of strided conv layers in D')
77 | parser.add_argument('--n_blocks', type=int, default=0, help='number of res conv layers in C')
78 | parser.add_argument('--image_size', type=int, default=128, help='image resolution')
79 | parser.add_argument('--lambda_rec', type=float, default=10, help='weight for self-reconstruction loss')
80 | parser.add_argument('--lambda_rec2', type=float, default=10, help='weight for cross-reconstruction2 loss')
81 | parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty')
82 | parser.add_argument('--lambda_fs', type=float, default=5, help='weight for fs recontrcution')
83 | parser.add_argument('--ft_num', type=int, default=1024, help='number of ds feature')
84 |
85 | # Training configuration.
86 | parser.add_argument('--batch_size', type=int, default=6, help='mini-batch size')
87 | parser.add_argument('--num_iters', type=int, default=200000, help='number of total iterations for training D')
88 | parser.add_argument('--num_iters_decay', type=int, default=100000, help='number of iterations for decaying lr')
89 | parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for encoder and decoder')
90 | parser.add_argument('--d_lr', type=float, default=0.0001, help='learning rate for D')
91 | parser.add_argument('--n_critic', type=int, default=5, help='number of D updates per each generator update')
92 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer')
93 | parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer')
94 | parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this step')
95 | # Test configuration.
96 | parser.add_argument('--test_iters', type=int, default=200000, help='test model from this step')
97 | parser.add_argument('--non_conditional', type=str2bool, default=True)
98 |
99 | # Miscellaneous.
100 | parser.add_argument('--num_workers', type=int, default=1)
101 | parser.add_argument('--mode', type=str, default='train', choices=['train', 'test', 'cls'])
102 |
103 | # Directories.
104 | parser.add_argument('--train_data_path', type=str, default='data/facescrub_train/')
105 | parser.add_argument('--test_data_path', type=str, default='data/facescrub_test/')
106 | parser.add_argument('--model_dir', type=str, default='dosgan')
107 | parser.add_argument('--cls_save_dir', type=str, default='dosgan_cls/models')
108 |
109 |
110 | # Step size.
111 | parser.add_argument('--log_step', type=int, default=1000)
112 | parser.add_argument('--sample_step', type=int, default=2000)
113 | parser.add_argument('--model_save_step', type=int, default=20000)
114 | parser.add_argument('--lr_update_step', type=int, default=1000)
115 |
116 | config = parser.parse_args()
117 | print(config)
118 | main(config)
119 |
120 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import torchvision.models as models
6 | #from cnn_finetune import make_model
7 | class ResidualBlock(nn.Module):
8 | """Residual Block with instance normalization."""
9 | def __init__(self, dim_in, dim_out):
10 | super(ResidualBlock, self).__init__()
11 | self.main = nn.Sequential(
12 | nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
13 | nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True),
14 | nn.ReLU(inplace=True),
15 | nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
16 | nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True))
17 |
18 | def forward(self, x):
19 | return x + self.main(x)
20 |
21 |
22 | class ResnetEncoder(nn.Module):
23 | def __init__(self, input_nc=3, output_nc=3, n_blocks=3):
24 | assert(n_blocks >= 0)
25 | super(ResnetEncoder, self).__init__()
26 | self.input_nc = input_nc
27 | self.output_nc = output_nc
28 | ngf = 64
29 | padding_type ='reflect'
30 | norm_layer = nn.InstanceNorm2d
31 | use_bias = False
32 |
33 | model = [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=3,
34 | bias=use_bias),
35 | norm_layer(ngf, affine=True, track_running_stats=True),
36 | nn.ReLU(True)]
37 |
38 | n_downsampling = 2
39 | for i in range(n_downsampling):
40 | mult = 2**i
41 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=4,
42 | stride=2, padding=1, bias=use_bias),
43 | norm_layer(ngf * mult * 2, affine=True, track_running_stats=True),
44 | nn.ReLU(True)]
45 | mult = 2**n_downsampling
46 |
47 | for i in range(n_blocks):
48 | model += [ResidualBlock(dim_in=ngf * mult, dim_out=ngf * mult)]
49 |
50 | self.model = nn.Sequential(*model)
51 |
52 | def forward(self, input):
53 | return self.model(input)
54 | class ResnetDecoder(nn.Module):
55 | def __init__(self, input_nc=3, output_nc=3, n_blocks=3, ft_num=16, image_size=128):
56 | assert(n_blocks >= 0)
57 | super(ResnetDecoder, self).__init__()
58 | self.input_nc = input_nc
59 | self.output_nc = output_nc
60 | ngf = 64
61 | ngf_o = ngf*2
62 | padding_type ='reflect'
63 | norm_layer = nn.InstanceNorm2d
64 | use_bias = False
65 |
66 | model = [ ]
67 | n_downsampling = 2
68 | mult = 2**n_downsampling
69 | model_2 = [ ]
70 | model_2 += [nn.Linear(ft_num, ngf * mult * int(image_size / np.power(2, n_downsampling)) * int(image_size / np.power(2, n_downsampling)))]
71 | model_2 += [nn.ReLU(True)]
72 |
73 | model += [nn.Conv2d(ngf * mult, ngf * mult, kernel_size=3,
74 | stride=1, padding=1, bias=use_bias),
75 | norm_layer(ngf * mult, affine=True, track_running_stats=True ),
76 | nn.ReLU(True)]
77 | for i in range(n_blocks):
78 | model += [ResidualBlock(dim_in=ngf * mult, dim_out=ngf * mult)]
79 |
80 | for i in range(n_downsampling):
81 | mult = 2**(n_downsampling - i)
82 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
83 | kernel_size=3, stride=2,
84 | padding=1, output_padding=1,
85 | bias=use_bias),
86 | norm_layer(int(ngf * mult / 2), affine=True, track_running_stats=True),
87 | nn.ReLU(True)]
88 | model += [nn.Conv2d(ngf, 3, kernel_size=7, stride=1, padding=3, bias=False)]
89 |
90 | model += [nn.Tanh()]
91 | self.model = nn.Sequential(*model)
92 | self.model_2 = nn.Sequential(*model_2)
93 |
94 | def forward(self, input1, input2):
95 | out_2 = self.model_2(input2)
96 | out_2 = out_2.view(input1.size(0), input1.size(1), input1.size(2), input1.size(3))
97 |
98 | return self.model(input1+out_2)# self.model(torch.cat([input1, input2], dim=1))
99 |
100 |
101 | class Classifier(nn.Module):
102 | """Discriminator network with PatchGAN."""
103 | def __init__(self, image_size=128, conv_dim=64, c_dim=2, repeat_num=6, ft_num = 16, n_blocks = 3):
104 | super(Classifier, self).__init__()
105 | layers = []
106 | layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
107 | layers.append(nn.LeakyReLU(0.01))
108 |
109 | curr_dim = conv_dim
110 | for i in range(1, repeat_num):
111 | layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
112 | layers.append(nn.LeakyReLU(0.01))
113 | curr_dim = curr_dim * 2
114 | for i in range(n_blocks):
115 | layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))
116 | kernel_size = int(image_size / np.power(2, repeat_num))
117 | self.main = nn.Sequential(*layers)
118 | self.conv1 = nn.Sequential(*[nn.Conv2d(curr_dim, ft_num, kernel_size= kernel_size), nn.LeakyReLU(0.01)])
119 |
120 | self.conv2 = nn.Conv2d(ft_num, c_dim, kernel_size=1, bias=False)
121 |
122 | def forward(self, x):
123 | h = self.main(x)
124 | out_src = self.conv1(h)
125 | out_cls = self.conv2(out_src)
126 | return out_src.view(out_src.size(0), out_src.size(1)), out_cls.view(out_cls.size(0), out_cls.size(1))
127 |
128 | class Discriminator(nn.Module):
129 | """Discriminator network with PatchGAN."""
130 | def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6, ft_num = 16):
131 | super(Discriminator, self).__init__()
132 | layers = []
133 | layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
134 | layers.append(nn.LeakyReLU(0.01))
135 |
136 | curr_dim = conv_dim
137 | for i in range(1, repeat_num):
138 | layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
139 | layers.append(nn.LeakyReLU(0.01))
140 | curr_dim = curr_dim * 2
141 |
142 | kernel_size = int(image_size / np.power(2, repeat_num))
143 | self.main = nn.Sequential(*layers)
144 | self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False)
145 | self.conv2 = nn.Conv2d(curr_dim, ft_num, kernel_size=kernel_size, bias=False)
146 |
147 | def forward(self, x):
148 | h = self.main(x)
149 | out_src = self.conv1(h)
150 | out_cls = self.conv2(h)
151 | return out_src, out_cls.view(out_cls.size(0), out_cls.size(1))
152 |
153 |
--------------------------------------------------------------------------------
/solver_dosgan.py:
--------------------------------------------------------------------------------
1 | from model import *
2 | from torch.autograd import Variable
3 | from torchvision.utils import save_image
4 | import torch
5 | import torch.nn.functional as F
6 | import numpy as np
7 | import os
8 | import time
9 | import datetime
10 | import itertools
11 | def accuracy(output, target, topk=(1,)):
12 | """Computes the precision@k for the specified values of k"""
13 | if len(output[0]) < topk[1]:
14 | topk = (1, len(output[0]))
15 | maxk = max(topk)
16 | batch_size = target.size(0)
17 |
18 | _, pred = output.topk(maxk, 1, True, True)
19 | pred = pred.t()
20 | correct = pred.eq(target.view(1, -1).expand_as(pred))
21 |
22 | res = []
23 | for k in topk:
24 | correct_k = correct[:k].view(-1).float().sum(0)
25 | res.append(correct_k.mul_(100.0 / batch_size))
26 | return res
27 |
28 | class Solver(object):
29 |
30 | def __init__(self, data_loader, data_loader_test, config):
31 |
32 | # Data loader.
33 | self.data_loader = data_loader
34 | self.data_loader_test = data_loader_test
35 |
36 | # Model configurations and loss weights.
37 | self.ft_num = config.ft_num
38 | self.c_dim = config.c_dim
39 | self.d_conv_dim = config.d_conv_dim
40 | self.d_repeat_num = config.d_repeat_num
41 | self.n_blocks = config.n_blocks
42 | self.lambda_rec = config.lambda_rec
43 | self.lambda_rec2 = config.lambda_rec2
44 | self.lambda_gp = config.lambda_gp
45 | self.lambda_fs = config.lambda_fs
46 |
47 | # Training configurations.
48 | self.batch_size = config.batch_size
49 | self.num_iters = config.num_iters
50 | self.num_iters_decay = config.num_iters_decay
51 | self.g_lr = config.g_lr
52 | self.d_lr = config.d_lr
53 | self.n_critic = config.n_critic
54 | self.beta1 = config.beta1
55 | self.beta2 = config.beta2
56 | self.resume_iters = config.resume_iters
57 | self.image_size = config.image_size
58 |
59 | # Test configurations.
60 | self.test_iters = config.test_iters
61 | self.non_conditional = config.non_conditional
62 |
63 | # Miscellaneous.
64 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
65 |
66 | # Directories.
67 | self.log_dir = config.log_dir
68 | self.sample_dir = config.sample_dir
69 | self.model_save_dir = config.model_save_dir
70 | self.cls_save_dir = config.cls_save_dir
71 | self.result_dir = config.result_dir
72 |
73 | # Step size.
74 | self.log_step = config.log_step
75 | self.sample_step = config.sample_step
76 | self.model_save_step = config.model_save_step
77 | self.lr_update_step = config.lr_update_step
78 |
79 | # Build the model.
80 | self.build_model()
81 |
82 | def build_model(self):
83 | """Initializing networks."""
84 |
85 | self.encoder = ResnetEncoder()
86 | self.decoder = ResnetDecoder(ft_num=self.ft_num,image_size=self.image_size)
87 | self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num, ft_num = self.ft_num)
88 | self.C = Classifier(image_size=self.image_size, c_dim = self.c_dim, ft_num = self.ft_num, n_blocks = self.n_blocks)
89 |
90 | self.g_optimizer = torch.optim.Adam(itertools.chain(self.encoder.parameters(), self.decoder.parameters()), self.g_lr, [self.beta1, self.beta2])
91 | self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
92 | self.c_optimizer = torch.optim.Adam(self.C.parameters(), self.d_lr, [self.beta1, self.beta2])
93 | self.encoder.to(self.device)
94 | self.decoder.to(self.device)
95 | self.D.to(self.device)
96 | self.C.to(self.device)
97 |
98 |
99 | def restore_model(self, resume_iters):
100 | """Restore the trained networks."""
101 |
102 | print('Loading the trained models from step {}...'.format(resume_iters))
103 | encoder_path = os.path.join(self.model_save_dir, '{}-encoder.ckpt'.format(resume_iters))
104 | decoder_path = os.path.join(self.model_save_dir, '{}-decoder.ckpt'.format(resume_iters))
105 | D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters))
106 | self.encoder.load_state_dict(torch.load(encoder_path, map_location=lambda storage, loc: storage))
107 | self.decoder.load_state_dict(torch.load(decoder_path, map_location=lambda storage, loc: storage))
108 | self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
109 |
110 |
111 |
112 | def update_lr(self, g_lr, d_lr):
113 | """Decay learning rates."""
114 | for param_group in self.g_optimizer.param_groups:
115 | param_group['lr'] = g_lr
116 | for param_group in self.d_optimizer.param_groups:
117 | param_group['lr'] = d_lr
118 |
119 | def reset_grad(self):
120 | """Reset the gradient buffers."""
121 | self.g_optimizer.zero_grad()
122 | self.d_optimizer.zero_grad()
123 |
124 | def denorm(self, x):
125 | """Convert the range from [-1, 1] to [0, 1]."""
126 | out = (x + 1) / 2
127 | return out.clamp_(0, 1)
128 |
129 | def gradient_penalty(self, y, x):
130 | """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
131 | weight = torch.ones(y.size()).to(self.device)
132 | dydx = torch.autograd.grad(outputs=y,
133 | inputs=x,
134 | grad_outputs=weight,
135 | retain_graph=True,
136 | create_graph=True,
137 | only_inputs=True)[0]
138 |
139 | dydx = dydx.view(dydx.size(0), -1)
140 | dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
141 | return torch.mean((dydx_l2norm-1)**2)
142 |
143 |
144 |
145 |
146 | def classification_loss(self, logit, target):
147 | return F.cross_entropy(logit, target)
148 |
149 |
150 |
151 | def train(self):
152 | # Load pre-trained classification network
153 | cls_iter = 160000
154 | C_path = os.path.join(self.cls_save_dir, '{}-C.ckpt'.format(cls_iter))
155 | self.C.load_state_dict(torch.load(C_path, map_location=lambda storage, loc: storage))
156 |
157 | # Set data loader.
158 | data_loader = self.data_loader
159 |
160 | # Set learning rate
161 | g_lr = self.g_lr
162 | d_lr = self.d_lr
163 |
164 | # Start training from scratch or resume training.
165 | start_iters = 0
166 | if self.resume_iters:
167 | start_iters = self.resume_iters
168 | self.restore_model(self.resume_iters)
169 |
170 |
171 | empty = torch.FloatTensor(1, 3, self.image_size, self.image_size).to(self.device)
172 | empty.fill_(1)
173 | # Calculate domain feature centroid of each domain
174 | domain_sf_num = torch.FloatTensor(self.c_dim, 1).to(self.device)
175 | domain_sf_num.fill_(0.00000001)
176 | domain_sf = torch.FloatTensor(self.c_dim, self.ft_num).to(self.device)
177 | domain_sf.fill_(0)
178 | with torch.no_grad():
179 | for indx, (x_real, label_org) in enumerate(data_loader):
180 | x_real = x_real.to(self.device)
181 | label_org = label_org.to(self.device)
182 |
183 | x_ds, x_cls = self.C(x_real)
184 | for j in range(label_org.size(0)):
185 | domain_sf[label_org[j], :] = (domain_sf[label_org[j], :] + x_ds[j] / domain_sf_num[label_org[j], :]) * (
186 | domain_sf_num[label_org[j], :] / (domain_sf_num[label_org[j], :] + 1))
187 | domain_sf_num[label_org[j], :] += 1
188 |
189 | start_time = time.time()
190 | # Start training.
191 | for i in range(start_iters, self.num_iters):
192 |
193 | # Fetch real images and labels.
194 | try:
195 | x_real, label_org = next(data_iter)
196 | except:
197 | data_iter = iter(data_loader)
198 | x_real, label_org = next(data_iter)
199 |
200 | x_real = x_real.to(self.device)
201 | label_org = label_org.to(self.device)
202 |
203 | x_ds, x_cls = self.C(x_real) #obtain domain feature for each real image
204 |
205 | #obtain domain feature centroid for each real image
206 | x_ds_mean = torch.FloatTensor(label_org.size(0), self.ft_num).to(self.device)
207 | for j in range(label_org.size(0)):
208 | x_ds_mean[j] = domain_sf[label_org[j]:label_org[j] + 1, :]
209 |
210 | # random target
211 | rand_idx = torch.randperm(label_org.size(0))
212 |
213 | trg_dst = x_ds_mean[rand_idx]
214 | trg_ds = trg_dst.clone()
215 |
216 | # =================================================================================== #
217 | # 2. Train the discriminator #
218 | # =================================================================================== #
219 |
220 | # Compute loss with real images.
221 | out_src, out_cls = self.D(x_real)
222 | d_loss_real = - torch.mean(out_src)
223 | d_loss_dsrec = torch.mean(
224 | torch.abs(x_ds.detach() - out_cls))
225 |
226 | # Compute loss with fake images.
227 | x_fake = self.decoder(self.encoder(x_real), trg_ds)
228 | out_src, out_cls = self.D(x_fake.detach())
229 | d_loss_fake = torch.mean(out_src)
230 |
231 | # Compute loss for gradient penalty.
232 | alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
233 | x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
234 | out_src, _ = self.D(x_hat)
235 | d_loss_gp = self.gradient_penalty(out_src, x_hat)
236 |
237 | # Backward and optimize.
238 | d_loss = d_loss_real + d_loss_fake + self.lambda_fs * d_loss_dsrec + self.lambda_gp * d_loss_gp
239 | self.reset_grad()
240 | d_loss.backward()
241 | self.d_optimizer.step()
242 |
243 | # Logging.
244 | loss = {}
245 | loss['D/loss_real'] = d_loss_real.item()
246 | loss['D/loss_fake'] = d_loss_fake.item()
247 | loss['D/loss_dsrec'] = d_loss_dsrec.item()
248 | loss['D/loss_gp'] = d_loss_gp.item()
249 |
250 | # =================================================================================== #
251 | # 3. Train the encoder and decoder #
252 | # =================================================================================== #
253 |
254 | if (i + 1) % self.n_critic == 0:
255 | # Original-to-target domain.
256 | x_di = self.encoder(x_real)
257 |
258 | x_fake = self.decoder(x_di, trg_ds)
259 | x_reconst1 = self.decoder(x_di, x_ds)
260 | out_src, out_cls = self.D(x_fake)
261 | g_loss_fake = - torch.mean(out_src)
262 | g_loss_dsrec = torch.mean(
263 | torch.abs(trg_ds.detach() - out_cls))
264 |
265 | # Target-to-original domain.
266 | x_fake_di = self.encoder(x_fake)
267 |
268 | x_reconst2 = self.decoder(x_fake_di, x_ds)
269 |
270 | g_loss_rec = torch.mean(torch.abs(x_real - x_reconst1))
271 |
272 | g_loss_rec2 = torch.mean(torch.abs(x_real - x_reconst2))
273 |
274 | # Backward and optimize.
275 | g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_rec2 * g_loss_rec2 + self.lambda_fs * g_loss_dsrec
276 | self.reset_grad()
277 | g_loss.backward()
278 | self.g_optimizer.step()
279 |
280 | # Logging.
281 | loss['G/loss_fake'] = g_loss_fake.item()
282 | loss['G/loss_rec'] = g_loss_rec.item()
283 | loss['G/loss_rec2'] = g_loss_rec2.item()
284 | loss['G/loss_dsrec'] = g_loss_dsrec.item()
285 |
286 | # =================================================================================== #
287 | # 4. Miscellaneous #
288 | # =================================================================================== #
289 |
290 | # Print out training information.
291 | if (i + 1) % self.log_step == 0:
292 | et = time.time() - start_time
293 | et = str(datetime.timedelta(seconds=et))[:-7]
294 | log = "Elapsed [{}], Iteration [{}/{}]".format(et, i + 1, self.num_iters)
295 | for tag, value in loss.items():
296 | log += ", {}: {:.4f}".format(tag, value)
297 | print(log)
298 |
299 |
300 | # Translate fixed images for debugging.
301 | if (i) % self.sample_step == 0:
302 | with torch.no_grad():
303 | out_A2B_results = [empty]
304 |
305 | for idx1 in range(label_org.size(0)):
306 | out_A2B_results.append(x_real[idx1:idx1 + 1])
307 |
308 | for idx2 in range(label_org.size(0)):
309 | out_A2B_results.append(x_real[idx2:idx2 + 1])
310 |
311 | for idx1 in range(label_org.size(0)):
312 | x_fake = self.decoder(self.encoder(x_real[idx2:idx2 + 1]), x_ds_mean[idx1:idx1 + 1])
313 | out_A2B_results.append(x_fake)
314 | results_concat = torch.cat(out_A2B_results)
315 | x_AB_results_path = os.path.join(self.sample_dir, '{}_x_AB_results.jpg'.format(i + 1))
316 | save_image(self.denorm(results_concat.data.cpu()), x_AB_results_path, nrow=label_org.size(0) + 1,
317 | padding=0)
318 | print('Saved real and fake images into {}...'.format(x_AB_results_path))
319 |
320 |
321 | # Save model checkpoints.
322 | if (i + 1) % self.model_save_step == 0:
323 | encoder_path = os.path.join(self.model_save_dir, '{}-encoder.ckpt'.format(i + 1))
324 | decoder_path = os.path.join(self.model_save_dir, '{}-decoder.ckpt'.format(i + 1))
325 | D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i + 1))
326 | torch.save(self.encoder.state_dict(), encoder_path)
327 | torch.save(self.decoder.state_dict(), decoder_path)
328 | torch.save(self.D.state_dict(), D_path)
329 | print('Saved model checkpoints into {}...'.format(self.model_save_dir))
330 |
331 | # Decay learning rates.
332 | if (i + 1) % self.lr_update_step == 0 and (i + 1) > (self.num_iters - self.num_iters_decay):
333 | g_lr -= (self.g_lr / float(self.num_iters_decay))
334 | d_lr -= (self.d_lr / float(self.num_iters_decay))
335 | self.update_lr(g_lr, d_lr)
336 | print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
337 | def train_conditional(self):
338 | # Load pre-trained classification network
339 | cls_iter = 160000
340 | C_path = os.path.join(self.cls_save_dir, '{}-C.ckpt'.format(cls_iter))
341 | self.C.load_state_dict(torch.load(C_path, map_location=lambda storage, loc: storage))
342 |
343 | # Set data loader.
344 | data_loader = self.data_loader
345 |
346 | # Set learning rate
347 | g_lr = self.g_lr
348 | d_lr = self.d_lr
349 |
350 | # Start training from scratch or resume training.
351 | start_iters = 0
352 | if self.resume_iters:
353 | start_iters = self.resume_iters
354 | self.restore_model(self.resume_iters)
355 |
356 |
357 | empty = torch.FloatTensor(1, 3, self.image_size, self.image_size).to(self.device)
358 | empty.fill_(1)
359 |
360 |
361 | start_time = time.time()
362 | # Start training.
363 | for i in range(start_iters, self.num_iters):
364 |
365 | # Fetch real images and labels.
366 | try:
367 | x_real, label_org = next(data_iter)
368 | except:
369 | data_iter = iter(data_loader)
370 | x_real, label_org = next(data_iter)
371 |
372 | x_real = x_real.to(self.device)
373 | label_org = label_org.to(self.device)
374 |
375 | x_ds, x_cls = self.C(x_real) # obtain domain feature for each real image
376 |
377 | # random target
378 | rand_idx = torch.randperm(label_org.size(0))
379 |
380 |
381 | trg_dst = x_ds[rand_idx]
382 | trg_ds = trg_dst.clone()
383 |
384 |
385 |
386 |
387 | # =================================================================================== #
388 | # 2. Train the discriminator #
389 | # =================================================================================== #
390 |
391 | # Compute loss with real images.
392 | out_src, out_cls = self.D(x_real)
393 |
394 | d_loss_real = - torch.mean(out_src)
395 |
396 | d_loss_dsrec = torch.mean(torch.abs(x_ds.detach() - out_cls))
397 |
398 | # Compute loss with fake images.
399 | x_fake = self.decoder(self.encoder(x_real), trg_ds)
400 |
401 | out_src, out_cls = self.D(x_fake.detach())
402 |
403 |
404 | d_loss_fake = torch.mean(out_src)
405 |
406 | # Compute loss for gradient penalty.
407 | alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
408 | x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
409 | out_src, _ = self.D(x_hat)
410 | d_loss_gp = self.gradient_penalty(out_src, x_hat)
411 |
412 | # Backward and optimize.
413 | d_loss = d_loss_real + d_loss_fake + self.lambda_fs * d_loss_dsrec + self.lambda_gp * d_loss_gp
414 | self.reset_grad()
415 | d_loss.backward()
416 | self.d_optimizer.step()
417 |
418 | # Logging.
419 | loss = {}
420 | loss['D/loss_real'] = d_loss_real.item()
421 | loss['D/loss_fake'] = d_loss_fake.item()
422 | loss['D/loss_dsrec'] = d_loss_dsrec.item()
423 | loss['D/loss_gp'] = d_loss_gp.item()
424 |
425 | # =================================================================================== #
426 | # 3. Train the encoder and decoder #
427 | # =================================================================================== #
428 |
429 | if (i + 1) % self.n_critic == 0:
430 | # Original-to-target domain.
431 | x_di = self.encoder(x_real)
432 |
433 | x_fake = self.decoder(x_di, trg_ds)
434 | x_reconst1 = self.decoder(x_di, x_ds)
435 |
436 |
437 | out_src, out_cls = self.D(x_fake)
438 |
439 | g_loss_fake = - torch.mean(out_src)
440 | g_loss_dsrec = torch.mean(torch.abs(trg_ds.detach() - out_cls))
441 |
442 | # Target-to-original domain.
443 | x_fake_di = self.encoder(x_fake)
444 |
445 | x_reconst2 = self.decoder(x_fake_di, x_ds)
446 |
447 | g_loss_rec = torch.mean(torch.abs(x_real - x_reconst1))
448 |
449 | g_loss_rec2 = torch.mean(torch.abs(x_real - x_reconst2))
450 |
451 | # Backward and optimize.
452 | g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_rec2 * g_loss_rec2 + self.lambda_fs * g_loss_dsrec
453 | self.reset_grad()
454 | g_loss.backward()
455 | self.g_optimizer.step()
456 |
457 | # Logging.
458 | loss['G/loss_fake'] = g_loss_fake.item()
459 | loss['G/loss_rec'] = g_loss_rec.item()
460 | loss['G/loss_rec2'] = g_loss_rec2.item()
461 | loss['G/loss_dsrec'] = g_loss_dsrec.item()
462 |
463 | # =================================================================================== #
464 | # 4. Miscellaneous #
465 | # =================================================================================== #
466 |
467 | # Print out training information.
468 | if (i + 1) % self.log_step == 0:
469 | et = time.time() - start_time
470 | et = str(datetime.timedelta(seconds=et))[:-7]
471 | log = "Elapsed [{}], Iteration [{}/{}]".format(et, i + 1, self.num_iters)
472 | for tag, value in loss.items():
473 | log += ", {}: {:.4f}".format(tag, value)
474 | print(log)
475 |
476 |
477 | # Translate fixed images for debugging.
478 | if (i) % self.sample_step == 0:
479 | with torch.no_grad():
480 | out_A2B_results = [empty]
481 |
482 | for idx1 in range(label_org.size(0)):
483 | out_A2B_results.append(x_real[idx1:idx1 + 1])
484 |
485 | for idx2 in range(label_org.size(0)):
486 | out_A2B_results.append(x_real[idx2:idx2 + 1])
487 |
488 | for idx1 in range(label_org.size(0)):
489 | x_fake = self.decoder(self.encoder(x_real[idx2:idx2 + 1]), x_ds[idx1:idx1 + 1])
490 | out_A2B_results.append(x_fake)
491 | results_concat = torch.cat(out_A2B_results)
492 | x_AB_results_path = os.path.join(self.sample_dir, '{}_x_AB_results.jpg'.format(i + 1))
493 | save_image(self.denorm(results_concat.data.cpu()), x_AB_results_path, nrow=label_org.size(0) + 1,
494 | padding=0)
495 | print('Saved real and fake images into {}...'.format(x_AB_results_path))
496 |
497 |
498 | # Save model checkpoints.
499 | if (i + 1) % self.model_save_step == 0:
500 | encoder_path = os.path.join(self.model_save_dir, '{}-encoder.ckpt'.format(i + 1))
501 | decoder_path = os.path.join(self.model_save_dir, '{}-decoder.ckpt'.format(i + 1))
502 | D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i + 1))
503 | torch.save(self.encoder.state_dict(), encoder_path)
504 | torch.save(self.decoder.state_dict(), decoder_path)
505 | torch.save(self.D.state_dict(), D_path)
506 | print('Saved model checkpoints into {}...'.format(self.model_save_dir))
507 |
508 | # Decay learning rates.
509 | if (i + 1) % self.lr_update_step == 0 and (i + 1) > (self.num_iters - self.num_iters_decay):
510 | g_lr -= (self.g_lr / float(self.num_iters_decay))
511 | d_lr -= (self.d_lr / float(self.num_iters_decay))
512 | self.update_lr(g_lr, d_lr)
513 | print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
514 | def cls(self):
515 | """Train a domain classifier"""
516 | # Set data loader.
517 | data_loader = self.data_loader
518 |
519 |
520 | # Start training from scratch or resume training.
521 | start_iters = 0
522 |
523 | # Start training.
524 | start_time = time.time()
525 |
526 | for i in range(start_iters, self.num_iters):
527 |
528 | try:
529 | x_real, label_org = next(data_iter)
530 |
531 | except:
532 | data_iter = iter(data_loader)
533 | x_real, label_org = next(data_iter)
534 |
535 | x_real = x_real.to(self.device) # Input images.
536 | label_org = label_org.to(self.device) # Labels for computing classification loss.
537 |
538 |
539 | # =================================================================================== #
540 | # Train the classifier #
541 | # =================================================================================== #
542 |
543 | out_src, out_cls = self.C(x_real)
544 | d_loss_cls = self.classification_loss(out_cls, label_org)
545 |
546 | # Backward and optimize.
547 | d_loss = d_loss_cls
548 | self.c_optimizer.zero_grad()
549 | d_loss.backward()
550 | self.c_optimizer.step()
551 |
552 |
553 | # Logging.
554 | loss = {}
555 |
556 | loss['D/loss_cls'] = d_loss_cls.item()
557 |
558 |
559 |
560 | # Print out training information.
561 | if (i+1) % self.log_step == 0:
562 | et = time.time() - start_time
563 | et = str(datetime.timedelta(seconds=et))[:-7]
564 | prec1, prec5 = accuracy(out_cls.data, label_org.data, topk=(1, 5))
565 | loss['prec1'] = prec1
566 | loss['prec5'] = prec5
567 | log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
568 | for tag, value in loss.items():
569 | log += ", {}: {:.4f}".format(tag, value)
570 | print(log)
571 |
572 |
573 |
574 |
575 | # Save model checkpoints.
576 | if (i+1) % self.model_save_step == 0:
577 | C_path = os.path.join(self.model_save_dir, '{}-C.ckpt'.format(i+1))
578 | torch.save(self.C.state_dict(), C_path)
579 | print('Saved model checkpoints into {}...'.format(self.model_save_dir))
580 |
581 |
582 | def label2onehot(self, labels, dim):
583 | """Convert label indices to one-hot vectors."""
584 | batch_size = labels.size(0)
585 | out = torch.zeros(batch_size, dim)
586 | out[np.arange(batch_size), labels.long()] = 1
587 | return out
588 | def create_labels(self, c_org, c_dim=5):
589 | """Generate target domain labels for debugging and testing."""
590 | # Get hair color indices.
591 | hair_color_indices = []
592 | for i in range(c_dim):
593 | hair_color_indices.append(i)
594 |
595 | c_trg_list = []
596 | for i in range(c_dim):
597 | c_trg = c_org.clone()
598 | if i in hair_color_indices: # Set one hair color to 1 and the rest to 0.
599 | c_trg[:, i] = 1
600 | for j in hair_color_indices:
601 | if j != i:
602 | c_trg[:, j] = 0
603 |
604 | c_trg_list.append(c_trg.to(self.device))
605 | return c_trg_list
606 | def test(self):
607 | """Translate images with trained DosGAN."""
608 | # Load the trained networks.
609 | cls_iter = 160000
610 | C_path = os.path.join(self.cls_save_dir, '{}-C.ckpt'.format(cls_iter))
611 | self.C.load_state_dict(torch.load(C_path, map_location=lambda storage, loc: storage))
612 | self.restore_model(self.test_iters)
613 |
614 | # Set data loader.
615 | data_loader = self.data_loader
616 | data_loader_test = self.data_loader_test
617 | step = 0
618 | empty = torch.FloatTensor(1, 3,self.image_size,self.image_size).to(self.device)
619 | empty.fill_(1)
620 | domain_sf_num = torch.FloatTensor(self.c_dim, 1).to(self.device)
621 | domain_sf_num.fill_(0.00000001)
622 | domain_sf = torch.FloatTensor(self.c_dim, self.ft_num).to(self.device)
623 | domain_sf.fill_(0)
624 | with torch.no_grad():
625 | if self.non_conditional: # non_conditional testing
626 | for indx, (x_real, label_org) in enumerate(data_loader):
627 | x_real = x_real.to(self.device) # Input images.
628 | label_org = label_org.to(self.device)
629 |
630 | x_ds, x_cls = self.C(x_real)
631 | for j in range(label_org.size(0)):
632 | domain_sf[label_org[j],:] = (domain_sf[label_org[j],:] + x_ds[j]/domain_sf_num[label_org[j],:])*(domain_sf_num[label_org[j],:]/(domain_sf_num[label_org[j],:]+1))
633 | domain_sf_num[label_org[j],:] += 1
634 | step = step +1
635 |
636 | for indx, (x_real, label_org) in enumerate(data_loader_test):
637 | x_real = x_real.to(self.device) # Input images.
638 |
639 | x_ds, x_cls = self.C(x_real)
640 | c_org = self.label2onehot(label_org, self.c_dim)
641 |
642 |
643 | c_org = c_org.to(self.device)
644 | label_org = label_org.to(self.device)
645 |
646 | c_fixed_list = self.create_labels(c_org, self.c_dim)
647 |
648 | x_fake_list = [x_real]
649 | for c_fixed in c_fixed_list:
650 | _, out_pred_fixed = torch.max(c_fixed.data, 1)
651 | x_ds_m = x_ds.clone()
652 | for k in range(label_org.size(0)):
653 | x_ds_m[k,:] = domain_sf[out_pred_fixed[k],:]
654 | x_fake = self.decoder(self.encoder(x_real), x_ds_m)
655 | x_fake_list.append(x_fake)
656 |
657 | x_concat = torch.cat(x_fake_list, dim=3)
658 | sample_path = os.path.join(self.result_dir, '{}-images.jpg'.format(indx+1))
659 | save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
660 | print('Saved real and fake images into {}...'.format(sample_path))
661 | else: # conditional image translation testing
662 | for indx, (x_real, label_org) in enumerate(data_loader_test):
663 | x_real = x_real.to(self.device) # Input images.
664 | label_org = label_org.to(self.device)
665 | x_ds, x_cls = self.C(x_real)
666 |
667 |
668 | out_A2B_results = [empty]
669 |
670 | for j in range(label_org.size(0)):
671 | out_A2B_results.append(x_real[j:j+1])
672 |
673 | for i in range(label_org.size(0)):
674 | out_A2B_results.append(x_real[i:i+1])
675 |
676 | for j in range(label_org.size(0)):
677 | x_fake = self.decoder(self.encoder(x_real[i:i+1]), x_ds[j:j+1])
678 | out_A2B_results.append(x_fake)
679 | results_concat = torch.cat(out_A2B_results)
680 | x_AB_results_path = os.path.join(self.result_dir, '{}_x_AB_results.jpg'.format(indx+1))
681 | save_image(self.denorm(results_concat.data.cpu()), x_AB_results_path, nrow=label_org.size(0)+1,padding=0)
682 | print('Saved real and fake images into {}...'.format(x_AB_results_path))
683 |
684 |
--------------------------------------------------------------------------------
/split2train_val.py:
--------------------------------------------------------------------------------
1 | #-*- coding: utf-8 -*-
2 | import sys
3 | import os
4 | import shutil
5 |
6 | def split2train_val(root_dir, train_dir, val_dir):
7 | try:
8 | os.mkdir(train_dir)
9 | os.mkdir(val_dir)
10 | except:
11 | print ("train_dir and val_dir have existed!")
12 |
13 | person_names = os.listdir(root_dir)
14 | index = 0
15 |
16 | size_list = list()
17 |
18 | for person_name in person_names:
19 |
20 | os.mkdir(os.path.join(train_dir, person_name))
21 | os.mkdir(os.path.join(val_dir, person_name))
22 |
23 | index += 1
24 | img_names = os.listdir(os.path.join(root_dir, person_name))
25 | n_face = len(img_names)
26 | n_test = int(n_face/10)
27 | n_train = n_face - n_test
28 |
29 | print (len(person_names), str(index), person_name, str(n_train), str(n_test))
30 |
31 | for i in range(len(img_names)):
32 | img_name = img_names[i]
33 | source_path = os.path.join(root_dir, person_name, img_name)
34 | if i < n_train:
35 | target_path = os.path.join(train_dir, person_name, img_name)
36 | else:
37 | target_path = os.path.join(val_dir, person_name, img_name)
38 |
39 | shutil.copy(source_path, target_path)
40 |
41 |
42 | if __name__ == "__main__":
43 | root_dir = sys.argv[1]
44 | train_dir = sys.argv[2]
45 | val_dir = sys.argv[3]
46 | split2train_val(root_dir, train_dir, val_dir)
47 |
--------------------------------------------------------------------------------