├── .gitignore
├── LICENSE
├── README.md
├── UGATIT_inference.py
├── UGATIT_train.py
├── example.png
├── module
├── base_module.py
├── dataloader.py
├── discriminator.py
├── generator.py
└── importer.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | dataset/
2 | __pycache__/
3 | output/
4 | trained_model/
5 | conversion/
6 | .DS_Store
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 zassou65535
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # image_converter
2 | ## 概要
3 |
4 |
5 |
6 | メルアイコン変換器で用いた、UGATITのpytorch実装です。
7 | 詳しい解説はこちら。
8 |
9 | ## 想定環境
10 | python 3.7.1, VRAM32GB以上
11 | torch==1.4.0
12 | torchvision==0.5.0
13 | `pip install -r requirements.txt`でライブラリを揃えることができます。
14 |
15 | ## プログラム
16 | * `UGATIT_train.py`は学習を実行し、学習の過程と学習済みモデルを出力するプログラムです。
17 | * `UGATIT_inference.py`は`UGATIT_train.py`で出力した学習済みモデルを読み込み、推論(画像の変換)を実行、生成画像を出力するプログラムです。
18 |
19 | ## 使い方
20 | 以下では変換元ドメインをA、変換先ドメインをBと表現します。
21 | ### 学習の実行
22 | 1. `UGATIT_train.py`のあるディレクトリに`./dataset`ディレクトリを作成します
23 | 1. `./dataset`ディレクトリ内に`group_A`ディレクトリと`group_B`ディレクトリの2つを作成します。
24 | 1. `./dataset/group_A`ディレクトリに、Aに属する画像を`./dataset/group_A/*/*`という形式で好きな数入れます(画像のファイル形式はpng)。
25 | 1. `./dataset/group_B`ディレクトリに、Bに属する画像を`./dataset/group_B/*/*`という形式で好きな数入れます(画像のファイル形式はpng)。
26 | 1. `UGATIT_train.py`の置いてあるディレクトリで`python UGATIT_train.py`を実行することで、「A⇄B」の変換ができるよう目指して学習を実行します。
27 | * 学習の過程が`./output`以下に出力されます。
28 | * 学習済みモデルが`./trained_model/generator_A2B_trained_model_cpu.pth`として出力されます。
29 | ### 推論の実行
30 | 1. `UGATIT_inference.py`のあるディレクトリに`./conversion`ディレクトリを作成します
31 | 1. `./conversion`内に`target`ディレクトリを作成し、Aに属する画像を好きな数入れます。
32 | 1. `UGATIT_inference.py`の置いてあるディレクトリで`python UGATIT_inference.py`を実行して`./conversion/target`内の画像をBへ変換します
33 | * A→Bの変換結果が`./conversion/converted/`以下に出力されます。
34 | * 注意点として、`./trained_model`内に学習済みモデル`generator_A2B_trained_model_cpu.pth`がなければエラーとなります
35 |
36 | 学習には環境によっては12時間以上要する場合があります。
37 | 入力された画像は256×256にリサイズされた上で学習に使われます。出力画像も256×256です。
38 |
39 | ## 参考
40 | U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation
41 | U-GAT-IT — Official PyTorch Implementation
42 |
--------------------------------------------------------------------------------
/UGATIT_inference.py:
--------------------------------------------------------------------------------
1 | #encoding:utf-8
2 |
3 | from module.importer import *
4 | from module.dataloader import *
5 | from module.generator import *
6 |
7 | #学習済みモデルの読み込み
8 | generator = Generator()
9 | generator.load_state_dict(torch.load('./trained_model/generator_A2B_trained_model_cpu.pth'))
10 | #推論モードに切り替え
11 | generator.eval()
12 | #変換対象となる画像の読み込み
13 | path_list = make_datapath_list('./conversion/target/**/*')
14 | train_dataset = GAN_Img_Dataset(file_list=path_list,transform=ImageTransform(256))
15 | dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=len(path_list),shuffle=False)
16 | target = next(iter(dataloader))
17 | #generatorへ入力、出力画像を得る
18 | converted,_,_ = generator.forward(target)
19 | #画像出力用にディレクトリを作成
20 | os.makedirs("./conversion/converted",exist_ok=True)
21 | #画像を出力
22 | for i,output_img in enumerate(converted):
23 | origin_filename = os.path.basename(path_list[i])
24 | origin_filename_without_ex = os.path.splitext(origin_filename)[0]
25 | filename = "./conversion/converted/{}_converted{}.png".format(origin_filename_without_ex,i)
26 | #そこへ保存
27 | vutils.save_image(output_img,filename,normalize=True)
28 | print(origin_filename + " : converted")
29 |
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/UGATIT_train.py:
--------------------------------------------------------------------------------
1 | #encoding:utf-8
2 |
3 | from module.importer import *
4 | from module.discriminator import *
5 | from module.generator import *
6 | from module.dataloader import *
7 | from module.base_module import *
8 |
9 | #データセットAの、各データへのパスのフォーマット make_datapath_listへの引数
10 | dataroot_A = './dataset/group_A/**/*'
11 | #データセットBの、各データへのパスのフォーマット
12 | dataroot_B = './dataset/group_B/**/*'
13 | #バッチサイズ
14 | batch_size = 1
15 | #エポック数
16 | num_epochs = 40
17 | #generator,discriminatorのoptimizerに使う学習率
18 | learning_rate = 0.0001
19 | #Adamのweight decay(重み減衰)の度合い
20 | weight_decay = 0.0001
21 | #output_progress_intervalエポックごとに学習状況の画像を出力する
22 | output_progress_interval = 1
23 |
24 | #訓練データAの読み込み、データセット作成
25 | path_list_A = make_datapath_list(dataroot_A)
26 | transform_A = ImageModification(resize_pixel=256,x_move=[-0.05,0.05],y_move=[-0.05,0.05],min_scale=0.9)
27 | train_dataset_A = GAN_Img_Dataset(file_list=path_list_A,transform=transform_A)
28 | dataloader_A = torch.utils.data.DataLoader(train_dataset_A,batch_size=batch_size,shuffle=True)
29 |
30 | #訓練データBの読み込み、データセット作成
31 | path_list_B = make_datapath_list(dataroot_B)
32 | transform_B = ImageModification(resize_pixel=256,x_move=[-0.1,0.1],y_move=[-0.1,0.25],min_scale=0.7)
33 | train_dataset_B = GAN_Img_Dataset(file_list=path_list_B,transform=transform_B)
34 | dataloader_B = torch.utils.data.DataLoader(train_dataset_B,batch_size=batch_size,shuffle=True)
35 |
36 | #GPUが使用可能かどうか確認
37 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
38 | print("device:",device)
39 |
40 | # #ネットワークを初期化するための関数
41 | def weights_init(m):
42 | classname = m.__class__.__name__
43 | if classname.find('Conv2d') != -1:
44 | #平均0.0,標準偏差0.02となるように初期化
45 | nn.init.normal_(m.weight.data, 0.0, 0.02)
46 | elif classname.find('ConvTranspose2d') != -1:
47 | #平均0.0,標準偏差0.02となるように初期化
48 | nn.init.normal_(m.weight.data, 0.0, 0.02)
49 |
50 | #各ネットワークのインスタンスを生成、デバイスに移動
51 | netG_A2B = Generator().to(device)
52 | netG_B2A = Generator().to(device)
53 | netD_GA = Discriminator(n_layers=7).to(device)
54 | netD_GB = Discriminator(n_layers=7).to(device)
55 | netD_LA = Discriminator(n_layers=5).to(device)
56 | netD_LB = Discriminator(n_layers=5).to(device)
57 | #ネットワークの初期化
58 | netG_A2B.apply(weights_init)
59 | netG_B2A.apply(weights_init)
60 | netD_GA.apply(weights_init)
61 | netD_GB.apply(weights_init)
62 | netD_LA.apply(weights_init)
63 | netD_LB.apply(weights_init)
64 |
65 | #損失関数の初期化
66 | L1_loss = nn.L1Loss().to(device)
67 | MSE_loss = nn.MSELoss().to(device)
68 | BCE_loss = nn.BCEWithLogitsLoss().to(device)
69 | #Adam optimizersをGeneratorとDiscriminatorに適用
70 | beta1 = 0.5
71 | beta2 = 0.999
72 | optimizerG = torch.optim.Adam(itertools.chain(netG_A2B.parameters(),netG_B2A.parameters()),lr=learning_rate,betas=(beta1,beta2),weight_decay=weight_decay)
73 | optimizerD = torch.optim.Adam(itertools.chain(netD_GA.parameters(),netD_GB.parameters(),netD_LA.parameters(),netD_LB.parameters()),lr=learning_rate,betas=(beta1,beta2),weight_decay=weight_decay)
74 |
75 | #Generatorに使われているAdaILNとILNの、rhoの範囲を[0,1]に制限するためのモジュールを宣言
76 | Rho_Clipper = RhoClipper(0,1)
77 |
78 | #イテレーションを全部で何回実行することになるかを計算
79 | iteration_per_epoch = len(path_list_A) if len(path_list_A) (total_iteration // 2)):
145 | optimizerG.param_groups[0]['lr'] -= (learning_rate / (total_iteration//2))
146 | optimizerD.param_groups[0]['lr'] -= (learning_rate / (total_iteration//2))
147 |
148 | #-------------------------
149 | #discriminatorの学習
150 | #-------------------------
151 | #前のイテレーションでたまった傾きをリセット
152 | optimizerD.zero_grad()
153 |
154 | #本物画像から偽物画像を生成
155 | fake_A2B, _, _ = netG_A2B(real_A)
156 | fake_B2A, _, _ = netG_B2A(real_B)
157 |
158 | #本物画像に対しそれぞれ判定
159 | real_GA_logit, real_GA_cam_logit, _ = netD_GA(real_A)
160 | real_LA_logit, real_LA_cam_logit, _ = netD_LA(real_A)
161 | real_GB_logit, real_GB_cam_logit, _ = netD_GB(real_B)
162 | real_LB_logit, real_LB_cam_logit, _ = netD_LB(real_B)
163 |
164 | #偽物画像に対しそれぞれ判定
165 | fake_GA_logit, fake_GA_cam_logit, _ = netD_GA(fake_B2A)
166 | fake_LA_logit, fake_LA_cam_logit, _ = netD_LA(fake_B2A)
167 | fake_GB_logit, fake_GB_cam_logit, _ = netD_GB(fake_A2B)
168 | fake_LB_logit, fake_LB_cam_logit, _ = netD_LB(fake_A2B)
169 |
170 | #損失の計算
171 | D_ad_loss_GA = MSE_loss(real_GA_logit, torch.ones_like(real_GA_logit).to(device)) + MSE_loss(fake_GA_logit, torch.zeros_like(fake_GA_logit).to(device))
172 | D_ad_cam_loss_GA = MSE_loss(real_GA_cam_logit, torch.ones_like(real_GA_cam_logit).to(device)) + MSE_loss(fake_GA_cam_logit, torch.zeros_like(fake_GA_cam_logit).to(device))
173 | D_ad_loss_LA = MSE_loss(real_LA_logit, torch.ones_like(real_LA_logit).to(device)) + MSE_loss(fake_LA_logit, torch.zeros_like(fake_LA_logit).to(device))
174 | D_ad_cam_loss_LA = MSE_loss(real_LA_cam_logit, torch.ones_like(real_LA_cam_logit).to(device)) + MSE_loss(fake_LA_cam_logit, torch.zeros_like(fake_LA_cam_logit).to(device))
175 | D_ad_loss_GB = MSE_loss(real_GB_logit, torch.ones_like(real_GB_logit).to(device)) + MSE_loss(fake_GB_logit, torch.zeros_like(fake_GB_logit).to(device))
176 | D_ad_cam_loss_GB = MSE_loss(real_GB_cam_logit, torch.ones_like(real_GB_cam_logit).to(device)) + MSE_loss(fake_GB_cam_logit, torch.zeros_like(fake_GB_cam_logit).to(device))
177 | D_ad_loss_LB = MSE_loss(real_LB_logit, torch.ones_like(real_LB_logit).to(device)) + MSE_loss(fake_LB_logit, torch.zeros_like(fake_LB_logit).to(device))
178 | D_ad_cam_loss_LB = MSE_loss(real_LB_cam_logit, torch.ones_like(real_LB_cam_logit).to(device)) + MSE_loss(fake_LB_cam_logit, torch.zeros_like(fake_LB_cam_logit).to(device))
179 |
180 | D_loss_A = 1*(D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA)
181 | D_loss_B = 1*(D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB)
182 |
183 | Discriminator_loss = D_loss_A + D_loss_B
184 | #傾きを計算して
185 | Discriminator_loss.backward()
186 | #discriminatorのパラメーターを更新
187 | optimizerD.step()
188 |
189 | #後でグラフに出力するために記録
190 | D_losses_per_epoch.append(Discriminator_loss.item())
191 |
192 | #-------------------------
193 | #Generatorの学習
194 | #-------------------------
195 | #前のイテレーションでたまった傾きをリセット
196 | optimizerG.zero_grad()
197 |
198 | #本物画像から偽物画像を生成
199 | fake_A2B, fake_A2B_cam_logit, _ = netG_A2B(real_A)
200 | fake_B2A, fake_B2A_cam_logit, _ = netG_B2A(real_B)
201 |
202 | #偽物画像から本物に戻ってくるのを目指す
203 | fake_A2B2A, _, _ = netG_B2A(fake_A2B)
204 | fake_B2A2B, _, _ = netG_A2B(fake_B2A)
205 |
206 | #変換先と同じドメインの本物画像から偽物画像を生成
207 | fake_A2A, fake_A2A_cam_logit, _ = netG_B2A(real_A)
208 | fake_B2B, fake_B2B_cam_logit, _ = netG_A2B(real_B)
209 |
210 | #生成された偽物画像についてそれぞれ判定
211 | fake_GA_logit, fake_GA_cam_logit, _ = netD_GA(fake_B2A)
212 | fake_LA_logit, fake_LA_cam_logit, _ = netD_LA(fake_B2A)
213 | fake_GB_logit, fake_GB_cam_logit, _ = netD_GB(fake_A2B)
214 | fake_LB_logit, fake_LB_cam_logit, _ = netD_LB(fake_A2B)
215 |
216 | #損失の計算
217 | G_ad_loss_GA = MSE_loss(fake_GA_logit, torch.ones_like(fake_GA_logit).to(device))
218 | G_ad_cam_loss_GA = MSE_loss(fake_GA_cam_logit, torch.ones_like(fake_GA_cam_logit).to(device))
219 | G_ad_loss_LA = MSE_loss(fake_LA_logit, torch.ones_like(fake_LA_logit).to(device))
220 | G_ad_cam_loss_LA = MSE_loss(fake_LA_cam_logit, torch.ones_like(fake_LA_cam_logit).to(device))
221 | G_ad_loss_GB = MSE_loss(fake_GB_logit, torch.ones_like(fake_GB_logit).to(device))
222 | G_ad_cam_loss_GB = MSE_loss(fake_GB_cam_logit, torch.ones_like(fake_GB_cam_logit).to(device))
223 | G_ad_loss_LB = MSE_loss(fake_LB_logit, torch.ones_like(fake_LB_logit).to(device))
224 | G_ad_cam_loss_LB = MSE_loss(fake_LB_cam_logit, torch.ones_like(fake_LB_cam_logit).to(device))
225 |
226 | G_recon_loss_A = L1_loss(fake_A2B2A, real_A)
227 | G_recon_loss_B = L1_loss(fake_B2A2B, real_B)
228 |
229 | G_identity_loss_A = L1_loss(fake_A2A, real_A)
230 | G_identity_loss_B = L1_loss(fake_B2B, real_B)
231 |
232 | G_cam_loss_A = BCE_loss(fake_B2A_cam_logit, torch.ones_like(fake_B2A_cam_logit).to(device)) + BCE_loss(fake_A2A_cam_logit, torch.zeros_like(fake_A2A_cam_logit).to(device))
233 | G_cam_loss_B = BCE_loss(fake_A2B_cam_logit, torch.ones_like(fake_A2B_cam_logit).to(device)) + BCE_loss(fake_B2B_cam_logit, torch.zeros_like(fake_B2B_cam_logit).to(device))
234 |
235 | G_loss_A = 1*(G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA) + 10*G_recon_loss_A + 10*G_identity_loss_A + 1000*G_cam_loss_A
236 | G_loss_B = 1*(G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB) + 10*G_recon_loss_B + 10*G_identity_loss_B + 1000*G_cam_loss_B
237 |
238 | Generator_loss = G_loss_A + G_loss_B
239 | #傾きを計算して
240 | Generator_loss.backward()
241 | #generatorのパラメーターを更新
242 | optimizerG.step()
243 |
244 | #Generatorに使われているAdaILNとILNの、rhoの範囲を制限する
245 | netG_A2B.apply(Rho_Clipper)
246 | netG_B2A.apply(Rho_Clipper)
247 |
248 | #後でグラフに出力するために記録
249 | G_losses_per_epoch.append(Generator_loss.item())
250 |
251 | #学習状況をシェルに出力
252 | if iteration % 50 == 0:
253 | print('[%d/%d][iteration:%d]\tLoss_D: %.4f\tLoss_G: %.4f'
254 | % (epoch,num_epochs,iteration,
255 | Discriminator_loss.item(),Generator_loss.item()))
256 |
257 | iteration += 1
258 | #テスト用break
259 | #break
260 |
261 | #後で出力するためにepochごとにlossの平均を取り記録
262 | G_losses.append(torch.mean(torch.tensor(G_losses_per_epoch,dtype=torch.float64)).item())
263 | D_losses.append(torch.mean(torch.tensor(D_losses_per_epoch,dtype=torch.float64)).item())
264 | #Generatorの学習状況を画像として記録
265 | if (epoch % output_progress_interval == 0 or (epoch+1)==num_epochs):
266 | #ネットワークを推論モードにする
267 | netG_A2B.eval(),netG_B2A.eval()
268 | netD_GA.eval(),netD_GB.eval()
269 | netD_LA.eval(),netD_LB.eval()
270 | #画像出力用ディレクトリがなければ作成
271 | output_dir = "./output/epoch_{}".format(epoch+1)
272 | if not os.path.exists(output_dir):
273 | os.makedirs(output_dir)
274 | #画像の生成と出力
275 | #デバイスに配置されている画像をcpuに移す関数
276 | def move_to_cpu(imgs_on_device):
277 | imgs_on_cpu = []
278 | for im in imgs_on_device:
279 | imgs_on_cpu.append(im.detach().cpu())
280 | return tuple(imgs_on_cpu)
281 |
282 | fake_A2B, _, fake_A2B_heatmap = netG_A2B(sample_real_A)
283 | fake_B2A, _, fake_B2A_heatmap = netG_B2A(sample_real_B)
284 |
285 | fake_A2B2A, _, fake_A2B2A_heatmap = netG_B2A(fake_A2B)
286 | fake_B2A2B, _, fake_B2A2B_heatmap = netG_A2B(fake_B2A)
287 |
288 | fake_A2B,fake_A2B_heatmap = move_to_cpu([fake_A2B,fake_A2B_heatmap])
289 | fake_B2A,fake_B2A_heatmap = move_to_cpu([fake_B2A,fake_B2A_heatmap])
290 | fake_A2B2A,fake_A2B2A_heatmap = move_to_cpu([fake_A2B2A,fake_A2B2A_heatmap])
291 | fake_B2A2B,fake_B2A2B_heatmap = move_to_cpu([fake_B2A2B,fake_B2A2B_heatmap])
292 |
293 | fake_A2A, _, fake_A2A_heatmap = netG_B2A(sample_real_A)
294 | fake_B2B, _, fake_B2B_heatmap = netG_A2B(sample_real_B)
295 |
296 | fake_A2A,fake_A2A_heatmap = move_to_cpu([fake_A2A,fake_A2A_heatmap])
297 | fake_B2B,fake_B2B_heatmap = move_to_cpu([fake_B2B,fake_B2B_heatmap])
298 | sr_A,sr_B = move_to_cpu([sample_real_A,sample_real_B])
299 | #A->B->Aの画像の出力
300 | output_how_much_progress("./output/epoch_{}/conversion_A2B2A.png".format(epoch+1),[sr_A,fake_A2B,fake_A2B2A])
301 | #B->A->Bの画像の出力
302 | output_how_much_progress("./output/epoch_{}/conversion_B2A2B.png".format(epoch+1),[sr_B,fake_B2A,fake_B2A2B])
303 | #ヒートマップ(A)の出力
304 | fake_A2A_heatmap = F.interpolate(fake_A2A_heatmap,size=(256,256))
305 | fake_A2B_heatmap = F.interpolate(fake_A2B_heatmap,size=(256,256))
306 | fake_A2B2A_heatmap = F.interpolate(fake_A2B2A_heatmap,size=(256,256))
307 | output_how_much_progress("./output/epoch_{}/heatmap_A.png".format(epoch+1).format(epoch+1),[fake_A2A_heatmap,fake_A2B_heatmap,fake_A2B2A_heatmap])
308 | #ヒートマップ(B)の出力
309 | fake_B2B_heatmap = F.interpolate(fake_B2B_heatmap,size=(256,256))
310 | fake_B2A_heatmap = F.interpolate(fake_B2A_heatmap,size=(256,256))
311 | fake_B2A2B_heatmap = F.interpolate(fake_B2A2B_heatmap,size=(256,256))
312 | output_how_much_progress("./output/epoch_{}/heatmap_B.png".format(epoch+1).format(epoch+1),[fake_B2B_heatmap,fake_B2A_heatmap,fake_B2A2B_heatmap])
313 | #テスト用break
314 | #break
315 |
316 | #学習にかかった時間を出力
317 | #学習終了時の時間を記録
318 | t_epoch_finish = time.time()
319 | total_time = t_epoch_finish - t_epoch_start
320 | with open('./output/time.txt', mode='w') as f:
321 | f.write("total_time: {:.4f} sec.\n".format(total_time))
322 | f.write("dataset_A size: {}\n".format(len(path_list_A)))
323 | f.write("dataset_B size: {}\n".format(len(path_list_B)))
324 | f.write("num_epochs: {}\n".format(num_epochs))
325 | f.write("batch_size: {}\n".format(batch_size))
326 |
327 | #学習済みGeneratorのモデル(CPU向け)を出力
328 | #モデル出力用ディレクトリがなければ作成
329 | output_dir = "./trained_model"
330 | if not os.path.exists(output_dir):
331 | os.makedirs(output_dir)
332 | torch.save(netG_A2B.to('cpu').state_dict(),'./trained_model/generator_A2B_trained_model_cpu.pth')
333 | torch.save(netG_B2A.to('cpu').state_dict(),'./trained_model/generator_B2A_trained_model_cpu.pth')
334 |
335 | #lossのグラフを出力
336 | plt.clf()
337 | plt.figure(figsize=(10,5))
338 | plt.title("Generator and Discriminator Loss During Training")
339 | plt.plot(G_losses,label="G")
340 | plt.plot(D_losses,label="D")
341 | plt.xlabel("epoch")
342 | plt.ylabel("Loss")
343 | plt.legend()
344 | plt.savefig('./output/loss.png')
345 |
346 |
--------------------------------------------------------------------------------
/example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zassou65535/image_converter/4dd9a76154d43e6c339d5953b7b57287fd8e8204/example.png
--------------------------------------------------------------------------------
/module/base_module.py:
--------------------------------------------------------------------------------
1 | #encoding:utf-8
2 |
3 | from .importer import *
4 |
5 | class ResnetBlock(nn.Module):
6 | def __init__(self, dim, use_bias):
7 | super(ResnetBlock,self).__init__()
8 | self.conv_block = nn.Sequential(
9 | nn.ReflectionPad2d(1),
10 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias),
11 | nn.InstanceNorm2d(dim),
12 | nn.ReLU(True),
13 |
14 | nn.ReflectionPad2d(1),
15 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias),
16 | nn.InstanceNorm2d(dim)
17 | )
18 |
19 | def forward(self, x):
20 | out = x + self.conv_block(x)
21 | return out
22 |
23 | #上のResnetBlockのInstanceNorm2dを、後述のadaILNに差し替えたもの
24 | class ResnetAdaILNBlock(nn.Module):
25 | def __init__(self, dim, use_bias):
26 | super(ResnetAdaILNBlock,self).__init__()
27 | self.pad1 = nn.ReflectionPad2d(1)
28 | self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias)
29 | self.norm1 = adaILN(dim)
30 | self.relu1 = nn.ReLU(True)
31 |
32 | self.pad2 = nn.ReflectionPad2d(1)
33 | self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias)
34 | self.norm2 = adaILN(dim)
35 |
36 | def forward(self, x, gamma, beta):
37 | out = self.pad1(x)
38 | out = self.conv1(out)
39 | out = self.norm1(out, gamma, beta)
40 | out = self.relu1(out)
41 | out = self.pad2(out)
42 | out = self.conv2(out)
43 | out = self.norm2(out, gamma, beta)
44 | return out + x
45 |
46 | #入力されたTensorに対し
47 | #各チャネル別々に正規化したものと、そのレイヤー内でいっぺんに正規化をかけたもの
48 | #双方を割合rhoで混ぜ合わせ出力とする層
49 | #rhoは学習可能
50 | class adaILN(nn.Module):
51 | def __init__(self, num_features, eps=1e-5):
52 | super(adaILN,self).__init__()
53 | self.eps = eps
54 | #rhoを学習可能にする
55 | self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
56 | #初期値0.9
57 | self.rho.data.fill_(0.9)
58 |
59 | def forward(self, input, gamma, beta):
60 | #各チャネル別々に、特徴マップに対し正規化をかける
61 | #各チャネル別々に、特徴マップの縦横全ての平均と分散を計算
62 | in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True)
63 | out_in = (input - in_mean) / torch.sqrt(in_var + self.eps)
64 | #レイヤーの全部の特徴マップに対し正規化をかける
65 | #特徴マップのチャネル+縦横全ての平均と分散を計算(そのレイヤーの全部の値に対して計算)
66 | ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3], keepdim=True), torch.var(input, dim=[1, 2, 3], keepdim=True)
67 | out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps)
68 | #self.rhoを割合としてout_inとout_lnを混ぜ合わせ出力とする
69 | out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln
70 | out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
71 | return out
72 |
73 | #adaILNの、gamma,betaを学習によって決められるようにしたバージョン
74 | class ILN(nn.Module):
75 | def __init__(self, num_features, eps=1e-5):
76 | super(ILN,self).__init__()
77 | self.eps = eps
78 | self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
79 | self.gamma = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
80 | self.beta = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
81 | self.rho.data.fill_(0.0)
82 | self.gamma.data.fill_(1.0)
83 | self.beta.data.fill_(0.0)
84 |
85 | def forward(self, input):
86 | in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True)
87 | out_in = (input - in_mean) / torch.sqrt(in_var + self.eps)
88 | ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3], keepdim=True), torch.var(input, dim=[1, 2, 3], keepdim=True)
89 | out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps)
90 | out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln
91 | out = out * self.gamma.expand(input.shape[0], -1, -1, -1) + self.beta.expand(input.shape[0], -1, -1, -1)
92 | return out
93 |
94 | #Model.apply(Rho_Clipper)とすることで、
95 | #Model中に含まれるrhoの値を[min,max]の範囲に制限できる
96 | class RhoClipper(object):
97 | def __init__(self, min, max):
98 | self.clip_min = min
99 | self.clip_max = max
100 | assert min < max
101 | def __call__(self, module):
102 | if hasattr(module, 'rho'):
103 | w = module.rho.data
104 | w = w.clamp(self.clip_min, self.clip_max)
105 | module.rho.data = w
106 |
--------------------------------------------------------------------------------
/module/dataloader.py:
--------------------------------------------------------------------------------
1 | #encoding:utf-8
2 |
3 | from .importer import *
4 |
5 | def make_datapath_list(target_path):
6 | #読み込むデータセットのパス
7 | #例えば
8 | #target_path = "./dataset/**/*"などとします
9 | #画像のファイル形式はpng
10 | path_list = []#画像ファイルパスのリストを作り、戻り値とする
11 | for path in glob.glob(target_path,recursive=True):
12 | if os.path.isfile(path):
13 | path_list.append(path)
14 | ##読み込むパスを全部表示 必要ならコメントアウトを外す
15 | #print(path)
16 | #読み込んだ画像の数を表示
17 | print("images : " + str(len(path_list)))
18 | path_list = sorted(path_list)
19 | return path_list
20 |
21 | class ImageTransform():
22 | #画像の前処理クラス
23 | def __init__(self,resize_pixel):
24 | self.data_transform = transforms.Compose([
25 | transforms.Resize((resize_pixel,resize_pixel)),
26 | transforms.ToTensor(),
27 | transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
28 | ])
29 | def __call__(self,img):
30 | return self.data_transform(img)
31 |
32 | class ImageModification():
33 | #画像の前処理クラス 画像の平行移動、拡大縮小なども行う
34 | def __init__(self,resize_pixel,x_move=[-0.1,0.1],y_move=[-0.1,0.2],min_scale=0.75):
35 | self.resize_pixel = resize_pixel
36 | self.x_move = x_move
37 | self.y_move = y_move
38 | self.min_scale = min_scale
39 | self.data_resize = transforms.Resize((resize_pixel*2,resize_pixel*2))
40 | self.data_arrange = transforms.Compose([
41 | transforms.Resize((resize_pixel,resize_pixel)),
42 | transforms.ToTensor(),
43 | transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
44 | ])
45 | def __call__(self,img):
46 | img = self.data_resize(img)
47 | #中心(y,x)を何pixelずらすかを指定
48 | move_pixel_x = np.random.uniform(self.resize_pixel*self.x_move[0],self.resize_pixel*self.x_move[1])
49 | move_pixel_y = np.random.uniform(self.resize_pixel*self.y_move[0],self.resize_pixel*self.y_move[1])
50 | move_pixel = [move_pixel_x,move_pixel_y]
51 | #ずらす
52 | img = transforms.functional.affine(img,angle=0,translate=(move_pixel),scale=1,shear=0)
53 | #切り取る
54 | max_crop_size = 2*(self.resize_pixel - np.max(np.abs(move_pixel)))
55 | min_crop_size = 2*(self.resize_pixel*self.min_scale)
56 | crop_size = np.random.randint(min_crop_size,max_crop_size)
57 | img = transforms.functional.center_crop(img,crop_size)
58 | #Tensorに変換して出力
59 | img = self.data_arrange(img)
60 | return img
61 |
62 | class GAN_Img_Dataset(data.Dataset):
63 | #画像のデータセットクラス
64 | def __init__(self,file_list,transform):
65 | self.file_list = file_list
66 | self.transform = transform
67 | #画像の枚数を返す
68 | def __len__(self):
69 | return len(self.file_list)
70 | #前処理済み画像の、Tensor形式のデータを取得
71 | def __getitem__(self,index):
72 | img_path = self.file_list[index]
73 | img = Image.open(img_path)#[RGB][高さ][幅]
74 | img = img.convert('RGB')#pngをjpg形式に変換
75 | img_transformed = self.transform(img)
76 | return img_transformed
77 |
78 | #動作確認
79 | # path_list = make_datapath_list("../dataset/group_B/**/*")
80 |
81 | # transform = ImageModification(resize_pixel=256,x_move=[-0.1,0.1],y_move=[-0.1,0.25],min_scale=0.7)
82 | # dataset = GAN_Img_Dataset(file_list=path_list,transform=transform)
83 |
84 | # batch_size = 8
85 | # dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=False)
86 |
87 | # imgs = next(iter(dataloader))
88 | # print(imgs.size())
89 |
90 | # for i,img_transformed in enumerate(imgs):
91 | # img_transformed = img_transformed.detach()
92 | # vutils.save_image(img_transformed,"../output/test_img_{}.png".format(i),normalize=True)
93 |
94 |
95 |
96 |
--------------------------------------------------------------------------------
/module/discriminator.py:
--------------------------------------------------------------------------------
1 | #encoding:utf-8
2 |
3 | from .importer import *
4 | from .base_module import *
5 |
6 | class Discriminator(nn.Module):
7 | def __init__(self, input_nc=3, ndf=64, n_layers=5):
8 | super(Discriminator,self).__init__()
9 | #レイヤーは全部でn_layers個
10 | #最初の1レイヤー目
11 | model = [nn.ReflectionPad2d(1),
12 | nn.utils.spectral_norm(
13 | nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=0, bias=True)),
14 | nn.LeakyReLU(0.2, True)]
15 |
16 | #(n_layers-2)個の中間のレイヤー
17 | for i in range(1, n_layers - 2):
18 | mult = 2 ** (i - 1)
19 | model += [nn.ReflectionPad2d(1),
20 | nn.utils.spectral_norm(
21 | nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=2, padding=0, bias=True)),
22 | nn.LeakyReLU(0.2, True)]
23 |
24 | #最後のレイヤー
25 | mult = 2 ** (n_layers - 2 - 1)
26 | model += [nn.ReflectionPad2d(1),
27 | nn.utils.spectral_norm(
28 | nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=1, padding=0, bias=True)),
29 | nn.LeakyReLU(0.2, True)]
30 |
31 | self.model = nn.Sequential(*model)
32 |
33 | #CAM用モジュール類を定義
34 | mult = 2 ** (n_layers - 2)
35 | self.gap_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False))
36 | self.gmp_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False))
37 | self.conv1x1 = nn.Conv2d(ndf * mult * 2, ndf * mult, kernel_size=1, stride=1, bias=True)
38 | self.leaky_relu = nn.LeakyReLU(0.2, True)
39 |
40 | self.pad = nn.ReflectionPad2d(1)
41 | self.conv = nn.utils.spectral_norm(nn.Conv2d(ndf * mult, 1, kernel_size=4, stride=1, padding=0, bias=False))
42 |
43 | def forward(self, input):
44 | #論文中で、Encoderと呼ばれている層を通す
45 | x = self.model(input)
46 |
47 | #平均を取る操作によって
48 | # x : torch.Size([batch_size,channel,Height,Width])を
49 | #gap : torch.Size([batch_size,channel,1,1])に変換する
50 | # = Global Average Poolingを実行
51 | gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
52 | #gap_logit : torch.Size([batch_size,1])
53 | gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
54 | #gap_weight : torch.Size([1,channel])
55 | gap_weight = list(self.gap_fc.parameters())[0]
56 | #gap_weight.unsqueeze(2).unsqueeze(3) : torch.Size([1,channel,1,1])
57 | #gap : torch.Size([batch_size,channel,Height,Width])
58 | gap = x * gap_weight.unsqueeze(2).unsqueeze(3)
59 |
60 | #最大値を取る操作によって
61 | # x : torch.Size([batch_size,channel,Height,Width])を
62 | #gmp : torch.Size([batch_size,channel,1,1])に変換する
63 | # = Global Max Poolingを実行
64 | gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
65 | #gmp_logit : torch.Size([batch_size,1])
66 | gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
67 | #gmp_weight : torch.Size([1,channel])
68 | gmp_weight = list(self.gmp_fc.parameters())[0] #元々[1]以降は存在しない つまりlen(list(self.gmp_fc.parameters())) = 1
69 | #gmp_weight.unsqueeze(2).unsqueeze(3) : torch.Size([1,channel,1,1])
70 | #gmp : torch.Size([batch_size,channel,Height,Width])
71 | gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)
72 |
73 | #cam_logit : torch.Size([batch_size,2])
74 | cam_logit = torch.cat([gap_logit, gmp_logit], 1)
75 | x = torch.cat([gap, gmp], 1)
76 | #この時点で
77 | # x : torch.Size([batch_size,channel*2,Height,Width])
78 | x = self.leaky_relu(self.conv1x1(x))
79 | #この時点で
80 | # x : torch.Size([batch_size,channel,Height,Width])
81 |
82 | #heatmap : torch.Size([batch_size,1,Height,Width])
83 | heatmap = torch.sum(x, dim=1, keepdim=True)
84 |
85 | #self.pad = nn.ReflectionPad2d(1)を適用
86 | x = self.pad(x)
87 | #この時点で
88 | # x : torch.Size([batch_size,channel,Height+2,Width+2])
89 | out = self.conv(x)
90 | #out : torch.Size([batch_size,1,Height-1,Width-1])
91 |
92 | #out : torch.Size([batch_size,1,Height-1,Width-1])
93 | #cam_logit : torch.Size([batch_size,2])
94 | #heatmap : torch.Size([batch_size,1,Height,Width])
95 | return out, cam_logit, heatmap
96 |
--------------------------------------------------------------------------------
/module/generator.py:
--------------------------------------------------------------------------------
1 | #encoding:utf-8
2 |
3 | from .importer import *
4 | from .base_module import *
5 |
6 | class Generator(nn.Module):
7 | def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=6, img_size=256):
8 | super(Generator,self).__init__()
9 | self.input_nc = input_nc
10 | self.output_nc = output_nc
11 | self.ngf = ngf
12 | self.n_blocks = n_blocks
13 | self.img_size = img_size
14 |
15 | DownBlock = []
16 | DownBlock += [nn.ReflectionPad2d(3),
17 | nn.Conv2d(input_nc, ngf, kernel_size=7, stride=1, padding=0, bias=False),
18 | nn.InstanceNorm2d(ngf),
19 | nn.ReLU(True)]
20 |
21 | #2回分の畳み込み層を入れる
22 | n_downsampling = 2
23 | for i in range(n_downsampling):
24 | mult = 2**i
25 | DownBlock += [nn.ReflectionPad2d(1),
26 | nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=0, bias=False),
27 | nn.InstanceNorm2d(ngf * mult * 2),
28 | nn.ReLU(True)]
29 |
30 | #ResnetBlockをn_blocks=6個分入れる
31 | mult = 2**n_downsampling
32 | for i in range(n_blocks):
33 | DownBlock += [ResnetBlock(ngf * mult, use_bias=False)]
34 |
35 | #論文中で、Encoderと呼ばれている箇所を作成
36 | self.DownBlock = nn.Sequential(*DownBlock)
37 |
38 | # Class Activation Map
39 | self.gap_fc = nn.Linear(ngf * mult, 1, bias=False)
40 | self.gmp_fc = nn.Linear(ngf * mult, 1, bias=False)
41 | self.conv1x1 = nn.Conv2d(ngf * mult * 2, ngf * mult, kernel_size=1, stride=1, bias=True)
42 | self.relu = nn.ReLU(True)
43 |
44 | #γとβ用のモジュールを作成
45 | FC = [nn.Linear(img_size // mult * img_size // mult * ngf * mult, ngf * mult, bias=False),
46 | nn.ReLU(True),
47 | nn.Linear(ngf * mult, ngf * mult, bias=False),
48 | nn.ReLU(True)]
49 | self.gamma = nn.Linear(ngf * mult, ngf * mult, bias=False)
50 | self.beta = nn.Linear(ngf * mult, ngf * mult, bias=False)
51 |
52 | self.FC = nn.Sequential(*FC)
53 |
54 | #Up-Samplingするための箇所を作成(論文中でDecoderと呼ばれている箇所)
55 | #AdaILN入りのResnetBlockをn_blocks=6個分作る
56 | for i in range(n_blocks):
57 | setattr(self, 'UpBlock1_' + str(i+1), ResnetAdaILNBlock(ngf * mult, use_bias=False))
58 |
59 | #n_downsampling=2回分、Up-Samplingする層を追加
60 | UpBlock2 = []
61 | for i in range(n_downsampling):
62 | mult = 2**(n_downsampling - i)
63 | UpBlock2 += [nn.Upsample(scale_factor=2, mode='nearest'),
64 | nn.ReflectionPad2d(1),
65 | nn.Conv2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=1, padding=0, bias=False),
66 | ILN(int(ngf * mult / 2)),
67 | nn.ReLU(True)]
68 |
69 | UpBlock2 += [nn.ReflectionPad2d(3),
70 | nn.Conv2d(ngf, output_nc, kernel_size=7, stride=1, padding=0, bias=False),
71 | nn.Tanh()]
72 |
73 | #論文中で、Decoderと呼ばれている箇所を作成
74 | self.UpBlock2 = nn.Sequential(*UpBlock2)
75 |
76 | def forward(self, input):
77 | #論文中で、Encoderと呼ばれている層に通す
78 | x = self.DownBlock(input)
79 |
80 | gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
81 | gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
82 | gap_weight = list(self.gap_fc.parameters())[0]
83 | gap = x * gap_weight.unsqueeze(2).unsqueeze(3)
84 |
85 | gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
86 | gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
87 | gmp_weight = list(self.gmp_fc.parameters())[0]
88 | gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)
89 |
90 | cam_logit = torch.cat([gap_logit, gmp_logit], 1)
91 | x = torch.cat([gap, gmp], 1)
92 | x = self.relu(self.conv1x1(x))
93 |
94 | heatmap = torch.sum(x, dim=1, keepdim=True)
95 |
96 | #γとβを求める
97 | x_ = self.FC(x.view(x.shape[0], -1))
98 | gamma, beta = self.gamma(x_), self.beta(x_)
99 |
100 | #論文中で、Decoderと呼ばれている層に通す
101 | #n_blocks=6個分のResnetBlockに通す
102 | for i in range(self.n_blocks):
103 | x = getattr(self, 'UpBlock1_' + str(i+1))(x, gamma, beta)
104 | #Upsample
105 | out = self.UpBlock2(x)
106 |
107 | return out, cam_logit, heatmap
108 |
--------------------------------------------------------------------------------
/module/importer.py:
--------------------------------------------------------------------------------
1 | #encoding:utf-8
2 |
3 | import glob
4 | import os as os
5 | import os.path as osp
6 | import random
7 | import numpy as np
8 | import json
9 | from PIL import Image
10 | from tqdm import tqdm
11 | import matplotlib as mpl
12 | mpl.use('Agg')# AGG(Anti-Grain Geometry engine)
13 | import matplotlib.pyplot as plt
14 |
15 | import torch
16 | import torch.nn as nn
17 | import torch.optim as optim
18 | import torch.utils.data as data
19 | import torchvision
20 | from torchvision import models,transforms
21 | import torch.nn.init as init
22 | from torch.autograd import Function
23 | import torch.nn.functional as F
24 | import torchvision.utils as vutils
25 |
26 | import xml.etree.ElementTree as ET
27 | import itertools
28 | from math import sqrt
29 | import time
30 | import sys
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cycler==0.10.0
2 | kiwisolver==1.1.0
3 | matplotlib==3.1.3
4 | numpy
5 | opencv-python==4.2.0.32
6 | pandas==1.0.1
7 | pillow>=8.3.2
8 | pyparsing==2.4.6
9 | python-dateutil==2.8.1
10 | pytz==2019.3
11 | six==1.14.0
12 | tqdm==4.42.1
13 |
--------------------------------------------------------------------------------