├── LICENSE ├── README.md ├── UGATIT.py ├── assets ├── ablation.png ├── discriminator.png ├── generator.png ├── kid.png ├── teaser.png └── user_study.png ├── dataset.py ├── dataset └── YOUR_DATASET_NAME │ ├── testA │ └── female_2321.jpg │ ├── testB │ └── 3414.png │ ├── trainA │ └── female_222.jpg │ └── trainB │ └── 0006.png ├── main.py ├── networks.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Hyeonwoo Kang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## U-GAT-IT — Official PyTorch Implementation 2 | ### : Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation 3 | 4 |
5 | 6 |
7 | 8 | ### [Paper](https://arxiv.org/abs/1907.10830) | [Official Tensorflow code](https://github.com/taki0112/UGATIT) 9 | The results of the paper came from the **Tensorflow code** 10 | 11 | 12 | > **U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation**
13 | > **Junho Kim (NCSOFT)**, Minjae Kim (NCSOFT), Hyeonwoo Kang (NCSOFT), Kwanghee Lee (Boeing Korea) 14 | > 15 | > **Abstract** *We propose a novel method for unsupervised image-to-image translation, which incorporates a new attention module and a new learnable normalization function in an end-to-end manner. The attention module guides our model to focus on more important regions distinguishing between source and target domains based on the attention map obtained by the auxiliary classifier. Unlike previous attention-based methods which cannot handle the geometric changes between domains, our model can translate both images requiring holistic changes and images requiring large shape changes. Moreover, our new AdaLIN (Adaptive Layer-Instance Normalization) function helps our attention-guided model to flexibly control the amount of change in shape and texture by learned parameters depending on datasets. Experimental results show the superiority of the proposed method compared to the existing state-of-the-art models with a fixed network architecture and hyper-parameters.* 16 | 17 | ## Usage 18 | ``` 19 | ├── dataset 20 |    └── YOUR_DATASET_NAME 21 |    ├── trainA 22 |           ├── xxx.jpg (name, format doesn't matter) 23 | ├── yyy.png 24 | └── ... 25 |    ├── trainB 26 | ├── zzz.jpg 27 | ├── www.png 28 | └── ... 29 |    ├── testA 30 |    ├── aaa.jpg 31 | ├── bbb.png 32 | └── ... 33 |    └── testB 34 | ├── ccc.jpg 35 | ├── ddd.png 36 | └── ... 37 | ``` 38 | 39 | ### Train 40 | ``` 41 | > python main.py --dataset selfie2anime 42 | ``` 43 | * If the memory of gpu is **not sufficient**, set `--light` to True 44 | 45 | ### Test 46 | ``` 47 | > python main.py --dataset selfie2anime --phase test 48 | ``` 49 | 50 | ## Architecture 51 |
52 | 53 |
54 | 55 | --- 56 | 57 |
58 | 59 |
60 | 61 | ## Results 62 | ### Ablation study 63 |
64 | 65 |
66 | 67 | ### User study 68 |
69 | 70 |
71 | 72 | ### Comparison 73 |
74 | 75 |
76 | 77 | ## Citation 78 | If you find this code useful for your research, please cite our paper: 79 | 80 | ``` 81 | @misc{kim2019ugatit, 82 | title={U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation}, 83 | author={Junho Kim and Minjae Kim and Hyeonwoo Kang and Kwanghee Lee}, 84 | year={2019}, 85 | eprint={1907.10830}, 86 | archivePrefix={arXiv}, 87 | primaryClass={cs.CV} 88 | } 89 | ``` 90 | 91 | ## Author 92 | [Junho Kim](http://bit.ly/jhkim_ai), Minjae Kim, Hyeonwoo Kang, Kwanghee Lee 93 | -------------------------------------------------------------------------------- /UGATIT.py: -------------------------------------------------------------------------------- 1 | import time, itertools 2 | from dataset import ImageFolder 3 | from torchvision import transforms 4 | from torch.utils.data import DataLoader 5 | from networks import * 6 | from utils import * 7 | from glob import glob 8 | 9 | class UGATIT(object) : 10 | def __init__(self, args): 11 | self.light = args.light 12 | 13 | if self.light : 14 | self.model_name = 'UGATIT_light' 15 | else : 16 | self.model_name = 'UGATIT' 17 | 18 | self.result_dir = args.result_dir 19 | self.dataset = args.dataset 20 | 21 | self.iteration = args.iteration 22 | self.decay_flag = args.decay_flag 23 | 24 | self.batch_size = args.batch_size 25 | self.print_freq = args.print_freq 26 | self.save_freq = args.save_freq 27 | 28 | self.lr = args.lr 29 | self.weight_decay = args.weight_decay 30 | self.ch = args.ch 31 | 32 | """ Weight """ 33 | self.adv_weight = args.adv_weight 34 | self.cycle_weight = args.cycle_weight 35 | self.identity_weight = args.identity_weight 36 | self.cam_weight = args.cam_weight 37 | 38 | """ Generator """ 39 | self.n_res = args.n_res 40 | 41 | """ Discriminator """ 42 | self.n_dis = args.n_dis 43 | 44 | self.img_size = args.img_size 45 | self.img_ch = args.img_ch 46 | 47 | self.device = args.device 48 | self.benchmark_flag = args.benchmark_flag 49 | self.resume = args.resume 50 | 51 | if torch.backends.cudnn.enabled and self.benchmark_flag: 52 | print('set benchmark !') 53 | torch.backends.cudnn.benchmark = True 54 | 55 | print() 56 | 57 | print("##### Information #####") 58 | print("# light : ", self.light) 59 | print("# dataset : ", self.dataset) 60 | print("# batch_size : ", self.batch_size) 61 | print("# iteration per epoch : ", self.iteration) 62 | 63 | print() 64 | 65 | print("##### Generator #####") 66 | print("# residual blocks : ", self.n_res) 67 | 68 | print() 69 | 70 | print("##### Discriminator #####") 71 | print("# discriminator layer : ", self.n_dis) 72 | 73 | print() 74 | 75 | print("##### Weight #####") 76 | print("# adv_weight : ", self.adv_weight) 77 | print("# cycle_weight : ", self.cycle_weight) 78 | print("# identity_weight : ", self.identity_weight) 79 | print("# cam_weight : ", self.cam_weight) 80 | 81 | ################################################################################## 82 | # Model 83 | ################################################################################## 84 | 85 | def build_model(self): 86 | """ DataLoader """ 87 | train_transform = transforms.Compose([ 88 | transforms.RandomHorizontalFlip(), 89 | transforms.Resize((self.img_size + 30, self.img_size+30)), 90 | transforms.RandomCrop(self.img_size), 91 | transforms.ToTensor(), 92 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 93 | ]) 94 | test_transform = transforms.Compose([ 95 | transforms.Resize((self.img_size, self.img_size)), 96 | transforms.ToTensor(), 97 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 98 | ]) 99 | 100 | self.trainA = ImageFolder(os.path.join('dataset', self.dataset, 'trainA'), train_transform) 101 | self.trainB = ImageFolder(os.path.join('dataset', self.dataset, 'trainB'), train_transform) 102 | self.testA = ImageFolder(os.path.join('dataset', self.dataset, 'testA'), test_transform) 103 | self.testB = ImageFolder(os.path.join('dataset', self.dataset, 'testB'), test_transform) 104 | self.trainA_loader = DataLoader(self.trainA, batch_size=self.batch_size, shuffle=True) 105 | self.trainB_loader = DataLoader(self.trainB, batch_size=self.batch_size, shuffle=True) 106 | self.testA_loader = DataLoader(self.testA, batch_size=1, shuffle=False) 107 | self.testB_loader = DataLoader(self.testB, batch_size=1, shuffle=False) 108 | 109 | """ Define Generator, Discriminator """ 110 | self.genA2B = ResnetGenerator(input_nc=3, output_nc=3, ngf=self.ch, n_blocks=self.n_res, img_size=self.img_size, light=self.light).to(self.device) 111 | self.genB2A = ResnetGenerator(input_nc=3, output_nc=3, ngf=self.ch, n_blocks=self.n_res, img_size=self.img_size, light=self.light).to(self.device) 112 | self.disGA = Discriminator(input_nc=3, ndf=self.ch, n_layers=7).to(self.device) 113 | self.disGB = Discriminator(input_nc=3, ndf=self.ch, n_layers=7).to(self.device) 114 | self.disLA = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device) 115 | self.disLB = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device) 116 | 117 | """ Define Loss """ 118 | self.L1_loss = nn.L1Loss().to(self.device) 119 | self.MSE_loss = nn.MSELoss().to(self.device) 120 | self.BCE_loss = nn.BCEWithLogitsLoss().to(self.device) 121 | 122 | """ Trainer """ 123 | self.G_optim = torch.optim.Adam(itertools.chain(self.genA2B.parameters(), self.genB2A.parameters()), lr=self.lr, betas=(0.5, 0.999), weight_decay=self.weight_decay) 124 | self.D_optim = torch.optim.Adam(itertools.chain(self.disGA.parameters(), self.disGB.parameters(), self.disLA.parameters(), self.disLB.parameters()), lr=self.lr, betas=(0.5, 0.999), weight_decay=self.weight_decay) 125 | 126 | """ Define Rho clipper to constraint the value of rho in AdaILN and ILN""" 127 | self.Rho_clipper = RhoClipper(0, 1) 128 | 129 | def train(self): 130 | self.genA2B.train(), self.genB2A.train(), self.disGA.train(), self.disGB.train(), self.disLA.train(), self.disLB.train() 131 | 132 | start_iter = 1 133 | if self.resume: 134 | model_list = glob(os.path.join(self.result_dir, self.dataset, 'model', '*.pt')) 135 | if not len(model_list) == 0: 136 | model_list.sort() 137 | start_iter = int(model_list[-1].split('_')[-1].split('.')[0]) 138 | self.load(os.path.join(self.result_dir, self.dataset, 'model'), start_iter) 139 | print(" [*] Load SUCCESS") 140 | if self.decay_flag and start_iter > (self.iteration // 2): 141 | self.G_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2) 142 | self.D_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2) 143 | 144 | # training loop 145 | print('training start !') 146 | start_time = time.time() 147 | for step in range(start_iter, self.iteration + 1): 148 | if self.decay_flag and step > (self.iteration // 2): 149 | self.G_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) 150 | self.D_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) 151 | 152 | try: 153 | real_A, _ = trainA_iter.next() 154 | except: 155 | trainA_iter = iter(self.trainA_loader) 156 | real_A, _ = trainA_iter.next() 157 | 158 | try: 159 | real_B, _ = trainB_iter.next() 160 | except: 161 | trainB_iter = iter(self.trainB_loader) 162 | real_B, _ = trainB_iter.next() 163 | 164 | real_A, real_B = real_A.to(self.device), real_B.to(self.device) 165 | 166 | # Update D 167 | self.D_optim.zero_grad() 168 | 169 | fake_A2B, _, _ = self.genA2B(real_A) 170 | fake_B2A, _, _ = self.genB2A(real_B) 171 | 172 | real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A) 173 | real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A) 174 | real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B) 175 | real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B) 176 | 177 | fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) 178 | fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) 179 | fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) 180 | fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) 181 | 182 | D_ad_loss_GA = self.MSE_loss(real_GA_logit, torch.ones_like(real_GA_logit).to(self.device)) + self.MSE_loss(fake_GA_logit, torch.zeros_like(fake_GA_logit).to(self.device)) 183 | D_ad_cam_loss_GA = self.MSE_loss(real_GA_cam_logit, torch.ones_like(real_GA_cam_logit).to(self.device)) + self.MSE_loss(fake_GA_cam_logit, torch.zeros_like(fake_GA_cam_logit).to(self.device)) 184 | D_ad_loss_LA = self.MSE_loss(real_LA_logit, torch.ones_like(real_LA_logit).to(self.device)) + self.MSE_loss(fake_LA_logit, torch.zeros_like(fake_LA_logit).to(self.device)) 185 | D_ad_cam_loss_LA = self.MSE_loss(real_LA_cam_logit, torch.ones_like(real_LA_cam_logit).to(self.device)) + self.MSE_loss(fake_LA_cam_logit, torch.zeros_like(fake_LA_cam_logit).to(self.device)) 186 | D_ad_loss_GB = self.MSE_loss(real_GB_logit, torch.ones_like(real_GB_logit).to(self.device)) + self.MSE_loss(fake_GB_logit, torch.zeros_like(fake_GB_logit).to(self.device)) 187 | D_ad_cam_loss_GB = self.MSE_loss(real_GB_cam_logit, torch.ones_like(real_GB_cam_logit).to(self.device)) + self.MSE_loss(fake_GB_cam_logit, torch.zeros_like(fake_GB_cam_logit).to(self.device)) 188 | D_ad_loss_LB = self.MSE_loss(real_LB_logit, torch.ones_like(real_LB_logit).to(self.device)) + self.MSE_loss(fake_LB_logit, torch.zeros_like(fake_LB_logit).to(self.device)) 189 | D_ad_cam_loss_LB = self.MSE_loss(real_LB_cam_logit, torch.ones_like(real_LB_cam_logit).to(self.device)) + self.MSE_loss(fake_LB_cam_logit, torch.zeros_like(fake_LB_cam_logit).to(self.device)) 190 | 191 | D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA) 192 | D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB) 193 | 194 | Discriminator_loss = D_loss_A + D_loss_B 195 | Discriminator_loss.backward() 196 | self.D_optim.step() 197 | 198 | # Update G 199 | self.G_optim.zero_grad() 200 | 201 | fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A) 202 | fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B) 203 | 204 | fake_A2B2A, _, _ = self.genB2A(fake_A2B) 205 | fake_B2A2B, _, _ = self.genA2B(fake_B2A) 206 | 207 | fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A) 208 | fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B) 209 | 210 | fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) 211 | fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) 212 | fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) 213 | fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) 214 | 215 | G_ad_loss_GA = self.MSE_loss(fake_GA_logit, torch.ones_like(fake_GA_logit).to(self.device)) 216 | G_ad_cam_loss_GA = self.MSE_loss(fake_GA_cam_logit, torch.ones_like(fake_GA_cam_logit).to(self.device)) 217 | G_ad_loss_LA = self.MSE_loss(fake_LA_logit, torch.ones_like(fake_LA_logit).to(self.device)) 218 | G_ad_cam_loss_LA = self.MSE_loss(fake_LA_cam_logit, torch.ones_like(fake_LA_cam_logit).to(self.device)) 219 | G_ad_loss_GB = self.MSE_loss(fake_GB_logit, torch.ones_like(fake_GB_logit).to(self.device)) 220 | G_ad_cam_loss_GB = self.MSE_loss(fake_GB_cam_logit, torch.ones_like(fake_GB_cam_logit).to(self.device)) 221 | G_ad_loss_LB = self.MSE_loss(fake_LB_logit, torch.ones_like(fake_LB_logit).to(self.device)) 222 | G_ad_cam_loss_LB = self.MSE_loss(fake_LB_cam_logit, torch.ones_like(fake_LB_cam_logit).to(self.device)) 223 | 224 | G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A) 225 | G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B) 226 | 227 | G_identity_loss_A = self.L1_loss(fake_A2A, real_A) 228 | G_identity_loss_B = self.L1_loss(fake_B2B, real_B) 229 | 230 | G_cam_loss_A = self.BCE_loss(fake_B2A_cam_logit, torch.ones_like(fake_B2A_cam_logit).to(self.device)) + self.BCE_loss(fake_A2A_cam_logit, torch.zeros_like(fake_A2A_cam_logit).to(self.device)) 231 | G_cam_loss_B = self.BCE_loss(fake_A2B_cam_logit, torch.ones_like(fake_A2B_cam_logit).to(self.device)) + self.BCE_loss(fake_B2B_cam_logit, torch.zeros_like(fake_B2B_cam_logit).to(self.device)) 232 | 233 | G_loss_A = self.adv_weight * (G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A 234 | G_loss_B = self.adv_weight * (G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B 235 | 236 | Generator_loss = G_loss_A + G_loss_B 237 | Generator_loss.backward() 238 | self.G_optim.step() 239 | 240 | # clip parameter of AdaILN and ILN, applied after optimizer step 241 | self.genA2B.apply(self.Rho_clipper) 242 | self.genB2A.apply(self.Rho_clipper) 243 | 244 | print("[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (step, self.iteration, time.time() - start_time, Discriminator_loss, Generator_loss)) 245 | if step % self.print_freq == 0: 246 | train_sample_num = 5 247 | test_sample_num = 5 248 | A2B = np.zeros((self.img_size * 7, 0, 3)) 249 | B2A = np.zeros((self.img_size * 7, 0, 3)) 250 | 251 | self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval(), self.disGB.eval(), self.disLA.eval(), self.disLB.eval() 252 | for _ in range(train_sample_num): 253 | try: 254 | real_A, _ = trainA_iter.next() 255 | except: 256 | trainA_iter = iter(self.trainA_loader) 257 | real_A, _ = trainA_iter.next() 258 | 259 | try: 260 | real_B, _ = trainB_iter.next() 261 | except: 262 | trainB_iter = iter(self.trainB_loader) 263 | real_B, _ = trainB_iter.next() 264 | real_A, real_B = real_A.to(self.device), real_B.to(self.device) 265 | 266 | fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) 267 | fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) 268 | 269 | fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) 270 | fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) 271 | 272 | fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) 273 | fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) 274 | 275 | A2B = np.concatenate((A2B, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))), 276 | cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), 277 | RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), 278 | cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), 279 | RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), 280 | cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), 281 | RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1) 282 | 283 | B2A = np.concatenate((B2A, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))), 284 | cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), 285 | RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), 286 | cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), 287 | RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), 288 | cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), 289 | RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1) 290 | 291 | for _ in range(test_sample_num): 292 | try: 293 | real_A, _ = testA_iter.next() 294 | except: 295 | testA_iter = iter(self.testA_loader) 296 | real_A, _ = testA_iter.next() 297 | 298 | try: 299 | real_B, _ = testB_iter.next() 300 | except: 301 | testB_iter = iter(self.testB_loader) 302 | real_B, _ = testB_iter.next() 303 | real_A, real_B = real_A.to(self.device), real_B.to(self.device) 304 | 305 | fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) 306 | fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) 307 | 308 | fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) 309 | fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) 310 | 311 | fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) 312 | fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) 313 | 314 | A2B = np.concatenate((A2B, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))), 315 | cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), 316 | RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), 317 | cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), 318 | RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), 319 | cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), 320 | RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1) 321 | 322 | B2A = np.concatenate((B2A, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))), 323 | cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), 324 | RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), 325 | cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), 326 | RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), 327 | cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), 328 | RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1) 329 | 330 | cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'img', 'A2B_%07d.png' % step), A2B * 255.0) 331 | cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'img', 'B2A_%07d.png' % step), B2A * 255.0) 332 | self.genA2B.train(), self.genB2A.train(), self.disGA.train(), self.disGB.train(), self.disLA.train(), self.disLB.train() 333 | 334 | if step % self.save_freq == 0: 335 | self.save(os.path.join(self.result_dir, self.dataset, 'model'), step) 336 | 337 | if step % 1000 == 0: 338 | params = {} 339 | params['genA2B'] = self.genA2B.state_dict() 340 | params['genB2A'] = self.genB2A.state_dict() 341 | params['disGA'] = self.disGA.state_dict() 342 | params['disGB'] = self.disGB.state_dict() 343 | params['disLA'] = self.disLA.state_dict() 344 | params['disLB'] = self.disLB.state_dict() 345 | torch.save(params, os.path.join(self.result_dir, self.dataset + '_params_latest.pt')) 346 | 347 | def save(self, dir, step): 348 | params = {} 349 | params['genA2B'] = self.genA2B.state_dict() 350 | params['genB2A'] = self.genB2A.state_dict() 351 | params['disGA'] = self.disGA.state_dict() 352 | params['disGB'] = self.disGB.state_dict() 353 | params['disLA'] = self.disLA.state_dict() 354 | params['disLB'] = self.disLB.state_dict() 355 | torch.save(params, os.path.join(dir, self.dataset + '_params_%07d.pt' % step)) 356 | 357 | def load(self, dir, step): 358 | params = torch.load(os.path.join(dir, self.dataset + '_params_%07d.pt' % step)) 359 | self.genA2B.load_state_dict(params['genA2B']) 360 | self.genB2A.load_state_dict(params['genB2A']) 361 | self.disGA.load_state_dict(params['disGA']) 362 | self.disGB.load_state_dict(params['disGB']) 363 | self.disLA.load_state_dict(params['disLA']) 364 | self.disLB.load_state_dict(params['disLB']) 365 | 366 | def test(self): 367 | model_list = glob(os.path.join(self.result_dir, self.dataset, 'model', '*.pt')) 368 | if not len(model_list) == 0: 369 | model_list.sort() 370 | iter = int(model_list[-1].split('_')[-1].split('.')[0]) 371 | self.load(os.path.join(self.result_dir, self.dataset, 'model'), iter) 372 | print(" [*] Load SUCCESS") 373 | else: 374 | print(" [*] Load FAILURE") 375 | return 376 | 377 | self.genA2B.eval(), self.genB2A.eval() 378 | for n, (real_A, _) in enumerate(self.testA_loader): 379 | real_A = real_A.to(self.device) 380 | 381 | fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) 382 | 383 | fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) 384 | 385 | fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) 386 | 387 | A2B = np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))), 388 | cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), 389 | RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), 390 | cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), 391 | RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), 392 | cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), 393 | RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0) 394 | 395 | cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'test', 'A2B_%d.png' % (n + 1)), A2B * 255.0) 396 | 397 | for n, (real_B, _) in enumerate(self.testB_loader): 398 | real_B = real_B.to(self.device) 399 | 400 | fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) 401 | 402 | fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) 403 | 404 | fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) 405 | 406 | B2A = np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))), 407 | cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), 408 | RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), 409 | cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), 410 | RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), 411 | cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), 412 | RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0) 413 | 414 | cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'test', 'B2A_%d.png' % (n + 1)), B2A * 255.0) 415 | -------------------------------------------------------------------------------- /assets/ablation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liaoxy169/UGATIT-pytorch/a68011d6365725a4d82a9640c857d1d606c9b138/assets/ablation.png -------------------------------------------------------------------------------- /assets/discriminator.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liaoxy169/UGATIT-pytorch/a68011d6365725a4d82a9640c857d1d606c9b138/assets/discriminator.png -------------------------------------------------------------------------------- /assets/generator.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liaoxy169/UGATIT-pytorch/a68011d6365725a4d82a9640c857d1d606c9b138/assets/generator.png -------------------------------------------------------------------------------- /assets/kid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liaoxy169/UGATIT-pytorch/a68011d6365725a4d82a9640c857d1d606c9b138/assets/kid.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liaoxy169/UGATIT-pytorch/a68011d6365725a4d82a9640c857d1d606c9b138/assets/teaser.png -------------------------------------------------------------------------------- /assets/user_study.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liaoxy169/UGATIT-pytorch/a68011d6365725a4d82a9640c857d1d606c9b138/assets/user_study.png -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | 5 | import os 6 | import os.path 7 | 8 | 9 | def has_file_allowed_extension(filename, extensions): 10 | """Checks if a file is an allowed extension. 11 | 12 | Args: 13 | filename (string): path to a file 14 | 15 | Returns: 16 | bool: True if the filename ends with a known image extension 17 | """ 18 | filename_lower = filename.lower() 19 | return any(filename_lower.endswith(ext) for ext in extensions) 20 | 21 | 22 | def find_classes(dir): 23 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 24 | classes.sort() 25 | class_to_idx = {classes[i]: i for i in range(len(classes))} 26 | return classes, class_to_idx 27 | 28 | 29 | def make_dataset(dir, extensions): 30 | images = [] 31 | for root, _, fnames in sorted(os.walk(dir)): 32 | for fname in sorted(fnames): 33 | if has_file_allowed_extension(fname, extensions): 34 | path = os.path.join(root, fname) 35 | item = (path, 0) 36 | images.append(item) 37 | 38 | return images 39 | 40 | 41 | class DatasetFolder(data.Dataset): 42 | def __init__(self, root, loader, extensions, transform=None, target_transform=None): 43 | # classes, class_to_idx = find_classes(root) 44 | samples = make_dataset(root, extensions) 45 | if len(samples) == 0: 46 | raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n" 47 | "Supported extensions are: " + ",".join(extensions))) 48 | 49 | self.root = root 50 | self.loader = loader 51 | self.extensions = extensions 52 | self.samples = samples 53 | 54 | self.transform = transform 55 | self.target_transform = target_transform 56 | 57 | def __getitem__(self, index): 58 | """ 59 | Args: 60 | index (int): Index 61 | 62 | Returns: 63 | tuple: (sample, target) where target is class_index of the target class. 64 | """ 65 | path, target = self.samples[index] 66 | sample = self.loader(path) 67 | if self.transform is not None: 68 | sample = self.transform(sample) 69 | if self.target_transform is not None: 70 | target = self.target_transform(target) 71 | 72 | return sample, target 73 | 74 | def __len__(self): 75 | return len(self.samples) 76 | 77 | def __repr__(self): 78 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 79 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 80 | fmt_str += ' Root Location: {}\n'.format(self.root) 81 | tmp = ' Transforms (if any): ' 82 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 83 | tmp = ' Target Transforms (if any): ' 84 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 85 | return fmt_str 86 | 87 | 88 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 89 | 90 | 91 | def pil_loader(path): 92 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 93 | with open(path, 'rb') as f: 94 | img = Image.open(f) 95 | return img.convert('RGB') 96 | 97 | 98 | def default_loader(path): 99 | return pil_loader(path) 100 | 101 | 102 | class ImageFolder(DatasetFolder): 103 | def __init__(self, root, transform=None, target_transform=None, 104 | loader=default_loader): 105 | super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, 106 | transform=transform, 107 | target_transform=target_transform) 108 | self.imgs = self.samples 109 | -------------------------------------------------------------------------------- /dataset/YOUR_DATASET_NAME/testA/female_2321.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liaoxy169/UGATIT-pytorch/a68011d6365725a4d82a9640c857d1d606c9b138/dataset/YOUR_DATASET_NAME/testA/female_2321.jpg -------------------------------------------------------------------------------- /dataset/YOUR_DATASET_NAME/testB/3414.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liaoxy169/UGATIT-pytorch/a68011d6365725a4d82a9640c857d1d606c9b138/dataset/YOUR_DATASET_NAME/testB/3414.png -------------------------------------------------------------------------------- /dataset/YOUR_DATASET_NAME/trainA/female_222.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liaoxy169/UGATIT-pytorch/a68011d6365725a4d82a9640c857d1d606c9b138/dataset/YOUR_DATASET_NAME/trainA/female_222.jpg -------------------------------------------------------------------------------- /dataset/YOUR_DATASET_NAME/trainB/0006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liaoxy169/UGATIT-pytorch/a68011d6365725a4d82a9640c857d1d606c9b138/dataset/YOUR_DATASET_NAME/trainB/0006.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from UGATIT import UGATIT 2 | import argparse 3 | from utils import * 4 | 5 | """parsing and configuration""" 6 | 7 | def parse_args(): 8 | desc = "Pytorch implementation of U-GAT-IT" 9 | parser = argparse.ArgumentParser(description=desc) 10 | parser.add_argument('--phase', type=str, default='train', help='[train / test]') 11 | parser.add_argument('--light', type=str2bool, default=False, help='[U-GAT-IT full version / U-GAT-IT light version]') 12 | parser.add_argument('--dataset', type=str, default='YOUR_DATASET_NAME', help='dataset_name') 13 | 14 | parser.add_argument('--iteration', type=int, default=1000000, help='The number of training iterations') 15 | parser.add_argument('--batch_size', type=int, default=1, help='The size of batch size') 16 | parser.add_argument('--print_freq', type=int, default=1000, help='The number of image print freq') 17 | parser.add_argument('--save_freq', type=int, default=100000, help='The number of model save freq') 18 | parser.add_argument('--decay_flag', type=str2bool, default=True, help='The decay_flag') 19 | 20 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate') 21 | parser.add_argument('--weight_decay', type=float, default=0.0001, help='The weight decay') 22 | parser.add_argument('--adv_weight', type=int, default=1, help='Weight for GAN') 23 | parser.add_argument('--cycle_weight', type=int, default=10, help='Weight for Cycle') 24 | parser.add_argument('--identity_weight', type=int, default=10, help='Weight for Identity') 25 | parser.add_argument('--cam_weight', type=int, default=1000, help='Weight for CAM') 26 | 27 | parser.add_argument('--ch', type=int, default=64, help='base channel number per layer') 28 | parser.add_argument('--n_res', type=int, default=4, help='The number of resblock') 29 | parser.add_argument('--n_dis', type=int, default=6, help='The number of discriminator layer') 30 | 31 | parser.add_argument('--img_size', type=int, default=256, help='The size of image') 32 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel') 33 | 34 | parser.add_argument('--result_dir', type=str, default='results', help='Directory name to save the results') 35 | parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'], help='Set gpu mode; [cpu, cuda]') 36 | parser.add_argument('--benchmark_flag', type=str2bool, default=False) 37 | parser.add_argument('--resume', type=str2bool, default=False) 38 | 39 | return check_args(parser.parse_args()) 40 | 41 | """checking arguments""" 42 | def check_args(args): 43 | # --result_dir 44 | check_folder(os.path.join(args.result_dir, args.dataset, 'model')) 45 | check_folder(os.path.join(args.result_dir, args.dataset, 'img')) 46 | check_folder(os.path.join(args.result_dir, args.dataset, 'test')) 47 | 48 | # --epoch 49 | try: 50 | assert args.epoch >= 1 51 | except: 52 | print('number of epochs must be larger than or equal to one') 53 | 54 | # --batch_size 55 | try: 56 | assert args.batch_size >= 1 57 | except: 58 | print('batch size must be larger than or equal to one') 59 | return args 60 | 61 | """main""" 62 | def main(): 63 | # parse arguments 64 | args = parse_args() 65 | if args is None: 66 | exit() 67 | 68 | # open session 69 | gan = UGATIT(args) 70 | 71 | # build graph 72 | gan.build_model() 73 | 74 | if args.phase == 'train' : 75 | gan.train() 76 | print(" [*] Training finished!") 77 | 78 | if args.phase == 'test' : 79 | gan.test() 80 | print(" [*] Test finished!") 81 | 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | 5 | 6 | class ResnetGenerator(nn.Module): 7 | def __init__(self, input_nc, output_nc, ngf=64, n_blocks=6, img_size=256, light=False): 8 | assert(n_blocks >= 0) 9 | super(ResnetGenerator, self).__init__() 10 | self.input_nc = input_nc 11 | self.output_nc = output_nc 12 | self.ngf = ngf 13 | self.n_blocks = n_blocks 14 | self.img_size = img_size 15 | self.light = light 16 | 17 | DownBlock = [] 18 | DownBlock += [nn.ReflectionPad2d(3), 19 | nn.Conv2d(input_nc, ngf, kernel_size=7, stride=1, padding=0, bias=False), 20 | nn.InstanceNorm2d(ngf), 21 | nn.ReLU(True)] 22 | 23 | # Down-Sampling 24 | n_downsampling = 2 25 | for i in range(n_downsampling): 26 | mult = 2**i 27 | DownBlock += [nn.ReflectionPad2d(1), 28 | nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=0, bias=False), 29 | nn.InstanceNorm2d(ngf * mult * 2), 30 | nn.ReLU(True)] 31 | 32 | # Down-Sampling Bottleneck 33 | mult = 2**n_downsampling 34 | for i in range(n_blocks): 35 | DownBlock += [ResnetBlock(ngf * mult, use_bias=False)] 36 | 37 | # Class Activation Map 38 | self.gap_fc = nn.Linear(ngf * mult, 1, bias=False) 39 | self.gmp_fc = nn.Linear(ngf * mult, 1, bias=False) 40 | self.conv1x1 = nn.Conv2d(ngf * mult * 2, ngf * mult, kernel_size=1, stride=1, bias=True) 41 | self.relu = nn.ReLU(True) 42 | 43 | # Gamma, Beta block 44 | if self.light: 45 | FC = [nn.Linear(ngf * mult, ngf * mult, bias=False), 46 | nn.ReLU(True), 47 | nn.Linear(ngf * mult, ngf * mult, bias=False), 48 | nn.ReLU(True)] 49 | else: 50 | FC = [nn.Linear(img_size // mult * img_size // mult * ngf * mult, ngf * mult, bias=False), 51 | nn.ReLU(True), 52 | nn.Linear(ngf * mult, ngf * mult, bias=False), 53 | nn.ReLU(True)] 54 | self.gamma = nn.Linear(ngf * mult, ngf * mult, bias=False) 55 | self.beta = nn.Linear(ngf * mult, ngf * mult, bias=False) 56 | 57 | # Up-Sampling Bottleneck 58 | for i in range(n_blocks): 59 | setattr(self, 'UpBlock1_' + str(i+1), ResnetAdaILNBlock(ngf * mult, use_bias=False)) 60 | 61 | # Up-Sampling 62 | UpBlock2 = [] 63 | for i in range(n_downsampling): 64 | mult = 2**(n_downsampling - i) 65 | UpBlock2 += [nn.Upsample(scale_factor=2, mode='nearest'), 66 | nn.ReflectionPad2d(1), 67 | nn.Conv2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=1, padding=0, bias=False), 68 | ILN(int(ngf * mult / 2)), 69 | nn.ReLU(True)] 70 | 71 | UpBlock2 += [nn.ReflectionPad2d(3), 72 | nn.Conv2d(ngf, output_nc, kernel_size=7, stride=1, padding=0, bias=False), 73 | nn.Tanh()] 74 | 75 | self.DownBlock = nn.Sequential(*DownBlock) 76 | self.FC = nn.Sequential(*FC) 77 | self.UpBlock2 = nn.Sequential(*UpBlock2) 78 | 79 | def forward(self, input): 80 | x = self.DownBlock(input) 81 | 82 | gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) 83 | gap_logit = self.gap_fc(gap.view(x.shape[0], -1)) 84 | gap_weight = list(self.gap_fc.parameters())[0] 85 | gap = x * gap_weight.unsqueeze(2).unsqueeze(3) 86 | 87 | gmp = torch.nn.functional.adaptive_max_pool2d(x, 1) 88 | gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1)) 89 | gmp_weight = list(self.gmp_fc.parameters())[0] 90 | gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3) 91 | 92 | cam_logit = torch.cat([gap_logit, gmp_logit], 1) 93 | x = torch.cat([gap, gmp], 1) 94 | x = self.relu(self.conv1x1(x)) 95 | 96 | heatmap = torch.sum(x, dim=1, keepdim=True) 97 | 98 | if self.light: 99 | x_ = torch.nn.functional.adaptive_avg_pool2d(x, 1) 100 | x_ = self.FC(x_.view(x_.shape[0], -1)) 101 | else: 102 | x_ = self.FC(x.view(x.shape[0], -1)) 103 | gamma, beta = self.gamma(x_), self.beta(x_) 104 | 105 | 106 | for i in range(self.n_blocks): 107 | x = getattr(self, 'UpBlock1_' + str(i+1))(x, gamma, beta) 108 | out = self.UpBlock2(x) 109 | 110 | return out, cam_logit, heatmap 111 | 112 | 113 | class ResnetBlock(nn.Module): 114 | def __init__(self, dim, use_bias): 115 | super(ResnetBlock, self).__init__() 116 | conv_block = [] 117 | conv_block += [nn.ReflectionPad2d(1), 118 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias), 119 | nn.InstanceNorm2d(dim), 120 | nn.ReLU(True)] 121 | 122 | conv_block += [nn.ReflectionPad2d(1), 123 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias), 124 | nn.InstanceNorm2d(dim)] 125 | 126 | self.conv_block = nn.Sequential(*conv_block) 127 | 128 | def forward(self, x): 129 | out = x + self.conv_block(x) 130 | return out 131 | 132 | 133 | class ResnetAdaILNBlock(nn.Module): 134 | def __init__(self, dim, use_bias): 135 | super(ResnetAdaILNBlock, self).__init__() 136 | self.pad1 = nn.ReflectionPad2d(1) 137 | self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias) 138 | self.norm1 = adaILN(dim) 139 | self.relu1 = nn.ReLU(True) 140 | 141 | self.pad2 = nn.ReflectionPad2d(1) 142 | self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias) 143 | self.norm2 = adaILN(dim) 144 | 145 | def forward(self, x, gamma, beta): 146 | out = self.pad1(x) 147 | out = self.conv1(out) 148 | out = self.norm1(out, gamma, beta) 149 | out = self.relu1(out) 150 | out = self.pad2(out) 151 | out = self.conv2(out) 152 | out = self.norm2(out, gamma, beta) 153 | 154 | return out 155 | 156 | 157 | class adaILN(nn.Module): 158 | def __init__(self, num_features, eps=1e-5): 159 | super(adaILN, self).__init__() 160 | self.eps = eps 161 | self.rho = Parameter(torch.Tensor(1, num_features, 1, 1)) 162 | self.rho.data.fill_(0.9) 163 | 164 | def forward(self, input, gamma, beta): 165 | in_mean, in_var = torch.mean(torch.mean(input, dim=2, keepdim=True), dim=3, keepdim=True), torch.var(torch.var(input, dim=2, keepdim=True), dim=3, keepdim=True) 166 | out_in = (input - in_mean) / torch.sqrt(in_var + self.eps) 167 | ln_mean, ln_var = torch.mean(torch.mean(torch.mean(input, dim=1, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True), torch.var(torch.var(torch.var(input, dim=1, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True) 168 | out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps) 169 | out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln 170 | out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3) 171 | 172 | return out 173 | 174 | 175 | class ILN(nn.Module): 176 | def __init__(self, num_features, eps=1e-5): 177 | super(ILN, self).__init__() 178 | self.eps = eps 179 | self.rho = Parameter(torch.Tensor(1, num_features, 1, 1)) 180 | self.gamma = Parameter(torch.Tensor(1, num_features, 1, 1)) 181 | self.beta = Parameter(torch.Tensor(1, num_features, 1, 1)) 182 | self.rho.data.fill_(0.0) 183 | self.gamma.data.fill_(1.0) 184 | self.beta.data.fill_(0.0) 185 | 186 | def forward(self, input): 187 | in_mean, in_var = torch.mean(torch.mean(input, dim=2, keepdim=True), dim=3, keepdim=True), torch.var(torch.var(input, dim=2, keepdim=True), dim=3, keepdim=True) 188 | out_in = (input - in_mean) / torch.sqrt(in_var + self.eps) 189 | ln_mean, ln_var = torch.mean(torch.mean(torch.mean(input, dim=1, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True), torch.var(torch.var(torch.var(input, dim=1, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True) 190 | out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps) 191 | out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln 192 | out = out * self.gamma.expand(input.shape[0], -1, -1, -1) + self.beta.expand(input.shape[0], -1, -1, -1) 193 | 194 | return out 195 | 196 | 197 | class Discriminator(nn.Module): 198 | def __init__(self, input_nc, ndf=64, n_layers=5): 199 | super(Discriminator, self).__init__() 200 | model = [nn.ReflectionPad2d(1), 201 | nn.utils.spectral_norm( 202 | nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=0, bias=True)), 203 | nn.LeakyReLU(0.2, True)] 204 | 205 | for i in range(1, n_layers - 2): 206 | mult = 2 ** (i - 1) 207 | model += [nn.ReflectionPad2d(1), 208 | nn.utils.spectral_norm( 209 | nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=2, padding=0, bias=True)), 210 | nn.LeakyReLU(0.2, True)] 211 | 212 | mult = 2 ** (n_layers - 2 - 1) 213 | model += [nn.ReflectionPad2d(1), 214 | nn.utils.spectral_norm( 215 | nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=1, padding=0, bias=True)), 216 | nn.LeakyReLU(0.2, True)] 217 | 218 | # Class Activation Map 219 | mult = 2 ** (n_layers - 2) 220 | self.gap_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False)) 221 | self.gmp_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False)) 222 | self.conv1x1 = nn.Conv2d(ndf * mult * 2, ndf * mult, kernel_size=1, stride=1, bias=True) 223 | self.leaky_relu = nn.LeakyReLU(0.2, True) 224 | 225 | self.pad = nn.ReflectionPad2d(1) 226 | self.conv = nn.utils.spectral_norm( 227 | nn.Conv2d(ndf * mult, 1, kernel_size=4, stride=1, padding=0, bias=False)) 228 | 229 | self.model = nn.Sequential(*model) 230 | 231 | def forward(self, input): 232 | x = self.model(input) 233 | 234 | gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) 235 | gap_logit = self.gap_fc(gap.view(x.shape[0], -1)) 236 | gap_weight = list(self.gap_fc.parameters())[0] 237 | gap = x * gap_weight.unsqueeze(2).unsqueeze(3) 238 | 239 | gmp = torch.nn.functional.adaptive_max_pool2d(x, 1) 240 | gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1)) 241 | gmp_weight = list(self.gmp_fc.parameters())[0] 242 | gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3) 243 | 244 | cam_logit = torch.cat([gap_logit, gmp_logit], 1) 245 | x = torch.cat([gap, gmp], 1) 246 | x = self.leaky_relu(self.conv1x1(x)) 247 | 248 | heatmap = torch.sum(x, dim=1, keepdim=True) 249 | 250 | x = self.pad(x) 251 | out = self.conv(x) 252 | 253 | return out, cam_logit, heatmap 254 | 255 | 256 | class RhoClipper(object): 257 | 258 | def __init__(self, min, max): 259 | self.clip_min = min 260 | self.clip_max = max 261 | assert min < max 262 | 263 | def __call__(self, module): 264 | 265 | if hasattr(module, 'rho'): 266 | w = module.rho.data 267 | w = w.clamp(self.clip_min, self.clip_max) 268 | module.rho.data = w 269 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from scipy import misc 2 | import os, cv2, torch 3 | import numpy as np 4 | 5 | def load_test_data(image_path, size=256): 6 | img = misc.imread(image_path, mode='RGB') 7 | img = misc.imresize(img, [size, size]) 8 | img = np.expand_dims(img, axis=0) 9 | img = preprocessing(img) 10 | 11 | return img 12 | 13 | def preprocessing(x): 14 | x = x/127.5 - 1 # -1 ~ 1 15 | return x 16 | 17 | def save_images(images, size, image_path): 18 | return imsave(inverse_transform(images), size, image_path) 19 | 20 | def inverse_transform(images): 21 | return (images+1.) / 2 22 | 23 | def imsave(images, size, path): 24 | return misc.imsave(path, merge(images, size)) 25 | 26 | def merge(images, size): 27 | h, w = images.shape[1], images.shape[2] 28 | img = np.zeros((h * size[0], w * size[1], 3)) 29 | for idx, image in enumerate(images): 30 | i = idx % size[1] 31 | j = idx // size[1] 32 | img[h*j:h*(j+1), w*i:w*(i+1), :] = image 33 | 34 | return img 35 | 36 | def check_folder(log_dir): 37 | if not os.path.exists(log_dir): 38 | os.makedirs(log_dir) 39 | return log_dir 40 | 41 | def str2bool(x): 42 | return x.lower() in ('true') 43 | 44 | def cam(x, size = 256): 45 | x = x - np.min(x) 46 | cam_img = x / np.max(x) 47 | cam_img = np.uint8(255 * cam_img) 48 | cam_img = cv2.resize(cam_img, (size, size)) 49 | cam_img = cv2.applyColorMap(cam_img, cv2.COLORMAP_JET) 50 | return cam_img / 255.0 51 | 52 | def imagenet_norm(x): 53 | mean = [0.485, 0.456, 0.406] 54 | std = [0.299, 0.224, 0.225] 55 | mean = torch.FloatTensor(mean).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(x.device) 56 | std = torch.FloatTensor(std).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(x.device) 57 | return (x - mean) / std 58 | 59 | def denorm(x): 60 | return x * 0.5 + 0.5 61 | 62 | def tensor2numpy(x): 63 | return x.detach().cpu().numpy().transpose(1,2,0) 64 | 65 | def RGB2BGR(x): 66 | return cv2.cvtColor(x, cv2.COLOR_RGB2BGR) --------------------------------------------------------------------------------