├── DiagramsDomainIntersection.png ├── README.md ├── eval.py ├── models └── models.py ├── preprocess.py ├── train └── train.py └── utils.py /DiagramsDomainIntersection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sagiebenaim/DomainIntersectionDifference/13a492d72bbeb1471b158a6488505d50ed718b49/DiagramsDomainIntersection.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Domain Intersection and Domain Difference ([arxiv](https://arxiv.org/abs/1908.11628)). 2 | 3 | Pytorch Implementation of "Domain Intersection and Domain Difference" (ICCV 2019) 4 | 5 | ## Prerequisites 6 | - Python 2.7 / 3.6 7 | - Pytorch 0.4 8 | - [requsts](http://docs.python-requests.org/en/master/) 9 | - [argparse](https://docs.python.org/2/howto/argparse.html) 10 | - [Pillow](https://pillow.readthedocs.io/en/5.3.x/) 11 | 12 | ### Download and Prepare the Data 13 | Download the celeba dataset. Create a celeb directory and place the img_align_celeba folder and list_attr_celeba.txt inside. 14 | 15 | You can use the provided script ```preprocess.py``` to split celebA into the above format (with A and B based on the attribute of your choosing). 16 | For example, you can run the script using the following command: 17 | ``` 18 | python preprocess.py --root ./celeba/img_align_celeba --attributes ./celeba/list_attr_celeba.txt --dest ./smile_glasses 19 | ``` 20 | You can also use your own custom dataset, as long as it adheres to the following format: 21 | ``` 22 | root/ 23 | trainA/ 24 | trainB/ 25 | testA/ 26 | testB/ 27 | ``` 28 | You can then run the preprocessing in the following manner: 29 | ``` 30 | python preprocess.py --root ./custom_dataset --dest ./custom_train --folders --config smile_glasses 31 | ``` 32 | 33 | ### To Train 34 | Run ```train.py```. You can use the following example to run: 35 | ``` 36 | python train.py --root ./smile_glasses --out ./smile_glasses_out 37 | ``` 38 | 39 | ### To Resume Training 40 | Run ```train.py```. You can use the following example to run: 41 | ``` 42 | python train.py --root ./smile_glasses --out ./smile_glasses_out --load ./smile_glasses 43 | ``` 44 | 45 | ### To Evaluate 46 | Run ```eval.py```. You can use the following example to run: 47 | ``` 48 | python eval.py --root ./smile_glasses --out ./smile_glasses_eval --sep 25 --num_display 10 49 | ``` 50 | 51 | ## Figure 52 | Figure from paper describing the method: 53 | 54 | 55 | 56 | 57 | ## Reference 58 | If you found this code useful, please cite the following paper: 59 | ``` 60 | @inproceedings{Benaim2019DomainIntersectionDifference, 61 | title={Domain Intersection and Domain Difference}, 62 | author={Sagie Benaim and Michael Khaitov and Tomer Galanti and Lior Wolf}, 63 | booktitle={ICCV}, 64 | year={2019} 65 | } 66 | ``` 67 | 68 | ## Acknowledgements 69 | 70 | The implementation is based on the architecture of [Content Disentanglement](https://github.com/oripress/ContentDisentanglement). 71 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | 6 | from models.models import E_common, E_separate_A, E_separate_B, Decoder 7 | from utils import load_model_for_eval, save_chosen_imgs 8 | 9 | 10 | def eval(args): 11 | e_common = E_common(args.sep, int((args.resize / 64))) 12 | e_separate_A = E_separate_A(args.sep, int((args.resize / 64))) 13 | e_separate_B = E_separate_B(args.sep, int((args.resize / 64))) 14 | decoder = Decoder(int((args.resize / 64))) 15 | 16 | if torch.cuda.is_available(): 17 | e_common = e_common.cuda() 18 | e_separate_A = e_separate_A.cuda() 19 | e_separate_B = e_separate_B.cuda() 20 | decoder = decoder.cuda() 21 | 22 | if args.load != '': 23 | save_file = os.path.join(args.load, 'checkpoint') 24 | _iter = load_model_for_eval(save_file, e_common, e_separate_A, e_separate_B, decoder) 25 | 26 | e_common = e_common.eval() 27 | e_separate_A = e_separate_A.eval() 28 | e_separate_B = e_separate_B.eval() 29 | decoder = decoder.eval() 30 | 31 | if not os.path.exists(args.out) and args.out != "": 32 | os.mkdir(args.out) 33 | 34 | save_chosen_imgs(args, e_common, e_separate_A, e_separate_B, decoder, _iter, [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], 35 | False) 36 | 37 | 38 | if __name__ == '__main__': 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument('--root', default='') 41 | parser.add_argument('--load', default='') 42 | parser.add_argument('--out', default='') 43 | parser.add_argument('--resize', type=int, default=128) 44 | parser.add_argument('--crop', type=int, default=178) 45 | parser.add_argument('--sep', type=int, default=25) 46 | parser.add_argument('--bs', type=int, default=64) 47 | parser.add_argument('--num_display', type=int, default=5) 48 | 49 | args = parser.parse_args() 50 | 51 | eval(args) 52 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | from torch.nn import Parameter 7 | 8 | 9 | class E_common(nn.Module): 10 | def __init__(self, sep, size, dim=512): 11 | super(E_common, self).__init__() 12 | self.sep = sep 13 | self.size = size 14 | self.dim = dim 15 | self.layer1 = [] 16 | self.layer2 = [] 17 | self.layer3 = [] 18 | self.layer4 = [] 19 | self.layer5 = [] 20 | self.layer6 = [] 21 | self.z_dim_size = (self.dim - 2 * self.sep) * self.size * self.size 22 | 23 | self.layer1.append(SpectralNorm(nn.Conv2d(3, 32, 4, 2, 1))) 24 | self.layer1.append(nn.InstanceNorm2d(32)) 25 | self.layer1.append(nn.LeakyReLU(0.2, inplace=True)) 26 | self.l1 = nn.Sequential(*self.layer1) 27 | 28 | self.layer2.append(SpectralNorm(nn.Conv2d(32, 64, 4, 2, 1))) 29 | self.layer1.append(nn.InstanceNorm2d(64)) 30 | self.layer2.append(nn.LeakyReLU(0.2, inplace=True)) 31 | self.l2 = nn.Sequential(*self.layer2) 32 | 33 | self.layer3.append(SpectralNorm(nn.Conv2d(64, 128, 4, 2, 1))) 34 | self.layer3.append(nn.InstanceNorm2d(128)) 35 | self.layer3.append(nn.LeakyReLU(0.2, inplace=True)) 36 | self.l3 = nn.Sequential(*self.layer3) 37 | 38 | self.layer4.append(SpectralNorm(nn.Conv2d(128, 256, 4, 2, 1))) 39 | self.layer4.append(nn.InstanceNorm2d(256)) 40 | self.layer4.append(nn.LeakyReLU(0.2, inplace=True)) 41 | self.l4 = nn.Sequential(*self.layer4) 42 | 43 | self.layer5.append(SpectralNorm(nn.Conv2d(256, (512 - self.sep), 4, 2, 1))) 44 | self.layer5.append(nn.InstanceNorm2d(512 - self.sep)) 45 | self.layer5.append(nn.LeakyReLU(0.2, inplace=True)) 46 | self.l5 = nn.Sequential(*self.layer5) 47 | 48 | self.layer6.append(SpectralNorm(nn.Conv2d((512 - self.sep), (512 - 2 * self.sep), 4, 2, 1))) 49 | self.layer6.append(nn.InstanceNorm2d(512 - 2 * self.sep)) 50 | self.layer6.append(nn.LeakyReLU(0.2, inplace=True)) 51 | self.l6 = nn.Sequential(*self.layer6) 52 | 53 | def forward(self, net): 54 | out = self.l1(net) 55 | out = self.l2(out) 56 | out = self.l3(out) 57 | out = self.l4(out) 58 | out = self.l5(out) 59 | out = self.l6(out) 60 | out = out.view(-1, self.z_dim_size) 61 | 62 | return out 63 | 64 | 65 | class E_separate_A(nn.Module): 66 | def __init__(self, sep, size): 67 | super(E_separate_A, self).__init__() 68 | self.sep = sep 69 | self.size = size 70 | self.layer1 = [] 71 | self.layer2 = [] 72 | self.layer3 = [] 73 | self.layer4 = [] 74 | self.layer5 = [] 75 | self.layer6 = [] 76 | 77 | self.layer1.append(SpectralNorm(nn.Conv2d(3, 32, 4, 2, 1))) 78 | self.layer1.append(nn.InstanceNorm2d(32)) 79 | self.layer1.append(nn.LeakyReLU(0.2, inplace=True)) 80 | self.l1 = nn.Sequential(*self.layer1) 81 | 82 | self.layer2.append(SpectralNorm(nn.Conv2d(32, 64, 4, 2, 1))) 83 | self.layer2.append(nn.InstanceNorm2d(64)) 84 | self.layer2.append(nn.LeakyReLU(0.2, inplace=True)) 85 | self.l2 = nn.Sequential(*self.layer2) 86 | 87 | self.layer3.append(SpectralNorm(nn.Conv2d(64, 128, 4, 2, 1))) 88 | self.layer3.append(nn.InstanceNorm2d(128)) 89 | self.layer3.append(nn.LeakyReLU(0.2, inplace=True)) 90 | self.l3 = nn.Sequential(*self.layer3) 91 | 92 | self.layer4.append(SpectralNorm(nn.Conv2d(128, 256, 4, 2, 1))) 93 | self.layer4.append(nn.InstanceNorm2d(256)) 94 | self.layer4.append(nn.LeakyReLU(0.2, inplace=True)) 95 | self.l4 = nn.Sequential(*self.layer4) 96 | 97 | self.layer5.append(SpectralNorm(nn.Conv2d(256, 512, 4, 2, 1))) 98 | self.layer5.append(nn.InstanceNorm2d(self.sep)) 99 | self.layer5.append(nn.LeakyReLU(0.2, inplace=True)) 100 | self.l5 = nn.Sequential(*self.layer5) 101 | 102 | self.layer6.append(SpectralNorm(nn.Conv2d(512, self.sep, 4, 2, 1))) 103 | self.layer6.append(nn.InstanceNorm2d(512)) 104 | self.layer6.append(nn.LeakyReLU(0.2, inplace=True)) 105 | self.l6 = nn.Sequential(*self.layer6) 106 | 107 | def forward(self, net): 108 | out = self.l1(net) 109 | out = self.l2(out) 110 | out = self.l3(out) 111 | out = self.l4(out) 112 | out = self.l5(out) 113 | out = self.l6(out) 114 | out = out.view(-1, self.sep * self.size * self.size) 115 | return out 116 | 117 | 118 | class E_separate_B(nn.Module): 119 | def __init__(self, sep, size): 120 | super(E_separate_B, self).__init__() 121 | self.sep = sep 122 | self.size = size 123 | self.layer1 = [] 124 | self.layer2 = [] 125 | self.layer3 = [] 126 | self.layer4 = [] 127 | self.layer5 = [] 128 | self.layer6 = [] 129 | 130 | self.layer1.append(SpectralNorm(nn.Conv2d(3, 32, 4, 2, 1))) 131 | self.layer1.append(nn.InstanceNorm2d(32)) 132 | self.layer1.append(nn.LeakyReLU(0.2, inplace=True)) 133 | self.l1 = nn.Sequential(*self.layer1) 134 | 135 | self.layer2.append(SpectralNorm(nn.Conv2d(32, 64, 4, 2, 1))) 136 | self.layer2.append(nn.InstanceNorm2d(64)) 137 | self.layer2.append(nn.LeakyReLU(0.2, inplace=True)) 138 | self.l2 = nn.Sequential(*self.layer2) 139 | 140 | self.layer3.append(SpectralNorm(nn.Conv2d(64, 128, 4, 2, 1))) 141 | self.layer3.append(nn.InstanceNorm2d(128)) 142 | self.layer3.append(nn.LeakyReLU(0.2, inplace=True)) 143 | self.l3 = nn.Sequential(*self.layer3) 144 | 145 | self.layer4.append(SpectralNorm(nn.Conv2d(128, 256, 4, 2, 1))) 146 | self.layer4.append(nn.InstanceNorm2d(256)) 147 | self.layer4.append(nn.LeakyReLU(0.2, inplace=True)) 148 | self.l4 = nn.Sequential(*self.layer4) 149 | 150 | self.layer5.append(SpectralNorm(nn.Conv2d(256, 512, 4, 2, 1))) 151 | self.layer5.append(nn.InstanceNorm2d(self.sep)) 152 | self.layer5.append(nn.LeakyReLU(0.2, inplace=True)) 153 | self.l5 = nn.Sequential(*self.layer5) 154 | 155 | self.layer6.append(SpectralNorm(nn.Conv2d(512, self.sep, 4, 2, 1))) 156 | self.layer6.append(nn.InstanceNorm2d(512)) 157 | self.layer6.append(nn.LeakyReLU(0.2, inplace=True)) 158 | self.l6 = nn.Sequential(*self.layer6) 159 | 160 | def forward(self, net): 161 | out = self.l1(net) 162 | out = self.l2(out) 163 | out = self.l3(out) 164 | out = self.l4(out) 165 | out = self.l5(out) 166 | out = self.l6(out) 167 | out = out.view(-1, self.sep * self.size * self.size) 168 | return out 169 | 170 | 171 | class Decoder(nn.Module): 172 | def __init__(self, size, dim=512): 173 | super(Decoder, self).__init__() 174 | self.size = size 175 | self.dim = dim 176 | 177 | self.layer1 = [] 178 | self.layer2 = [] 179 | self.layer3 = [] 180 | self.layer4 = [] 181 | self.layer5 = [] 182 | self.layer6 = [] 183 | 184 | self.layer1.append(SpectralNorm(nn.ConvTranspose2d(512, 512, 4, 2, 1))) 185 | self.layer1.append(nn.InstanceNorm2d(512)) 186 | self.layer1.append(nn.ReLU(inplace=True)) 187 | self.l1 = nn.Sequential(*self.layer1) 188 | 189 | self.layer2.append(SpectralNorm(nn.ConvTranspose2d(512, 256, 4, 2, 1))) 190 | self.layer2.append(nn.InstanceNorm2d(256)) 191 | self.layer2.append(nn.ReLU(inplace=True)) 192 | self.l2 = nn.Sequential(*self.layer2) 193 | 194 | self.layer3.append(SpectralNorm(nn.ConvTranspose2d(256, 128, 4, 2, 1))) 195 | self.layer3.append(nn.InstanceNorm2d(128)) 196 | self.layer3.append(nn.ReLU(inplace=True)) 197 | self.l3 = nn.Sequential(*self.layer3) 198 | 199 | self.layer4.append(SpectralNorm(nn.ConvTranspose2d(128, 64, 4, 2, 1))) 200 | self.layer4.append(nn.InstanceNorm2d(64)) 201 | self.layer4.append(nn.ReLU(inplace=True)) 202 | self.l4 = nn.Sequential(*self.layer4) 203 | 204 | self.layer5.append(SpectralNorm(nn.ConvTranspose2d(64, 32, 4, 2, 1))) 205 | self.layer5.append(nn.InstanceNorm2d(32)) 206 | self.layer5.append(nn.ReLU(inplace=True)) 207 | self.l5 = nn.Sequential(*self.layer5) 208 | 209 | self.layer6.append(nn.ConvTranspose2d(32, 3, 4, 2, 1)) 210 | self.layer6.append(nn.Tanh()) 211 | self.l6 = nn.Sequential(*self.layer6) 212 | 213 | def forward(self, net): 214 | net = net.view(-1, self.dim, self.size, self.size) 215 | out = self.l1(net) 216 | out = self.l2(out) 217 | out = self.l3(out) 218 | out = self.l4(out) 219 | out = self.l5(out) 220 | out = self.l6(out) 221 | return out 222 | 223 | 224 | class Disc(nn.Module): 225 | def __init__(self, sep, size, dim=512): 226 | super(Disc, self).__init__() 227 | self.sep = sep 228 | self.size = size 229 | self.dim = dim 230 | 231 | self.classify = nn.Sequential( 232 | nn.Linear((dim - 2 * self.sep) * self.size * self.size, dim), 233 | nn.LeakyReLU(0.2, inplace=True), 234 | nn.Linear(dim, 1), 235 | nn.Sigmoid() 236 | ) 237 | 238 | def forward(self, net): 239 | # net = net.view(-1, (512 - 2 * self.sep) * self.size * self.size) 240 | net = net.view(-1, (self.dim - 2 * self.sep) * self.size * self.size) 241 | net = self.classify(net) 242 | net = net.view(-1) 243 | return net 244 | 245 | 246 | def l2normalize(v, eps=1e-12): 247 | return v / (v.norm() + eps) 248 | 249 | 250 | class SpectralNorm(nn.Module): 251 | def __init__(self, module, name='weight', power_iterations=1): 252 | super(SpectralNorm, self).__init__() 253 | self.module = module 254 | self.name = name 255 | self.power_iterations = power_iterations 256 | if not self._made_params(): 257 | self._make_params() 258 | 259 | def _update_u_v(self): 260 | u = getattr(self.module, self.name + "_u") 261 | v = getattr(self.module, self.name + "_v") 262 | w = getattr(self.module, self.name + "_bar") 263 | 264 | height = w.data.shape[0] 265 | for _ in range(self.power_iterations): 266 | v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data)) 267 | u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data)) 268 | 269 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) 270 | sigma = u.dot(w.view(height, -1).mv(v)) 271 | setattr(self.module, self.name, w / sigma.expand_as(w)) 272 | 273 | def _made_params(self): 274 | try: 275 | u = getattr(self.module, self.name + "_u") 276 | v = getattr(self.module, self.name + "_v") 277 | w = getattr(self.module, self.name + "_bar") 278 | return True 279 | except AttributeError: 280 | return False 281 | 282 | def _make_params(self): 283 | w = getattr(self.module, self.name) 284 | 285 | height = w.data.shape[0] 286 | width = w.view(height, -1).data.shape[1] 287 | 288 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 289 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 290 | u.data = l2normalize(u.data) 291 | v.data = l2normalize(v.data) 292 | w_bar = Parameter(w.data) 293 | 294 | del self.module._parameters[self.name] 295 | 296 | self.module.register_parameter(self.name + "_u", u) 297 | self.module.register_parameter(self.name + "_v", v) 298 | self.module.register_parameter(self.name + "_bar", w_bar) 299 | 300 | def forward(self, *args): 301 | self._update_u_v() 302 | return self.module.forward(*args) -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | ####### 6 | # CelebA attributes: 7 | # ------ 8 | # 5_o_Clock_Shadow 1 9 | # Arched_Eyebrows 2 10 | # Attractive 3 11 | # Bags_Under_Eyes 4 12 | # Bald 5 13 | # Bangs 6 14 | # Big_Lips 7 15 | # Big_Nose 8 16 | # Black_Hair 9 17 | # Blond_Hair 10 18 | # Blurry 11 19 | # Brown_Hair 12 20 | # Bushy_Eyebrows 13 21 | # Chubby 14 22 | # Double_Chin 15 23 | # Eyeglasses 16 24 | # Goatee 17 25 | # Gray_Hair 18 26 | # Heavy_Makeup 19 27 | # High_Cheekbones 20 28 | # Male 21 29 | # Mouth_Slightly_Open 22 30 | # Mustache 23 31 | # Narrow_Eyes 24 32 | # No_Beard 25 33 | # Oval_Face 26 34 | # Pale_Skin 27 35 | # Pointy_Nose 28 36 | # Receding_Hairline 29 37 | # Rosy_Cheeks 30 38 | # Sideburns 31 39 | # Smiling 32 40 | # Straight_Hair 33 41 | # Wavy_Hair 34 42 | # Wearing_Earrings 35 43 | # Wearing_Hat 36 44 | # Wearing_Lipstick 37 45 | # Wearing_Necklace 38 46 | # Wearing_Necktie 39 47 | # Young 40 48 | ####### 49 | 50 | def preprocess_celeba(args): 51 | if not os.path.exists(args.dest): 52 | os.mkdir(args.dest) 53 | 54 | allA = [] 55 | allB = [] 56 | 57 | with open(args.attributes) as f: 58 | lines = f.readlines() 59 | 60 | if args.config == 'beard_glasses': 61 | for line in lines[2:]: 62 | line = line.split() 63 | if male_no_5_oclock(line) and beard(line) and (not glasses(line)): 64 | allA.append(line[0]) 65 | elif male_no_5_oclock(line) and (not beard(line)) and glasses(line): 66 | allB.append(line[0]) 67 | 68 | if args.config == 'beard_smile': 69 | for line in lines[2:]: 70 | line = line.split() 71 | if male_no_5_oclock(line) and beard(line) and (not smile(line)): 72 | allA.append(line[0]) 73 | elif male_no_5_oclock(line) and (not beard(line)) and smile(line): 74 | allB.append(line[0]) 75 | 76 | if args.config == "smile_glasses": 77 | for line in lines[2:]: 78 | line = line.split() 79 | if smile(line) and (not glasses(line)): 80 | allA.append(line[0]) 81 | elif (not smile(line)) and glasses(line): 82 | allB.append(line[0]) 83 | 84 | if args.config == "male_female": 85 | for line in lines[2:]: 86 | line = line.split() 87 | if int(line[21]) == 1: 88 | allA.append(line[0]) 89 | else: 90 | allB.append(line[0]) 91 | 92 | if args.config == "blond_black": 93 | for line in lines[2:]: 94 | line = line.split() 95 | if blonde_hair(line) and (not hat(line)): 96 | allA.append(line[0]) 97 | elif black_hair(line) and (not hat(line)): 98 | allB.append(line[0]) 99 | 100 | testA = allA[:args.num_test_imgs] 101 | testB = allB[:args.num_test_imgs] 102 | trainA = allA[args.num_test_imgs:] 103 | trainB = allB[args.num_test_imgs:] 104 | 105 | with open(os.path.join(args.dest, 'testA.txt'), 'w') as f: 106 | for i, _img in enumerate(testA): 107 | if i == len(testA) - 1: 108 | f.write("%s" % os.path.join(args.root, _img)) 109 | else: 110 | f.write("%s\n" % os.path.join(args.root, _img)) 111 | 112 | with open(os.path.join(args.dest, 'testB.txt'), 'w') as f: 113 | for i, _img in enumerate(testB): 114 | if i == len(testB) - 1: 115 | f.write("%s" % os.path.join(args.root, _img)) 116 | else: 117 | f.write("%s\n" % os.path.join(args.root, _img)) 118 | 119 | with open(os.path.join(args.dest, 'trainA.txt'), 'w') as f: 120 | for i, _img in enumerate(trainA): 121 | if i == len(trainA) - 1: 122 | f.write("%s" % os.path.join(args.root, _img)) 123 | else: 124 | f.write("%s\n" % os.path.join(args.root, _img)) 125 | 126 | with open(os.path.join(args.dest, 'trainB.txt'), 'w') as f: 127 | for i, _img in enumerate(trainB): 128 | if i == len(trainB) - 1: 129 | f.write("%s" % os.path.join(args.root, _img)) 130 | else: 131 | f.write("%s\n" % os.path.join(args.root, _img)) 132 | 133 | 134 | def male_no_5_oclock(line): 135 | return int(line[21]) == 1 and int(line[1]) == -1 136 | 137 | 138 | def beard(line): 139 | return int(line[23]) == 1 or int(line[17]) == 1 or int(line[25]) == -1 140 | 141 | 142 | def glasses(line): 143 | return int(line[16]) == 1 144 | 145 | 146 | def smile(line): 147 | return int(line[32]) == 1 148 | 149 | 150 | def blonde_hair(line): 151 | return int(line[10]) == 1 152 | 153 | 154 | def black_hair(line): 155 | return int(line[9]) == 1 156 | 157 | 158 | def preprocess_folders(args): 159 | if not os.path.exists(args.dest): 160 | os.mkdir(args.dest) 161 | 162 | trainA = os.listdir(os.path.join(args.root, 'trainA')) 163 | trainB = os.listdir(os.path.join(args.root, 'trainB')) 164 | testA = os.listdir(os.path.join(args.root, 'testA')) 165 | testB = os.listdir(os.path.join(args.root, 'testB')) 166 | 167 | with open(os.path.join(args.dest, 'testA.txt'), 'w') as f: 168 | for i, _img in enumerate(testA): 169 | if i == len(testA) - 1: 170 | f.write("%s" % os.path.join(args.root, _img)) 171 | else: 172 | f.write("%s\n" % os.path.join(args.root, _img)) 173 | 174 | with open(os.path.join(args.dest, 'testB.txt'), 'w') as f: 175 | for i, _img in enumerate(testB): 176 | if i == len(testB) - 1: 177 | f.write("%s" % os.path.join(args.root, _img)) 178 | else: 179 | f.write("%s\n" % os.path.join(args.root, _img)) 180 | 181 | with open(os.path.join(args.dest, 'trainA.txt'), 'w') as f: 182 | for i, _img in enumerate(trainA): 183 | if i == len(trainA) - 1: 184 | f.write("%s" % os.path.join(args.root, _img)) 185 | else: 186 | f.write("%s\n" % os.path.join(args.root, _img)) 187 | 188 | with open(os.path.join(args.dest, 'trainB.txt'), 'w') as f: 189 | for i, _img in enumerate(trainB): 190 | if i == len(trainB) - 1: 191 | f.write("%s" % os.path.join(args.root, _img)) 192 | else: 193 | f.write("%s\n" % os.path.join(args.root, _img)) 194 | 195 | 196 | if __name__ == "__main__": 197 | parser = argparse.ArgumentParser() 198 | parser.add_argument("--root", default="", 199 | help="path to the celeba folder, or if you\'re using another dataset this should be the path to the root") 200 | parser.add_argument("--dest", default="", help="path to the destination folder") 201 | parser.add_argument("--attributes", default="", help="path to the attributes file") 202 | parser.add_argument("--num_test_imgs", default=64, help="number of images in the test set") 203 | parser.add_argument("--config", default="smile_glasses", help="configs available: glasses, mouth, beard") 204 | parser.add_argument("--custom", default=32, help="use a custom celeba attribute") 205 | parser.add_argument("--folders", action="store_true", 206 | help="use custom folders, instead of celeba") 207 | 208 | args = parser.parse_args() 209 | 210 | if not args.folders: 211 | preprocess_celeba(args) 212 | else: 213 | preprocess_folders(args) 214 | -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from torch import optim 5 | from torch.autograd import Variable 6 | from torch.utils.data import DataLoader 7 | 8 | from models.models import * 9 | from utils import Logger 10 | from utils import get_train_dataset 11 | from utils import save_imgs, save_model, load_model, save_stripped_imgs 12 | 13 | 14 | def train(args): 15 | if not os.path.exists(args.out): 16 | os.makedirs(args.out) 17 | 18 | _iter = 0 19 | domA_train, domB_train = get_train_dataset(args) 20 | 21 | size = args.resize // 64 22 | dim = 512 23 | 24 | e_common = E_common(args.sep, size, dim=dim) 25 | e_separate_A = E_separate_A(args.sep, size) 26 | e_separate_B = E_separate_B(args.sep, size) 27 | decoder = Decoder(size, dim=dim) 28 | disc = Disc(args.sep, size, dim=dim) 29 | 30 | A_label = torch.full((args.bs,), 1) 31 | B_label = torch.full((args.bs,), 0) 32 | zero_encoding = torch.full((args.bs, args.sep * (size) * (size)), 0) 33 | one_encoding = torch.full((args.bs, args.sep * (size) * (size)), 1) 34 | 35 | l1 = nn.L1Loss() 36 | bce = nn.BCELoss() 37 | 38 | if torch.cuda.is_available(): 39 | e_common = e_common.cuda() 40 | e_separate_A = e_separate_A.cuda() 41 | e_separate_B = e_separate_B.cuda() 42 | decoder = decoder.cuda() 43 | disc = disc.cuda() 44 | 45 | A_label = A_label.cuda() 46 | B_label = B_label.cuda() 47 | zero_encoding = zero_encoding.cuda() 48 | one_encoding = one_encoding.cuda() 49 | 50 | l1 = l1.cuda() 51 | bce = bce.cuda() 52 | 53 | ae_params = list(e_common.parameters()) + list(e_separate_A.parameters()) + list( 54 | e_separate_B.parameters()) + list(decoder.parameters()) 55 | ae_optimizer = optim.Adam(ae_params, lr=args.lr, betas=(0.5, 0.999)) 56 | 57 | disc_params = disc.parameters() 58 | disc_optimizer = optim.Adam(disc_params, lr=args.disclr, betas=(0.5, 0.999)) 59 | 60 | if args.load != '': 61 | save_file = os.path.join(args.load, 'checkpoint') 62 | _iter = load_model(save_file, e_common, e_separate_A, e_separate_B, decoder, ae_optimizer, disc, 63 | disc_optimizer) 64 | 65 | e_common = e_common.train() 66 | e_separate_A = e_separate_A.train() 67 | e_separate_B = e_separate_B.train() 68 | decoder = decoder.train() 69 | disc = disc.train() 70 | 71 | logger = Logger(args.out) 72 | 73 | print('Started training...') 74 | while True: 75 | domA_loader = torch.utils.data.DataLoader(domA_train, batch_size=args.bs, 76 | shuffle=True, num_workers=6) 77 | domB_loader = torch.utils.data.DataLoader(domB_train, batch_size=args.bs, 78 | shuffle=True, num_workers=6) 79 | if _iter >= args.iters: 80 | break 81 | 82 | for domA_img, domB_img in zip(domA_loader, domB_loader): 83 | 84 | if domA_img.size(0) != args.bs or domB_img.size(0) != args.bs: 85 | break 86 | 87 | domA_img = Variable(domA_img) 88 | domB_img = Variable(domB_img) 89 | 90 | if torch.cuda.is_available(): 91 | domA_img = domA_img.cuda() 92 | domB_img = domB_img.cuda() 93 | 94 | domA_img = domA_img.view((-1, 3, args.resize, args.resize)) 95 | domB_img = domB_img.view((-1, 3, args.resize, args.resize)) 96 | 97 | ae_optimizer.zero_grad() 98 | 99 | A_common = e_common(domA_img) 100 | A_separate_A = e_separate_A(domA_img) 101 | A_separate_B = e_separate_B(domA_img) 102 | if args.no_flag: 103 | A_encoding = torch.cat([A_common, A_separate_A, A_separate_A], dim=1) 104 | else: 105 | A_encoding = torch.cat([A_common, A_separate_A, zero_encoding], dim=1) 106 | B_common = e_common(domB_img) 107 | B_separate_A = e_separate_A(domB_img) 108 | B_separate_B = e_separate_B(domB_img) 109 | 110 | if args.one_encoding: 111 | B_encoding = torch.cat([B_common, B_separate_B, one_encoding], dim=1) 112 | elif args.no_flag: 113 | B_encoding = torch.cat([B_common, B_separate_B, B_separate_B], dim=1) 114 | else: 115 | B_encoding = torch.cat([B_common, zero_encoding, B_separate_B], dim=1) 116 | 117 | A_decoding = decoder(A_encoding) 118 | B_decoding = decoder(B_encoding) 119 | 120 | A_reconstruction_loss = l1(A_decoding, domA_img) 121 | B_reconstruction_loss = l1(B_decoding, domB_img) 122 | 123 | A_separate_B_loss = l1(A_separate_B, zero_encoding) 124 | B_separate_A_loss = l1(B_separate_A, zero_encoding) 125 | 126 | logger.add_value('A_recon', A_reconstruction_loss) 127 | logger.add_value('B_recon', B_reconstruction_loss) 128 | logger.add_value('A_sep_B', A_separate_B_loss) 129 | logger.add_value('B_sep_A', B_separate_A_loss) 130 | 131 | loss = 0 132 | 133 | if args.reconweight > 0: 134 | loss += args.reconweight * (A_reconstruction_loss + B_reconstruction_loss) 135 | 136 | if args.zeroweight > 0: 137 | loss += args.zeroweight * (A_separate_B_loss + B_separate_A_loss) 138 | 139 | if args.discweight > 0: 140 | preds_A = disc(A_common) 141 | preds_B = disc(B_common) 142 | distribution_adverserial_loss = args.discweight * \ 143 | (bce(preds_A, B_label) + bce(preds_B, B_label)) 144 | logger.add_value('distribution_adverserial', distribution_adverserial_loss) 145 | loss += distribution_adverserial_loss 146 | 147 | loss.backward() 148 | torch.nn.utils.clip_grad_norm_(ae_params, 5) 149 | ae_optimizer.step() 150 | 151 | if args.discweight > 0: 152 | disc_optimizer.zero_grad() 153 | 154 | A_common = e_common(domA_img) 155 | B_common = e_common(domB_img) 156 | 157 | disc_A = disc(A_common) 158 | disc_B = disc(B_common) 159 | 160 | loss = bce(disc_A, A_label) + bce(disc_B, B_label) 161 | logger.add_value('dist_disc', loss) 162 | loss.backward() 163 | torch.nn.utils.clip_grad_norm_(disc_params, 5) 164 | disc_optimizer.step() 165 | 166 | if _iter % args.progress_iter == 0: 167 | print('Outfile: %s <<>> Iteration %d' % (args.out, _iter)) 168 | 169 | if _iter % args.log_iter == 0: 170 | logger.log(_iter) 171 | 172 | logger.reset() 173 | 174 | if _iter % args.display_iter == 0: 175 | e_common = e_common.eval() 176 | e_separate_A = e_separate_A.eval() 177 | e_separate_B = e_separate_B.eval() 178 | decoder = decoder.eval() 179 | 180 | save_imgs(args, e_common, e_separate_A, e_separate_B, decoder, _iter, size=size, BtoA=True) 181 | save_imgs(args, e_common, e_separate_A, e_separate_B, decoder, _iter, size=size, BtoA=False) 182 | save_stripped_imgs(args, e_common, e_separate_A, e_separate_B, decoder, _iter, size=size, A=True) 183 | save_stripped_imgs(args, e_common, e_separate_A, e_separate_B, decoder, _iter, size=size, A=False) 184 | 185 | e_common = e_common.train() 186 | e_separate_A = e_separate_A.train() 187 | e_separate_B = e_separate_B.train() 188 | decoder = decoder.train() 189 | 190 | if _iter % args.save_iter == 0: 191 | save_file = os.path.join(args.out, 'checkpoint') 192 | save_model(save_file, e_common, e_separate_A, e_separate_B, decoder, ae_optimizer, disc, 193 | disc_optimizer, _iter) 194 | 195 | _iter += 1 196 | 197 | 198 | if __name__ == '__main__': 199 | parser = argparse.ArgumentParser() 200 | parser.add_argument('--root', default='') 201 | parser.add_argument('--out', default='out') 202 | parser.add_argument('--lr', type=float, default=0.0002) 203 | parser.add_argument('--bs', type=int, default=32) 204 | parser.add_argument('--iters', type=int, default=1250000) 205 | parser.add_argument('--resize', type=int, default=128) 206 | parser.add_argument('--crop', type=int, default=178) 207 | parser.add_argument('--sep', type=int, default=25) 208 | parser.add_argument('--disclr', type=float, default=0.0002) 209 | parser.add_argument('--progress_iter', type=int, default=100) 210 | parser.add_argument('--display_iter', type=int, default=1000) 211 | parser.add_argument('--log_iter', type=int, default=100) 212 | parser.add_argument('--save_iter', type=int, default=10000) 213 | parser.add_argument('--load', default='') 214 | parser.add_argument('--zeroweight', type=float, default=1.0) 215 | parser.add_argument('--reconweight', type=float, default=1.0) 216 | parser.add_argument('--discweight', type=float, default=0.001) 217 | parser.add_argument('--num_display', type=int, default=12) 218 | parser.add_argument('--one_encoding', type=int, default=0) 219 | parser.add_argument('--no_flag', type=int, default=0) 220 | 221 | 222 | args = parser.parse_args() 223 | 224 | train(args) 225 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import torch 5 | import torch.utils.data as data 6 | import torchvision.transforms as transforms 7 | import torchvision.utils as vutils 8 | from PIL import Image 9 | 10 | 11 | def get_test_dataset(args, crop=None, resize=None): 12 | if crop is None: 13 | crop = args.crop 14 | 15 | if resize is None: 16 | resize = args.resize 17 | 18 | comp_transform = transforms.Compose([ 19 | transforms.CenterCrop(crop), 20 | transforms.Resize(resize), 21 | transforms.ToTensor(), 22 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 23 | ]) 24 | 25 | domA_test = CustomDataset(os.path.join(args.root, 'testA.txt'), transform=comp_transform) 26 | domB_test = CustomDataset(os.path.join(args.root, 'testB.txt'), transform=comp_transform) 27 | 28 | return domA_test, domB_test 29 | 30 | 31 | def get_train_dataset(args, crop=None, resize=None): 32 | if crop is None: 33 | crop = args.crop 34 | 35 | if resize is None: 36 | resize = args.resize 37 | 38 | comp_transform = transforms.Compose([ 39 | transforms.CenterCrop(crop), 40 | transforms.Resize(resize), 41 | transforms.RandomHorizontalFlip(), 42 | transforms.ToTensor(), 43 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 44 | ]) 45 | 46 | domA_train = CustomDataset(os.path.join(args.root, 'trainA.txt'), transform=comp_transform) 47 | domB_train = CustomDataset(os.path.join(args.root, 'trainB.txt'), transform=comp_transform) 48 | 49 | return domA_train, domB_train 50 | 51 | 52 | def save_imgs(args, e_common, e_separate_A, e_separate_B, decoder, iters, size, BtoA=True, num_offsets=1): 53 | ''' saves images of translation B -> A or A -> B''' 54 | test_domA, test_domB = get_test_imgs(args) 55 | 56 | for k in range(num_offsets): 57 | exps = [] 58 | for i in range(k * args.num_display, (k + 1) * args.num_display): 59 | with torch.no_grad(): 60 | if i == k * args.num_display: 61 | filler = test_domB[i].unsqueeze(0).clone() 62 | exps.append(filler.fill_(0)) 63 | 64 | if BtoA: 65 | exps.append(test_domB[i].unsqueeze(0)) 66 | else: 67 | exps.append(test_domA[i].unsqueeze(0)) 68 | 69 | if BtoA: 70 | for i in range(k * args.num_display, (k + 1) * args.num_display): 71 | exps.append(test_domA[i].unsqueeze(0)) 72 | separate_A = e_separate_A(test_domA[i].unsqueeze(0)) 73 | for j in range(k * args.num_display, (k + 1) * args.num_display): 74 | with torch.no_grad(): 75 | common_B = e_common(test_domB[j].unsqueeze(0)) 76 | zero_encoding = torch.full((1, args.sep * size * size), 0) 77 | if torch.cuda.is_available(): 78 | zero_encoding = zero_encoding.cuda() 79 | 80 | if args.no_flag: 81 | BA_encoding = torch.cat([common_B, separate_A, separate_A], dim=1) 82 | BA_decoding = decoder(BA_encoding) 83 | exps.append(BA_decoding) 84 | else: 85 | BA_encoding = torch.cat([common_B, separate_A, zero_encoding], dim=1) 86 | BA_decoding = decoder(BA_encoding) 87 | exps.append(BA_decoding) 88 | else: 89 | for i in range(k * args.num_display, (k + 1) * args.num_display): 90 | exps.append(test_domB[i].unsqueeze(0)) 91 | separate_B = e_separate_B(test_domB[i].unsqueeze(0)) 92 | for j in range(k * args.num_display, (k + 1) * args.num_display): 93 | with torch.no_grad(): 94 | common_A = e_common(test_domA[j].unsqueeze(0)) 95 | zero_encoding = torch.full((1, args.sep * size * size), 0) 96 | one_encoding = torch.full((1, args.sep * size * size), 1) 97 | if torch.cuda.is_available(): 98 | zero_encoding = zero_encoding.cuda() 99 | one_encoding = one_encoding.cuda() 100 | 101 | if args.one_encoding: 102 | AB_encoding = torch.cat( 103 | [common_A, separate_B, one_encoding], dim=1) 104 | elif args.no_flag: 105 | AB_encoding = torch.cat( 106 | [common_A, separate_B, separate_B], dim=1) 107 | else: 108 | AB_encoding = torch.cat( 109 | [common_A, zero_encoding, separate_B], dim=1) 110 | 111 | AB_decoding = decoder(AB_encoding) 112 | exps.append(AB_decoding) 113 | 114 | with torch.no_grad(): 115 | exps = torch.cat(exps, 0) 116 | 117 | if BtoA: 118 | vutils.save_image(exps, 119 | '%s/experiments_%06d_%d-BtoA.png' % (args.out, 120 | iters, 121 | k), 122 | normalize=True, nrow=args.num_display + 1) 123 | else: 124 | vutils.save_image(exps, 125 | '%s/experiments_%06d_%d-AtoB.png' % (args.out, 126 | iters, 127 | k), 128 | normalize=True, nrow=args.num_display + 1) 129 | 130 | 131 | def get_test_imgs(args, crop=None, resize=None): 132 | domA_test, domB_test = get_test_dataset(args, crop=crop, resize=resize) 133 | 134 | domA_test_loader = torch.utils.data.DataLoader(domA_test, batch_size=64, 135 | shuffle=False, num_workers=6) 136 | domB_test_loader = torch.utils.data.DataLoader(domB_test, batch_size=64, 137 | shuffle=False, num_workers=6) 138 | 139 | for domA_img in domA_test_loader: 140 | if torch.cuda.is_available(): 141 | domA_img = domA_img.cuda() 142 | domA_img = domA_img[:] 143 | break 144 | 145 | for domB_img in domB_test_loader: 146 | if torch.cuda.is_available(): 147 | domB_img = domB_img.cuda() 148 | domB_img = domB_img[:] 149 | break 150 | 151 | return domA_img, domB_img 152 | 153 | 154 | def save_model(out_file, e_common, e_separate_A, e_separate_B, decoder, ae_opt, disc, disc_opt, iters): 155 | state = { 156 | 'e_common': e_common.state_dict(), 157 | 'e_separate_A': e_separate_A.state_dict(), 158 | 'e_separate_B': e_separate_B.state_dict(), 159 | 'decoder': decoder.state_dict(), 160 | 'ae_opt': ae_opt.state_dict(), 161 | 'disc': disc.state_dict(), 162 | 'disc_opt': disc_opt.state_dict(), 163 | 'iters': iters 164 | } 165 | torch.save(state, out_file) 166 | return 167 | 168 | 169 | def load_model(load_path, e_common, e_separate_A, e_separate_B, decoder, ae_opt, disc, disc_opt): 170 | state = torch.load(load_path) 171 | e_common.load_state_dict(state['e_common']) 172 | e_separate_A.load_state_dict(state['e_separate_A']) 173 | e_separate_B.load_state_dict(state['e_separate_B']) 174 | decoder.load_state_dict(state['decoder']) 175 | ae_opt.load_state_dict(state['ae_opt']) 176 | disc.load_state_dict(state['disc']) 177 | disc_opt.load_state_dict(state['disc_opt']) 178 | return state['iters'] 179 | 180 | 181 | def load_model_for_eval(load_path, e_common, e_separate_A, e_separate_B, decoder, ): 182 | state = torch.load(load_path) 183 | e_common.load_state_dict(state['e_common']) 184 | e_separate_A.load_state_dict(state['e_separate_A']) 185 | e_separate_B.load_state_dict(state['e_separate_B']) 186 | decoder.load_state_dict(state['decoder']) 187 | return state['iters'] 188 | 189 | 190 | IMG_EXTENSIONS = [ 191 | '.jpg', '.JPG', '.jpeg', '.JPEG', 192 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 193 | ] 194 | 195 | 196 | def edges_loader(path, train=True): 197 | image = Image.open(path).convert('RGB') 198 | image_A = image.crop((0, 0, 256, 256)) 199 | image_B = image.crop((0, 256, 512, 256)) 200 | 201 | if train: 202 | return image_A 203 | else: 204 | return image_B 205 | 206 | 207 | def default_loader(path): 208 | return Image.open(path).convert('RGB') 209 | 210 | 211 | class Logger(): 212 | def __init__(self, path): 213 | self.full_path = '%s/log.txt' % path 214 | self.log_file = open(self.full_path, 'w+') 215 | self.log_file.close() 216 | self.map = {} 217 | 218 | def add_value(self, tag, value): 219 | self.map[tag] = value 220 | 221 | def log(self, iter): 222 | self.log_file = open(self.full_path, 'a') 223 | self.log_file.write('iter: %7d' % iter) 224 | for k, v in self.map.items(): 225 | self.log_file.write('\t %s: %10.7f' % (k, v)) 226 | self.log_file.write('\n') 227 | self.log_file.close() 228 | 229 | def reset(self): 230 | self.map = {} 231 | 232 | 233 | class CustomDataset(data.Dataset): 234 | def __init__(self, path, transform=None, return_paths=False, 235 | loader=default_loader): 236 | super(CustomDataset, self).__init__() 237 | 238 | with open(path) as f: 239 | imgs = [s.replace('\n', '') for s in f.readlines()] 240 | 241 | if len(imgs) == 0: 242 | raise (RuntimeError("Found 0 images in: " + path + "\n" 243 | "Supported image extensions are: " + 244 | ",".join(IMG_EXTENSIONS))) 245 | 246 | self.imgs = imgs 247 | self.transform = transform 248 | self.return_paths = return_paths 249 | self.loader = loader 250 | 251 | def __getitem__(self, index): 252 | path = self.imgs[index] 253 | img = self.loader(path) 254 | if self.transform is not None: 255 | img = self.transform(img) 256 | 257 | if self.return_paths: 258 | return img, path 259 | else: 260 | return img 261 | 262 | def __len__(self): 263 | return len(self.imgs) 264 | 265 | 266 | def default_flist_reader(flist): 267 | """ 268 | flist format: impath label\nimpath label\n ...(same to caffe's filelist) 269 | """ 270 | imlist = [] 271 | with open(flist, 'r') as rf: 272 | for line in rf.readlines(): 273 | impath = line.strip() 274 | imlist.append(impath) 275 | 276 | return imlist 277 | 278 | 279 | def is_image_file(filename): 280 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 281 | 282 | 283 | def save_stripped_imgs(args, e_common, e_separate_A, e_separate_B, decoder, iters, size, A=True): 284 | test_domA, test_domB = get_test_imgs(args) 285 | exps = [] 286 | zero_encoding = torch.full((1, args.sep * size * size), 0) 287 | one_encoding = torch.full((1, args.sep * size * size), 1) 288 | # zero_encoding = torch.full((1, 12, 32, 32), 0) 289 | if torch.cuda.is_available(): 290 | zero_encoding = zero_encoding.cuda() 291 | one_encoding = one_encoding.cuda() 292 | 293 | for i in range(args.num_display): 294 | if A: 295 | image = test_domA[i] 296 | else: 297 | image = test_domB[i] 298 | exps.append(image.unsqueeze(0)) 299 | common = e_common(image.unsqueeze(0)) 300 | content_zero_encoding = torch.full(common.size(), 0) 301 | if torch.cuda.is_available(): 302 | content_zero_encoding = content_zero_encoding.cuda() 303 | separate_A = e_separate_A(image.unsqueeze(0)) 304 | separate_B = e_separate_B(image.unsqueeze(0)) 305 | 306 | if args.one_encoding: 307 | exps.append(decoder(torch.cat([content_zero_encoding, separate_A, zero_encoding], dim=1))) 308 | exps.append(decoder(torch.cat([content_zero_encoding, separate_B, one_encoding], dim=1))) 309 | elif args.no_flag: 310 | exps.append(decoder(torch.cat([content_zero_encoding, separate_A, separate_A], dim=1))) 311 | exps.append(decoder(torch.cat([content_zero_encoding, separate_B, separate_B], dim=1))) 312 | else: 313 | exps.append(decoder(torch.cat([common, zero_encoding, zero_encoding], dim=1))) 314 | exps.append(decoder(torch.cat([content_zero_encoding, separate_A, zero_encoding], dim=1))) 315 | exps.append(decoder(torch.cat([content_zero_encoding, zero_encoding, separate_B], dim=1))) 316 | 317 | with torch.no_grad(): 318 | exps = torch.cat(exps, 0) 319 | 320 | if A: 321 | vutils.save_image(exps, 322 | '%s/experiments_%06d-Astripped.png' % (args.out, iters), 323 | normalize=True, nrow=args.num_display) 324 | else: 325 | vutils.save_image(exps, 326 | '%s/experiments_%06d-Bstripped.png' % (args.out, iters), 327 | normalize=True, nrow=args.num_display) 328 | 329 | 330 | def save_chosen_imgs(args, e_common, e_separate_A, e_separate_B, decoder, iters, listA, listB, BtoA=True): 331 | ''' saves images of translation B -> A or A -> B''' 332 | test_domA, test_domB = get_test_imgs(args) 333 | 334 | exps = [] 335 | for i in range(args.num_display): 336 | with torch.no_grad(): 337 | if i == 0: 338 | filler = test_domB[i].unsqueeze(0).clone() 339 | exps.append(filler.fill_(0)) 340 | 341 | if BtoA: 342 | exps.append(test_domB[listB[i]].unsqueeze(0)) 343 | else: 344 | exps.append(test_domA[listA[i]].unsqueeze(0)) 345 | 346 | if BtoA: 347 | for i in listA: 348 | exps.append(test_domA[i].unsqueeze(0)) 349 | separate_A = e_separate_A(test_domA[i].unsqueeze(0)) 350 | for j in listB: 351 | with torch.no_grad(): 352 | common_B = e_common(test_domB[j].unsqueeze(0)) 353 | zero_encoding = torch.full((1, args.sep * (args.resize 354 | // 64) * (args.resize // 64)), 0) 355 | if torch.cuda.is_available(): 356 | zero_encoding = zero_encoding.cuda() 357 | 358 | BA_encoding = torch.cat([common_B, separate_A, zero_encoding], dim=1) 359 | BA_decoding = decoder(BA_encoding) 360 | exps.append(BA_decoding) 361 | else: 362 | for i in listB: 363 | exps.append(test_domB[i].unsqueeze(0)) 364 | separate_B = e_separate_B(test_domB[i].unsqueeze(0)) 365 | for j in listA: 366 | with torch.no_grad(): 367 | common_A = e_common(test_domA[j].unsqueeze(0)) 368 | zero_encoding = torch.full((1, args.sep * (args.resize 369 | // 64) * (args.resize // 64)), 0) 370 | if torch.cuda.is_available(): 371 | zero_encoding = zero_encoding.cuda() 372 | 373 | AB_encoding = torch.cat( 374 | [common_A, zero_encoding, separate_B], dim=1) 375 | AB_decoding = decoder(AB_encoding) 376 | exps.append(AB_decoding) 377 | 378 | with torch.no_grad(): 379 | exps = torch.cat(exps, 0) 380 | 381 | if BtoA: 382 | vutils.save_image(exps, 383 | '%s/experiments_%06d-BtoA.png' % (args.out, iters), 384 | normalize=True, nrow=args.num_display + 1) 385 | else: 386 | vutils.save_image(exps, 387 | '%s/experiments_%06d-AtoB.png' % (args.out, iters), 388 | normalize=True, nrow=args.num_display + 1) 389 | 390 | 391 | def interpolate_fixed_common(args, e_common, e_separate_A, e_separate_B, decoder, imgA1, imgA2, imgB1, 392 | imgB2, content_img): 393 | test_domA, test_domB = get_test_imgs(args) 394 | exps = [] 395 | common = e_common(test_domB[content_img].unsqueeze(0)) 396 | a1 = e_separate_A(test_domA[imgA1].unsqueeze(0)) 397 | a2 = e_separate_A(test_domA[imgA2].unsqueeze(0)) 398 | b1 = e_separate_B(test_domB[imgB1].unsqueeze(0)) 399 | b2 = e_separate_B(test_domB[imgB2].unsqueeze(0)) 400 | with torch.no_grad(): 401 | filler = test_domB[0].unsqueeze(0).clone() 402 | exps.append(filler.fill_(0)) 403 | exps.append(test_domA[imgA1].unsqueeze(0)) 404 | for i in range(args.num_display - 2): 405 | exps.append(filler.fill_(0)) 406 | exps.append(test_domA[imgA2].unsqueeze(0)) 407 | 408 | for i in range(args.num_display): 409 | if i == 0: 410 | exps.append(test_domB[imgB1].unsqueeze(0)) 411 | elif i == args.num_display - 1: 412 | exps.append(test_domB[imgB2].unsqueeze(0)) 413 | else: 414 | exps.append(filler.fill_(0)) 415 | 416 | for j in range(args.num_display): 417 | cur_sep_A = (float(j) / (args.num_display - 1)) * a2 + \ 418 | (1 - float(j) / (args.num_display - 1)) * a1 419 | cur_sep_B = (float(i) / (args.num_display - 1)) * b2 + \ 420 | (1 - float(i) / (args.num_display - 1)) * b1 421 | encoding = torch.cat([common, cur_sep_A, cur_sep_B], dim=1) 422 | decoding = decoder(encoding) 423 | exps.append(decoding) 424 | 425 | with torch.no_grad(): 426 | exps = torch.cat(exps, 0) 427 | 428 | vutils.save_image(exps, 429 | '%s/interpolation_fixed_C.png' % (args.out), 430 | normalize=True, nrow=args.num_display + 1) 431 | 432 | 433 | def interpolate_fixed_A(args, e_common, e_separate_A, e_separate_B, decoder, imgC1, imgC2, imgB1, 434 | imgB2, imgA): 435 | test_domA, test_domB = get_test_imgs(args) 436 | exps = [] 437 | c1 = e_common(test_domB[imgC1].unsqueeze(0)) 438 | c2 = e_common(test_domB[imgC2].unsqueeze(0)) 439 | a = e_separate_A(test_domA[imgA].unsqueeze(0)) 440 | b1 = e_separate_B(test_domB[imgB1].unsqueeze(0)) 441 | b2 = e_separate_B(test_domB[imgB2].unsqueeze(0)) 442 | with torch.no_grad(): 443 | filler = test_domB[0].unsqueeze(0).clone() 444 | exps.append(filler.fill_(0)) 445 | exps.append(test_domB[imgC1].unsqueeze(0)) 446 | for i in range(args.num_display - 2): 447 | exps.append(filler.fill_(0)) 448 | exps.append(test_domB[imgC2].unsqueeze(0)) 449 | 450 | for i in range(args.num_display): 451 | if i == 0: 452 | exps.append(test_domB[imgB1].unsqueeze(0)) 453 | elif i == args.num_display - 1: 454 | exps.append(test_domB[imgB2].unsqueeze(0)) 455 | else: 456 | exps.append(filler.fill_(0)) 457 | 458 | for j in range(args.num_display): 459 | cur_common = (float(j) / (args.num_display - 1)) * c2 + \ 460 | (1 - float(j) / (args.num_display - 1)) * c1 461 | cur_sep_B = (float(i) / (args.num_display - 1)) * b2 + \ 462 | (1 - float(i) / (args.num_display - 1)) * b1 463 | encoding = torch.cat([cur_common, a, cur_sep_B], dim=1) 464 | decoding = decoder(encoding) 465 | exps.append(decoding) 466 | 467 | with torch.no_grad(): 468 | exps = torch.cat(exps, 0) 469 | 470 | vutils.save_image(exps, 471 | '%s/interpolation_fixed_A.png' % (args.out), 472 | normalize=True, nrow=args.num_display + 1) 473 | 474 | 475 | def interpolate_fixed_B(args, e_common, e_separate_A, e_separate_B, decoder, imgC1, imgC2, imgA1, 476 | imgA2, imgB): 477 | test_domA, test_domB = get_test_imgs(args) 478 | exps = [] 479 | c1 = e_common(test_domB[imgC1].unsqueeze(0)) 480 | c2 = e_common(test_domB[imgC2].unsqueeze(0)) 481 | a1 = e_separate_A(test_domA[imgA1].unsqueeze(0)) 482 | a2 = e_separate_A(test_domA[imgA2].unsqueeze(0)) 483 | b = e_separate_B(test_domB[imgB].unsqueeze(0)) 484 | with torch.no_grad(): 485 | filler = test_domB[0].unsqueeze(0).clone() 486 | exps.append(filler.fill_(0)) 487 | exps.append(test_domB[imgC1].unsqueeze(0)) 488 | for i in range(args.num_display - 2): 489 | exps.append(filler.fill_(0)) 490 | exps.append(test_domB[imgC2].unsqueeze(0)) 491 | 492 | for i in range(args.num_display): 493 | if i == 0: 494 | exps.append(test_domA[imgA1].unsqueeze(0)) 495 | elif i == args.num_display - 1: 496 | exps.append(test_domA[imgA2].unsqueeze(0)) 497 | else: 498 | exps.append(filler.fill_(0)) 499 | 500 | for j in range(args.num_display): 501 | cur_common = (float(j) / (args.num_display - 1)) * c2 + \ 502 | (1 - float(j) / (args.num_display - 1)) * c1 503 | cur_sep_A = (float(i) / (args.num_display - 1)) * a2 + \ 504 | (1 - float(i) / (args.num_display - 1)) * a1 505 | encoding = torch.cat([cur_common, cur_sep_A, b], dim=1) 506 | decoding = decoder(encoding) 507 | exps.append(decoding) 508 | 509 | with torch.no_grad(): 510 | exps = torch.cat(exps, 0) 511 | 512 | vutils.save_image(exps, 513 | '%s/interpolation_fixed_B.png' % (args.out), 514 | normalize=True, nrow=args.num_display + 1) 515 | 516 | 517 | if __name__ == '__main__': 518 | pass 519 | --------------------------------------------------------------------------------