├── LICENSE
├── README.md
├── UGATIT.py
├── assets
├── ablation.png
├── discriminator.png
├── generator.png
├── kid.png
├── teaser.png
└── user_study.png
├── dataset.py
├── dataset
└── YOUR_DATASET_NAME
│ ├── testA
│ └── female_2321.jpg
│ ├── testB
│ └── 3414.png
│ ├── trainA
│ └── female_222.jpg
│ └── trainB
│ └── 0006.png
├── main.py
├── networks.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Hyeonwoo Kang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## U-GAT-IT — Official PyTorch Implementation
2 | ### : Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation
3 |
4 |
5 |

6 |
7 |
8 | ### [Paper](https://arxiv.org/abs/1907.10830) | [Official Tensorflow code](https://github.com/taki0112/UGATIT)
9 | The results of the paper came from the **Tensorflow code**
10 |
11 |
12 | > **U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation**
13 | > **Junho Kim (NCSOFT)**, Minjae Kim (NCSOFT), Hyeonwoo Kang (NCSOFT), Kwanghee Lee (Boeing Korea)
14 | >
15 | > **Abstract** *We propose a novel method for unsupervised image-to-image translation, which incorporates a new attention module and a new learnable normalization function in an end-to-end manner. The attention module guides our model to focus on more important regions distinguishing between source and target domains based on the attention map obtained by the auxiliary classifier. Unlike previous attention-based methods which cannot handle the geometric changes between domains, our model can translate both images requiring holistic changes and images requiring large shape changes. Moreover, our new AdaLIN (Adaptive Layer-Instance Normalization) function helps our attention-guided model to flexibly control the amount of change in shape and texture by learned parameters depending on datasets. Experimental results show the superiority of the proposed method compared to the existing state-of-the-art models with a fixed network architecture and hyper-parameters.*
16 |
17 | ## Usage
18 | ```
19 | ├── dataset
20 | └── YOUR_DATASET_NAME
21 | ├── trainA
22 | ├── xxx.jpg (name, format doesn't matter)
23 | ├── yyy.png
24 | └── ...
25 | ├── trainB
26 | ├── zzz.jpg
27 | ├── www.png
28 | └── ...
29 | ├── testA
30 | ├── aaa.jpg
31 | ├── bbb.png
32 | └── ...
33 | └── testB
34 | ├── ccc.jpg
35 | ├── ddd.png
36 | └── ...
37 | ```
38 |
39 | ### Train
40 | ```
41 | > python main.py --dataset selfie2anime
42 | ```
43 | * If the memory of gpu is **not sufficient**, set `--light` to True
44 |
45 | ### Test
46 | ```
47 | > python main.py --dataset selfie2anime --phase test
48 | ```
49 |
50 | ## Architecture
51 |
52 |

53 |
54 |
55 | ---
56 |
57 |
58 |

59 |
60 |
61 | ## Results
62 | ### Ablation study
63 |
64 |

65 |
66 |
67 | ### User study
68 |
69 |

70 |
71 |
72 | ### Comparison
73 |
74 |

75 |
76 |
77 | ## Citation
78 | If you find this code useful for your research, please cite our paper:
79 |
80 | ```
81 | @misc{kim2019ugatit,
82 | title={U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation},
83 | author={Junho Kim and Minjae Kim and Hyeonwoo Kang and Kwanghee Lee},
84 | year={2019},
85 | eprint={1907.10830},
86 | archivePrefix={arXiv},
87 | primaryClass={cs.CV}
88 | }
89 | ```
90 |
91 | ## Author
92 | [Junho Kim](http://bit.ly/jhkim_ai), Minjae Kim, Hyeonwoo Kang, Kwanghee Lee
93 |
--------------------------------------------------------------------------------
/UGATIT.py:
--------------------------------------------------------------------------------
1 | import time, itertools
2 | from dataset import ImageFolder
3 | from torchvision import transforms
4 | from torch.utils.data import DataLoader
5 | from networks import *
6 | from utils import *
7 | from glob import glob
8 |
9 | class UGATIT(object) :
10 | def __init__(self, args):
11 | self.light = args.light
12 |
13 | if self.light :
14 | self.model_name = 'UGATIT_light'
15 | else :
16 | self.model_name = 'UGATIT'
17 |
18 | self.result_dir = args.result_dir
19 | self.dataset = args.dataset
20 |
21 | self.iteration = args.iteration
22 | self.decay_flag = args.decay_flag
23 |
24 | self.batch_size = args.batch_size
25 | self.print_freq = args.print_freq
26 | self.save_freq = args.save_freq
27 |
28 | self.lr = args.lr
29 | self.weight_decay = args.weight_decay
30 | self.ch = args.ch
31 |
32 | """ Weight """
33 | self.adv_weight = args.adv_weight
34 | self.cycle_weight = args.cycle_weight
35 | self.identity_weight = args.identity_weight
36 | self.cam_weight = args.cam_weight
37 |
38 | """ Generator """
39 | self.n_res = args.n_res
40 |
41 | """ Discriminator """
42 | self.n_dis = args.n_dis
43 |
44 | self.img_size = args.img_size
45 | self.img_ch = args.img_ch
46 |
47 | self.device = args.device
48 | self.benchmark_flag = args.benchmark_flag
49 | self.resume = args.resume
50 |
51 | if torch.backends.cudnn.enabled and self.benchmark_flag:
52 | print('set benchmark !')
53 | torch.backends.cudnn.benchmark = True
54 |
55 | print()
56 |
57 | print("##### Information #####")
58 | print("# light : ", self.light)
59 | print("# dataset : ", self.dataset)
60 | print("# batch_size : ", self.batch_size)
61 | print("# iteration per epoch : ", self.iteration)
62 |
63 | print()
64 |
65 | print("##### Generator #####")
66 | print("# residual blocks : ", self.n_res)
67 |
68 | print()
69 |
70 | print("##### Discriminator #####")
71 | print("# discriminator layer : ", self.n_dis)
72 |
73 | print()
74 |
75 | print("##### Weight #####")
76 | print("# adv_weight : ", self.adv_weight)
77 | print("# cycle_weight : ", self.cycle_weight)
78 | print("# identity_weight : ", self.identity_weight)
79 | print("# cam_weight : ", self.cam_weight)
80 |
81 | ##################################################################################
82 | # Model
83 | ##################################################################################
84 |
85 | def build_model(self):
86 | """ DataLoader """
87 | train_transform = transforms.Compose([
88 | transforms.RandomHorizontalFlip(),
89 | transforms.Resize((self.img_size + 30, self.img_size+30)),
90 | transforms.RandomCrop(self.img_size),
91 | transforms.ToTensor(),
92 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
93 | ])
94 | test_transform = transforms.Compose([
95 | transforms.Resize((self.img_size, self.img_size)),
96 | transforms.ToTensor(),
97 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
98 | ])
99 |
100 | self.trainA = ImageFolder(os.path.join('dataset', self.dataset, 'trainA'), train_transform)
101 | self.trainB = ImageFolder(os.path.join('dataset', self.dataset, 'trainB'), train_transform)
102 | self.testA = ImageFolder(os.path.join('dataset', self.dataset, 'testA'), test_transform)
103 | self.testB = ImageFolder(os.path.join('dataset', self.dataset, 'testB'), test_transform)
104 | self.trainA_loader = DataLoader(self.trainA, batch_size=self.batch_size, shuffle=True)
105 | self.trainB_loader = DataLoader(self.trainB, batch_size=self.batch_size, shuffle=True)
106 | self.testA_loader = DataLoader(self.testA, batch_size=1, shuffle=False)
107 | self.testB_loader = DataLoader(self.testB, batch_size=1, shuffle=False)
108 |
109 | """ Define Generator, Discriminator """
110 | self.genA2B = ResnetGenerator(input_nc=3, output_nc=3, ngf=self.ch, n_blocks=self.n_res, img_size=self.img_size, light=self.light).to(self.device)
111 | self.genB2A = ResnetGenerator(input_nc=3, output_nc=3, ngf=self.ch, n_blocks=self.n_res, img_size=self.img_size, light=self.light).to(self.device)
112 | self.disGA = Discriminator(input_nc=3, ndf=self.ch, n_layers=7).to(self.device)
113 | self.disGB = Discriminator(input_nc=3, ndf=self.ch, n_layers=7).to(self.device)
114 | self.disLA = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device)
115 | self.disLB = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device)
116 |
117 | """ Define Loss """
118 | self.L1_loss = nn.L1Loss().to(self.device)
119 | self.MSE_loss = nn.MSELoss().to(self.device)
120 | self.BCE_loss = nn.BCEWithLogitsLoss().to(self.device)
121 |
122 | """ Trainer """
123 | self.G_optim = torch.optim.Adam(itertools.chain(self.genA2B.parameters(), self.genB2A.parameters()), lr=self.lr, betas=(0.5, 0.999), weight_decay=self.weight_decay)
124 | self.D_optim = torch.optim.Adam(itertools.chain(self.disGA.parameters(), self.disGB.parameters(), self.disLA.parameters(), self.disLB.parameters()), lr=self.lr, betas=(0.5, 0.999), weight_decay=self.weight_decay)
125 |
126 | """ Define Rho clipper to constraint the value of rho in AdaILN and ILN"""
127 | self.Rho_clipper = RhoClipper(0, 1)
128 |
129 | def train(self):
130 | self.genA2B.train(), self.genB2A.train(), self.disGA.train(), self.disGB.train(), self.disLA.train(), self.disLB.train()
131 |
132 | start_iter = 1
133 | if self.resume:
134 | model_list = glob(os.path.join(self.result_dir, self.dataset, 'model', '*.pt'))
135 | if not len(model_list) == 0:
136 | model_list.sort()
137 | start_iter = int(model_list[-1].split('_')[-1].split('.')[0])
138 | self.load(os.path.join(self.result_dir, self.dataset, 'model'), start_iter)
139 | print(" [*] Load SUCCESS")
140 | if self.decay_flag and start_iter > (self.iteration // 2):
141 | self.G_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2)
142 | self.D_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2)
143 |
144 | # training loop
145 | print('training start !')
146 | start_time = time.time()
147 | for step in range(start_iter, self.iteration + 1):
148 | if self.decay_flag and step > (self.iteration // 2):
149 | self.G_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2))
150 | self.D_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2))
151 |
152 | try:
153 | real_A, _ = trainA_iter.next()
154 | except:
155 | trainA_iter = iter(self.trainA_loader)
156 | real_A, _ = trainA_iter.next()
157 |
158 | try:
159 | real_B, _ = trainB_iter.next()
160 | except:
161 | trainB_iter = iter(self.trainB_loader)
162 | real_B, _ = trainB_iter.next()
163 |
164 | real_A, real_B = real_A.to(self.device), real_B.to(self.device)
165 |
166 | # Update D
167 | self.D_optim.zero_grad()
168 |
169 | fake_A2B, _, _ = self.genA2B(real_A)
170 | fake_B2A, _, _ = self.genB2A(real_B)
171 |
172 | real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A)
173 | real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A)
174 | real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B)
175 | real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B)
176 |
177 | fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
178 | fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
179 | fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
180 | fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)
181 |
182 | D_ad_loss_GA = self.MSE_loss(real_GA_logit, torch.ones_like(real_GA_logit).to(self.device)) + self.MSE_loss(fake_GA_logit, torch.zeros_like(fake_GA_logit).to(self.device))
183 | D_ad_cam_loss_GA = self.MSE_loss(real_GA_cam_logit, torch.ones_like(real_GA_cam_logit).to(self.device)) + self.MSE_loss(fake_GA_cam_logit, torch.zeros_like(fake_GA_cam_logit).to(self.device))
184 | D_ad_loss_LA = self.MSE_loss(real_LA_logit, torch.ones_like(real_LA_logit).to(self.device)) + self.MSE_loss(fake_LA_logit, torch.zeros_like(fake_LA_logit).to(self.device))
185 | D_ad_cam_loss_LA = self.MSE_loss(real_LA_cam_logit, torch.ones_like(real_LA_cam_logit).to(self.device)) + self.MSE_loss(fake_LA_cam_logit, torch.zeros_like(fake_LA_cam_logit).to(self.device))
186 | D_ad_loss_GB = self.MSE_loss(real_GB_logit, torch.ones_like(real_GB_logit).to(self.device)) + self.MSE_loss(fake_GB_logit, torch.zeros_like(fake_GB_logit).to(self.device))
187 | D_ad_cam_loss_GB = self.MSE_loss(real_GB_cam_logit, torch.ones_like(real_GB_cam_logit).to(self.device)) + self.MSE_loss(fake_GB_cam_logit, torch.zeros_like(fake_GB_cam_logit).to(self.device))
188 | D_ad_loss_LB = self.MSE_loss(real_LB_logit, torch.ones_like(real_LB_logit).to(self.device)) + self.MSE_loss(fake_LB_logit, torch.zeros_like(fake_LB_logit).to(self.device))
189 | D_ad_cam_loss_LB = self.MSE_loss(real_LB_cam_logit, torch.ones_like(real_LB_cam_logit).to(self.device)) + self.MSE_loss(fake_LB_cam_logit, torch.zeros_like(fake_LB_cam_logit).to(self.device))
190 |
191 | D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA)
192 | D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB)
193 |
194 | Discriminator_loss = D_loss_A + D_loss_B
195 | Discriminator_loss.backward()
196 | self.D_optim.step()
197 |
198 | # Update G
199 | self.G_optim.zero_grad()
200 |
201 | fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A)
202 | fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B)
203 |
204 | fake_A2B2A, _, _ = self.genB2A(fake_A2B)
205 | fake_B2A2B, _, _ = self.genA2B(fake_B2A)
206 |
207 | fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A)
208 | fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B)
209 |
210 | fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
211 | fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
212 | fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
213 | fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)
214 |
215 | G_ad_loss_GA = self.MSE_loss(fake_GA_logit, torch.ones_like(fake_GA_logit).to(self.device))
216 | G_ad_cam_loss_GA = self.MSE_loss(fake_GA_cam_logit, torch.ones_like(fake_GA_cam_logit).to(self.device))
217 | G_ad_loss_LA = self.MSE_loss(fake_LA_logit, torch.ones_like(fake_LA_logit).to(self.device))
218 | G_ad_cam_loss_LA = self.MSE_loss(fake_LA_cam_logit, torch.ones_like(fake_LA_cam_logit).to(self.device))
219 | G_ad_loss_GB = self.MSE_loss(fake_GB_logit, torch.ones_like(fake_GB_logit).to(self.device))
220 | G_ad_cam_loss_GB = self.MSE_loss(fake_GB_cam_logit, torch.ones_like(fake_GB_cam_logit).to(self.device))
221 | G_ad_loss_LB = self.MSE_loss(fake_LB_logit, torch.ones_like(fake_LB_logit).to(self.device))
222 | G_ad_cam_loss_LB = self.MSE_loss(fake_LB_cam_logit, torch.ones_like(fake_LB_cam_logit).to(self.device))
223 |
224 | G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A)
225 | G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B)
226 |
227 | G_identity_loss_A = self.L1_loss(fake_A2A, real_A)
228 | G_identity_loss_B = self.L1_loss(fake_B2B, real_B)
229 |
230 | G_cam_loss_A = self.BCE_loss(fake_B2A_cam_logit, torch.ones_like(fake_B2A_cam_logit).to(self.device)) + self.BCE_loss(fake_A2A_cam_logit, torch.zeros_like(fake_A2A_cam_logit).to(self.device))
231 | G_cam_loss_B = self.BCE_loss(fake_A2B_cam_logit, torch.ones_like(fake_A2B_cam_logit).to(self.device)) + self.BCE_loss(fake_B2B_cam_logit, torch.zeros_like(fake_B2B_cam_logit).to(self.device))
232 |
233 | G_loss_A = self.adv_weight * (G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A
234 | G_loss_B = self.adv_weight * (G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B
235 |
236 | Generator_loss = G_loss_A + G_loss_B
237 | Generator_loss.backward()
238 | self.G_optim.step()
239 |
240 | # clip parameter of AdaILN and ILN, applied after optimizer step
241 | self.genA2B.apply(self.Rho_clipper)
242 | self.genB2A.apply(self.Rho_clipper)
243 |
244 | print("[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (step, self.iteration, time.time() - start_time, Discriminator_loss, Generator_loss))
245 | if step % self.print_freq == 0:
246 | train_sample_num = 5
247 | test_sample_num = 5
248 | A2B = np.zeros((self.img_size * 7, 0, 3))
249 | B2A = np.zeros((self.img_size * 7, 0, 3))
250 |
251 | self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval(), self.disGB.eval(), self.disLA.eval(), self.disLB.eval()
252 | for _ in range(train_sample_num):
253 | try:
254 | real_A, _ = trainA_iter.next()
255 | except:
256 | trainA_iter = iter(self.trainA_loader)
257 | real_A, _ = trainA_iter.next()
258 |
259 | try:
260 | real_B, _ = trainB_iter.next()
261 | except:
262 | trainB_iter = iter(self.trainB_loader)
263 | real_B, _ = trainB_iter.next()
264 | real_A, real_B = real_A.to(self.device), real_B.to(self.device)
265 |
266 | fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
267 | fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)
268 |
269 | fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
270 | fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)
271 |
272 | fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
273 | fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)
274 |
275 | A2B = np.concatenate((A2B, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))),
276 | cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size),
277 | RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
278 | cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size),
279 | RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
280 | cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size),
281 | RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1)
282 |
283 | B2A = np.concatenate((B2A, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))),
284 | cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size),
285 | RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
286 | cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size),
287 | RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
288 | cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size),
289 | RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1)
290 |
291 | for _ in range(test_sample_num):
292 | try:
293 | real_A, _ = testA_iter.next()
294 | except:
295 | testA_iter = iter(self.testA_loader)
296 | real_A, _ = testA_iter.next()
297 |
298 | try:
299 | real_B, _ = testB_iter.next()
300 | except:
301 | testB_iter = iter(self.testB_loader)
302 | real_B, _ = testB_iter.next()
303 | real_A, real_B = real_A.to(self.device), real_B.to(self.device)
304 |
305 | fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
306 | fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)
307 |
308 | fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
309 | fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)
310 |
311 | fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
312 | fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)
313 |
314 | A2B = np.concatenate((A2B, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))),
315 | cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size),
316 | RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
317 | cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size),
318 | RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
319 | cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size),
320 | RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1)
321 |
322 | B2A = np.concatenate((B2A, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))),
323 | cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size),
324 | RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
325 | cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size),
326 | RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
327 | cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size),
328 | RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1)
329 |
330 | cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'img', 'A2B_%07d.png' % step), A2B * 255.0)
331 | cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'img', 'B2A_%07d.png' % step), B2A * 255.0)
332 | self.genA2B.train(), self.genB2A.train(), self.disGA.train(), self.disGB.train(), self.disLA.train(), self.disLB.train()
333 |
334 | if step % self.save_freq == 0:
335 | self.save(os.path.join(self.result_dir, self.dataset, 'model'), step)
336 |
337 | if step % 1000 == 0:
338 | params = {}
339 | params['genA2B'] = self.genA2B.state_dict()
340 | params['genB2A'] = self.genB2A.state_dict()
341 | params['disGA'] = self.disGA.state_dict()
342 | params['disGB'] = self.disGB.state_dict()
343 | params['disLA'] = self.disLA.state_dict()
344 | params['disLB'] = self.disLB.state_dict()
345 | torch.save(params, os.path.join(self.result_dir, self.dataset + '_params_latest.pt'))
346 |
347 | def save(self, dir, step):
348 | params = {}
349 | params['genA2B'] = self.genA2B.state_dict()
350 | params['genB2A'] = self.genB2A.state_dict()
351 | params['disGA'] = self.disGA.state_dict()
352 | params['disGB'] = self.disGB.state_dict()
353 | params['disLA'] = self.disLA.state_dict()
354 | params['disLB'] = self.disLB.state_dict()
355 | torch.save(params, os.path.join(dir, self.dataset + '_params_%07d.pt' % step))
356 |
357 | def load(self, dir, step):
358 | params = torch.load(os.path.join(dir, self.dataset + '_params_%07d.pt' % step))
359 | self.genA2B.load_state_dict(params['genA2B'])
360 | self.genB2A.load_state_dict(params['genB2A'])
361 | self.disGA.load_state_dict(params['disGA'])
362 | self.disGB.load_state_dict(params['disGB'])
363 | self.disLA.load_state_dict(params['disLA'])
364 | self.disLB.load_state_dict(params['disLB'])
365 |
366 | def test(self):
367 | model_list = glob(os.path.join(self.result_dir, self.dataset, 'model', '*.pt'))
368 | if not len(model_list) == 0:
369 | model_list.sort()
370 | iter = int(model_list[-1].split('_')[-1].split('.')[0])
371 | self.load(os.path.join(self.result_dir, self.dataset, 'model'), iter)
372 | print(" [*] Load SUCCESS")
373 | else:
374 | print(" [*] Load FAILURE")
375 | return
376 |
377 | self.genA2B.eval(), self.genB2A.eval()
378 | for n, (real_A, _) in enumerate(self.testA_loader):
379 | real_A = real_A.to(self.device)
380 |
381 | fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
382 |
383 | fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
384 |
385 | fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
386 |
387 | A2B = np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))),
388 | cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size),
389 | RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
390 | cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size),
391 | RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
392 | cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size),
393 | RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)
394 |
395 | cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'test', 'A2B_%d.png' % (n + 1)), A2B * 255.0)
396 |
397 | for n, (real_B, _) in enumerate(self.testB_loader):
398 | real_B = real_B.to(self.device)
399 |
400 | fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)
401 |
402 | fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)
403 |
404 | fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)
405 |
406 | B2A = np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))),
407 | cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size),
408 | RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
409 | cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size),
410 | RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
411 | cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size),
412 | RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)
413 |
414 | cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'test', 'B2A_%d.png' % (n + 1)), B2A * 255.0)
415 |
--------------------------------------------------------------------------------
/assets/ablation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liaoxy169/UGATIT-pytorch/a68011d6365725a4d82a9640c857d1d606c9b138/assets/ablation.png
--------------------------------------------------------------------------------
/assets/discriminator.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liaoxy169/UGATIT-pytorch/a68011d6365725a4d82a9640c857d1d606c9b138/assets/discriminator.png
--------------------------------------------------------------------------------
/assets/generator.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liaoxy169/UGATIT-pytorch/a68011d6365725a4d82a9640c857d1d606c9b138/assets/generator.png
--------------------------------------------------------------------------------
/assets/kid.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liaoxy169/UGATIT-pytorch/a68011d6365725a4d82a9640c857d1d606c9b138/assets/kid.png
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liaoxy169/UGATIT-pytorch/a68011d6365725a4d82a9640c857d1d606c9b138/assets/teaser.png
--------------------------------------------------------------------------------
/assets/user_study.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liaoxy169/UGATIT-pytorch/a68011d6365725a4d82a9640c857d1d606c9b138/assets/user_study.png
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 |
5 | import os
6 | import os.path
7 |
8 |
9 | def has_file_allowed_extension(filename, extensions):
10 | """Checks if a file is an allowed extension.
11 |
12 | Args:
13 | filename (string): path to a file
14 |
15 | Returns:
16 | bool: True if the filename ends with a known image extension
17 | """
18 | filename_lower = filename.lower()
19 | return any(filename_lower.endswith(ext) for ext in extensions)
20 |
21 |
22 | def find_classes(dir):
23 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
24 | classes.sort()
25 | class_to_idx = {classes[i]: i for i in range(len(classes))}
26 | return classes, class_to_idx
27 |
28 |
29 | def make_dataset(dir, extensions):
30 | images = []
31 | for root, _, fnames in sorted(os.walk(dir)):
32 | for fname in sorted(fnames):
33 | if has_file_allowed_extension(fname, extensions):
34 | path = os.path.join(root, fname)
35 | item = (path, 0)
36 | images.append(item)
37 |
38 | return images
39 |
40 |
41 | class DatasetFolder(data.Dataset):
42 | def __init__(self, root, loader, extensions, transform=None, target_transform=None):
43 | # classes, class_to_idx = find_classes(root)
44 | samples = make_dataset(root, extensions)
45 | if len(samples) == 0:
46 | raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
47 | "Supported extensions are: " + ",".join(extensions)))
48 |
49 | self.root = root
50 | self.loader = loader
51 | self.extensions = extensions
52 | self.samples = samples
53 |
54 | self.transform = transform
55 | self.target_transform = target_transform
56 |
57 | def __getitem__(self, index):
58 | """
59 | Args:
60 | index (int): Index
61 |
62 | Returns:
63 | tuple: (sample, target) where target is class_index of the target class.
64 | """
65 | path, target = self.samples[index]
66 | sample = self.loader(path)
67 | if self.transform is not None:
68 | sample = self.transform(sample)
69 | if self.target_transform is not None:
70 | target = self.target_transform(target)
71 |
72 | return sample, target
73 |
74 | def __len__(self):
75 | return len(self.samples)
76 |
77 | def __repr__(self):
78 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
79 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
80 | fmt_str += ' Root Location: {}\n'.format(self.root)
81 | tmp = ' Transforms (if any): '
82 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
83 | tmp = ' Target Transforms (if any): '
84 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
85 | return fmt_str
86 |
87 |
88 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
89 |
90 |
91 | def pil_loader(path):
92 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
93 | with open(path, 'rb') as f:
94 | img = Image.open(f)
95 | return img.convert('RGB')
96 |
97 |
98 | def default_loader(path):
99 | return pil_loader(path)
100 |
101 |
102 | class ImageFolder(DatasetFolder):
103 | def __init__(self, root, transform=None, target_transform=None,
104 | loader=default_loader):
105 | super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
106 | transform=transform,
107 | target_transform=target_transform)
108 | self.imgs = self.samples
109 |
--------------------------------------------------------------------------------
/dataset/YOUR_DATASET_NAME/testA/female_2321.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liaoxy169/UGATIT-pytorch/a68011d6365725a4d82a9640c857d1d606c9b138/dataset/YOUR_DATASET_NAME/testA/female_2321.jpg
--------------------------------------------------------------------------------
/dataset/YOUR_DATASET_NAME/testB/3414.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liaoxy169/UGATIT-pytorch/a68011d6365725a4d82a9640c857d1d606c9b138/dataset/YOUR_DATASET_NAME/testB/3414.png
--------------------------------------------------------------------------------
/dataset/YOUR_DATASET_NAME/trainA/female_222.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liaoxy169/UGATIT-pytorch/a68011d6365725a4d82a9640c857d1d606c9b138/dataset/YOUR_DATASET_NAME/trainA/female_222.jpg
--------------------------------------------------------------------------------
/dataset/YOUR_DATASET_NAME/trainB/0006.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liaoxy169/UGATIT-pytorch/a68011d6365725a4d82a9640c857d1d606c9b138/dataset/YOUR_DATASET_NAME/trainB/0006.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from UGATIT import UGATIT
2 | import argparse
3 | from utils import *
4 |
5 | """parsing and configuration"""
6 |
7 | def parse_args():
8 | desc = "Pytorch implementation of U-GAT-IT"
9 | parser = argparse.ArgumentParser(description=desc)
10 | parser.add_argument('--phase', type=str, default='train', help='[train / test]')
11 | parser.add_argument('--light', type=str2bool, default=False, help='[U-GAT-IT full version / U-GAT-IT light version]')
12 | parser.add_argument('--dataset', type=str, default='YOUR_DATASET_NAME', help='dataset_name')
13 |
14 | parser.add_argument('--iteration', type=int, default=1000000, help='The number of training iterations')
15 | parser.add_argument('--batch_size', type=int, default=1, help='The size of batch size')
16 | parser.add_argument('--print_freq', type=int, default=1000, help='The number of image print freq')
17 | parser.add_argument('--save_freq', type=int, default=100000, help='The number of model save freq')
18 | parser.add_argument('--decay_flag', type=str2bool, default=True, help='The decay_flag')
19 |
20 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate')
21 | parser.add_argument('--weight_decay', type=float, default=0.0001, help='The weight decay')
22 | parser.add_argument('--adv_weight', type=int, default=1, help='Weight for GAN')
23 | parser.add_argument('--cycle_weight', type=int, default=10, help='Weight for Cycle')
24 | parser.add_argument('--identity_weight', type=int, default=10, help='Weight for Identity')
25 | parser.add_argument('--cam_weight', type=int, default=1000, help='Weight for CAM')
26 |
27 | parser.add_argument('--ch', type=int, default=64, help='base channel number per layer')
28 | parser.add_argument('--n_res', type=int, default=4, help='The number of resblock')
29 | parser.add_argument('--n_dis', type=int, default=6, help='The number of discriminator layer')
30 |
31 | parser.add_argument('--img_size', type=int, default=256, help='The size of image')
32 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel')
33 |
34 | parser.add_argument('--result_dir', type=str, default='results', help='Directory name to save the results')
35 | parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'], help='Set gpu mode; [cpu, cuda]')
36 | parser.add_argument('--benchmark_flag', type=str2bool, default=False)
37 | parser.add_argument('--resume', type=str2bool, default=False)
38 |
39 | return check_args(parser.parse_args())
40 |
41 | """checking arguments"""
42 | def check_args(args):
43 | # --result_dir
44 | check_folder(os.path.join(args.result_dir, args.dataset, 'model'))
45 | check_folder(os.path.join(args.result_dir, args.dataset, 'img'))
46 | check_folder(os.path.join(args.result_dir, args.dataset, 'test'))
47 |
48 | # --epoch
49 | try:
50 | assert args.epoch >= 1
51 | except:
52 | print('number of epochs must be larger than or equal to one')
53 |
54 | # --batch_size
55 | try:
56 | assert args.batch_size >= 1
57 | except:
58 | print('batch size must be larger than or equal to one')
59 | return args
60 |
61 | """main"""
62 | def main():
63 | # parse arguments
64 | args = parse_args()
65 | if args is None:
66 | exit()
67 |
68 | # open session
69 | gan = UGATIT(args)
70 |
71 | # build graph
72 | gan.build_model()
73 |
74 | if args.phase == 'train' :
75 | gan.train()
76 | print(" [*] Training finished!")
77 |
78 | if args.phase == 'test' :
79 | gan.test()
80 | print(" [*] Test finished!")
81 |
82 | if __name__ == '__main__':
83 | main()
84 |
--------------------------------------------------------------------------------
/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.parameter import Parameter
4 |
5 |
6 | class ResnetGenerator(nn.Module):
7 | def __init__(self, input_nc, output_nc, ngf=64, n_blocks=6, img_size=256, light=False):
8 | assert(n_blocks >= 0)
9 | super(ResnetGenerator, self).__init__()
10 | self.input_nc = input_nc
11 | self.output_nc = output_nc
12 | self.ngf = ngf
13 | self.n_blocks = n_blocks
14 | self.img_size = img_size
15 | self.light = light
16 |
17 | DownBlock = []
18 | DownBlock += [nn.ReflectionPad2d(3),
19 | nn.Conv2d(input_nc, ngf, kernel_size=7, stride=1, padding=0, bias=False),
20 | nn.InstanceNorm2d(ngf),
21 | nn.ReLU(True)]
22 |
23 | # Down-Sampling
24 | n_downsampling = 2
25 | for i in range(n_downsampling):
26 | mult = 2**i
27 | DownBlock += [nn.ReflectionPad2d(1),
28 | nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=0, bias=False),
29 | nn.InstanceNorm2d(ngf * mult * 2),
30 | nn.ReLU(True)]
31 |
32 | # Down-Sampling Bottleneck
33 | mult = 2**n_downsampling
34 | for i in range(n_blocks):
35 | DownBlock += [ResnetBlock(ngf * mult, use_bias=False)]
36 |
37 | # Class Activation Map
38 | self.gap_fc = nn.Linear(ngf * mult, 1, bias=False)
39 | self.gmp_fc = nn.Linear(ngf * mult, 1, bias=False)
40 | self.conv1x1 = nn.Conv2d(ngf * mult * 2, ngf * mult, kernel_size=1, stride=1, bias=True)
41 | self.relu = nn.ReLU(True)
42 |
43 | # Gamma, Beta block
44 | if self.light:
45 | FC = [nn.Linear(ngf * mult, ngf * mult, bias=False),
46 | nn.ReLU(True),
47 | nn.Linear(ngf * mult, ngf * mult, bias=False),
48 | nn.ReLU(True)]
49 | else:
50 | FC = [nn.Linear(img_size // mult * img_size // mult * ngf * mult, ngf * mult, bias=False),
51 | nn.ReLU(True),
52 | nn.Linear(ngf * mult, ngf * mult, bias=False),
53 | nn.ReLU(True)]
54 | self.gamma = nn.Linear(ngf * mult, ngf * mult, bias=False)
55 | self.beta = nn.Linear(ngf * mult, ngf * mult, bias=False)
56 |
57 | # Up-Sampling Bottleneck
58 | for i in range(n_blocks):
59 | setattr(self, 'UpBlock1_' + str(i+1), ResnetAdaILNBlock(ngf * mult, use_bias=False))
60 |
61 | # Up-Sampling
62 | UpBlock2 = []
63 | for i in range(n_downsampling):
64 | mult = 2**(n_downsampling - i)
65 | UpBlock2 += [nn.Upsample(scale_factor=2, mode='nearest'),
66 | nn.ReflectionPad2d(1),
67 | nn.Conv2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=1, padding=0, bias=False),
68 | ILN(int(ngf * mult / 2)),
69 | nn.ReLU(True)]
70 |
71 | UpBlock2 += [nn.ReflectionPad2d(3),
72 | nn.Conv2d(ngf, output_nc, kernel_size=7, stride=1, padding=0, bias=False),
73 | nn.Tanh()]
74 |
75 | self.DownBlock = nn.Sequential(*DownBlock)
76 | self.FC = nn.Sequential(*FC)
77 | self.UpBlock2 = nn.Sequential(*UpBlock2)
78 |
79 | def forward(self, input):
80 | x = self.DownBlock(input)
81 |
82 | gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
83 | gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
84 | gap_weight = list(self.gap_fc.parameters())[0]
85 | gap = x * gap_weight.unsqueeze(2).unsqueeze(3)
86 |
87 | gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
88 | gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
89 | gmp_weight = list(self.gmp_fc.parameters())[0]
90 | gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)
91 |
92 | cam_logit = torch.cat([gap_logit, gmp_logit], 1)
93 | x = torch.cat([gap, gmp], 1)
94 | x = self.relu(self.conv1x1(x))
95 |
96 | heatmap = torch.sum(x, dim=1, keepdim=True)
97 |
98 | if self.light:
99 | x_ = torch.nn.functional.adaptive_avg_pool2d(x, 1)
100 | x_ = self.FC(x_.view(x_.shape[0], -1))
101 | else:
102 | x_ = self.FC(x.view(x.shape[0], -1))
103 | gamma, beta = self.gamma(x_), self.beta(x_)
104 |
105 |
106 | for i in range(self.n_blocks):
107 | x = getattr(self, 'UpBlock1_' + str(i+1))(x, gamma, beta)
108 | out = self.UpBlock2(x)
109 |
110 | return out, cam_logit, heatmap
111 |
112 |
113 | class ResnetBlock(nn.Module):
114 | def __init__(self, dim, use_bias):
115 | super(ResnetBlock, self).__init__()
116 | conv_block = []
117 | conv_block += [nn.ReflectionPad2d(1),
118 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias),
119 | nn.InstanceNorm2d(dim),
120 | nn.ReLU(True)]
121 |
122 | conv_block += [nn.ReflectionPad2d(1),
123 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias),
124 | nn.InstanceNorm2d(dim)]
125 |
126 | self.conv_block = nn.Sequential(*conv_block)
127 |
128 | def forward(self, x):
129 | out = x + self.conv_block(x)
130 | return out
131 |
132 |
133 | class ResnetAdaILNBlock(nn.Module):
134 | def __init__(self, dim, use_bias):
135 | super(ResnetAdaILNBlock, self).__init__()
136 | self.pad1 = nn.ReflectionPad2d(1)
137 | self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias)
138 | self.norm1 = adaILN(dim)
139 | self.relu1 = nn.ReLU(True)
140 |
141 | self.pad2 = nn.ReflectionPad2d(1)
142 | self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias)
143 | self.norm2 = adaILN(dim)
144 |
145 | def forward(self, x, gamma, beta):
146 | out = self.pad1(x)
147 | out = self.conv1(out)
148 | out = self.norm1(out, gamma, beta)
149 | out = self.relu1(out)
150 | out = self.pad2(out)
151 | out = self.conv2(out)
152 | out = self.norm2(out, gamma, beta)
153 |
154 | return out
155 |
156 |
157 | class adaILN(nn.Module):
158 | def __init__(self, num_features, eps=1e-5):
159 | super(adaILN, self).__init__()
160 | self.eps = eps
161 | self.rho = Parameter(torch.Tensor(1, num_features, 1, 1))
162 | self.rho.data.fill_(0.9)
163 |
164 | def forward(self, input, gamma, beta):
165 | in_mean, in_var = torch.mean(torch.mean(input, dim=2, keepdim=True), dim=3, keepdim=True), torch.var(torch.var(input, dim=2, keepdim=True), dim=3, keepdim=True)
166 | out_in = (input - in_mean) / torch.sqrt(in_var + self.eps)
167 | ln_mean, ln_var = torch.mean(torch.mean(torch.mean(input, dim=1, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True), torch.var(torch.var(torch.var(input, dim=1, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True)
168 | out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps)
169 | out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln
170 | out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
171 |
172 | return out
173 |
174 |
175 | class ILN(nn.Module):
176 | def __init__(self, num_features, eps=1e-5):
177 | super(ILN, self).__init__()
178 | self.eps = eps
179 | self.rho = Parameter(torch.Tensor(1, num_features, 1, 1))
180 | self.gamma = Parameter(torch.Tensor(1, num_features, 1, 1))
181 | self.beta = Parameter(torch.Tensor(1, num_features, 1, 1))
182 | self.rho.data.fill_(0.0)
183 | self.gamma.data.fill_(1.0)
184 | self.beta.data.fill_(0.0)
185 |
186 | def forward(self, input):
187 | in_mean, in_var = torch.mean(torch.mean(input, dim=2, keepdim=True), dim=3, keepdim=True), torch.var(torch.var(input, dim=2, keepdim=True), dim=3, keepdim=True)
188 | out_in = (input - in_mean) / torch.sqrt(in_var + self.eps)
189 | ln_mean, ln_var = torch.mean(torch.mean(torch.mean(input, dim=1, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True), torch.var(torch.var(torch.var(input, dim=1, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True)
190 | out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps)
191 | out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln
192 | out = out * self.gamma.expand(input.shape[0], -1, -1, -1) + self.beta.expand(input.shape[0], -1, -1, -1)
193 |
194 | return out
195 |
196 |
197 | class Discriminator(nn.Module):
198 | def __init__(self, input_nc, ndf=64, n_layers=5):
199 | super(Discriminator, self).__init__()
200 | model = [nn.ReflectionPad2d(1),
201 | nn.utils.spectral_norm(
202 | nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=0, bias=True)),
203 | nn.LeakyReLU(0.2, True)]
204 |
205 | for i in range(1, n_layers - 2):
206 | mult = 2 ** (i - 1)
207 | model += [nn.ReflectionPad2d(1),
208 | nn.utils.spectral_norm(
209 | nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=2, padding=0, bias=True)),
210 | nn.LeakyReLU(0.2, True)]
211 |
212 | mult = 2 ** (n_layers - 2 - 1)
213 | model += [nn.ReflectionPad2d(1),
214 | nn.utils.spectral_norm(
215 | nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=1, padding=0, bias=True)),
216 | nn.LeakyReLU(0.2, True)]
217 |
218 | # Class Activation Map
219 | mult = 2 ** (n_layers - 2)
220 | self.gap_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False))
221 | self.gmp_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False))
222 | self.conv1x1 = nn.Conv2d(ndf * mult * 2, ndf * mult, kernel_size=1, stride=1, bias=True)
223 | self.leaky_relu = nn.LeakyReLU(0.2, True)
224 |
225 | self.pad = nn.ReflectionPad2d(1)
226 | self.conv = nn.utils.spectral_norm(
227 | nn.Conv2d(ndf * mult, 1, kernel_size=4, stride=1, padding=0, bias=False))
228 |
229 | self.model = nn.Sequential(*model)
230 |
231 | def forward(self, input):
232 | x = self.model(input)
233 |
234 | gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
235 | gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
236 | gap_weight = list(self.gap_fc.parameters())[0]
237 | gap = x * gap_weight.unsqueeze(2).unsqueeze(3)
238 |
239 | gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
240 | gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
241 | gmp_weight = list(self.gmp_fc.parameters())[0]
242 | gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)
243 |
244 | cam_logit = torch.cat([gap_logit, gmp_logit], 1)
245 | x = torch.cat([gap, gmp], 1)
246 | x = self.leaky_relu(self.conv1x1(x))
247 |
248 | heatmap = torch.sum(x, dim=1, keepdim=True)
249 |
250 | x = self.pad(x)
251 | out = self.conv(x)
252 |
253 | return out, cam_logit, heatmap
254 |
255 |
256 | class RhoClipper(object):
257 |
258 | def __init__(self, min, max):
259 | self.clip_min = min
260 | self.clip_max = max
261 | assert min < max
262 |
263 | def __call__(self, module):
264 |
265 | if hasattr(module, 'rho'):
266 | w = module.rho.data
267 | w = w.clamp(self.clip_min, self.clip_max)
268 | module.rho.data = w
269 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from scipy import misc
2 | import os, cv2, torch
3 | import numpy as np
4 |
5 | def load_test_data(image_path, size=256):
6 | img = misc.imread(image_path, mode='RGB')
7 | img = misc.imresize(img, [size, size])
8 | img = np.expand_dims(img, axis=0)
9 | img = preprocessing(img)
10 |
11 | return img
12 |
13 | def preprocessing(x):
14 | x = x/127.5 - 1 # -1 ~ 1
15 | return x
16 |
17 | def save_images(images, size, image_path):
18 | return imsave(inverse_transform(images), size, image_path)
19 |
20 | def inverse_transform(images):
21 | return (images+1.) / 2
22 |
23 | def imsave(images, size, path):
24 | return misc.imsave(path, merge(images, size))
25 |
26 | def merge(images, size):
27 | h, w = images.shape[1], images.shape[2]
28 | img = np.zeros((h * size[0], w * size[1], 3))
29 | for idx, image in enumerate(images):
30 | i = idx % size[1]
31 | j = idx // size[1]
32 | img[h*j:h*(j+1), w*i:w*(i+1), :] = image
33 |
34 | return img
35 |
36 | def check_folder(log_dir):
37 | if not os.path.exists(log_dir):
38 | os.makedirs(log_dir)
39 | return log_dir
40 |
41 | def str2bool(x):
42 | return x.lower() in ('true')
43 |
44 | def cam(x, size = 256):
45 | x = x - np.min(x)
46 | cam_img = x / np.max(x)
47 | cam_img = np.uint8(255 * cam_img)
48 | cam_img = cv2.resize(cam_img, (size, size))
49 | cam_img = cv2.applyColorMap(cam_img, cv2.COLORMAP_JET)
50 | return cam_img / 255.0
51 |
52 | def imagenet_norm(x):
53 | mean = [0.485, 0.456, 0.406]
54 | std = [0.299, 0.224, 0.225]
55 | mean = torch.FloatTensor(mean).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(x.device)
56 | std = torch.FloatTensor(std).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(x.device)
57 | return (x - mean) / std
58 |
59 | def denorm(x):
60 | return x * 0.5 + 0.5
61 |
62 | def tensor2numpy(x):
63 | return x.detach().cpu().numpy().transpose(1,2,0)
64 |
65 | def RGB2BGR(x):
66 | return cv2.cvtColor(x, cv2.COLOR_RGB2BGR)
--------------------------------------------------------------------------------