├── .gitignore ├── LICENSE ├── README.md ├── UGATIT_inference.py ├── UGATIT_train.py ├── example.png ├── module ├── base_module.py ├── dataloader.py ├── discriminator.py ├── generator.py └── importer.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | dataset/ 2 | __pycache__/ 3 | output/ 4 | trained_model/ 5 | conversion/ 6 | .DS_Store 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 zassou65535 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # image_converter 2 | ## 概要 3 | 4 | 5 | 6 | メルアイコン変換器で用いた、UGATITのpytorch実装です。 7 | 詳しい解説はこちら。 8 | 9 | ## 想定環境 10 | python 3.7.1, VRAM32GB以上 11 | torch==1.4.0 12 | torchvision==0.5.0 13 | `pip install -r requirements.txt`でライブラリを揃えることができます。 14 | 15 | ## プログラム 16 | * `UGATIT_train.py`は学習を実行し、学習の過程と学習済みモデルを出力するプログラムです。 17 | * `UGATIT_inference.py`は`UGATIT_train.py`で出力した学習済みモデルを読み込み、推論(画像の変換)を実行、生成画像を出力するプログラムです。 18 | 19 | ## 使い方 20 | 以下では変換元ドメインをA、変換先ドメインをBと表現します。 21 | ### 学習の実行 22 | 1. `UGATIT_train.py`のあるディレクトリに`./dataset`ディレクトリを作成します 23 | 1. `./dataset`ディレクトリ内に`group_A`ディレクトリと`group_B`ディレクトリの2つを作成します。 24 | 1. `./dataset/group_A`ディレクトリに、Aに属する画像を`./dataset/group_A/*/*`という形式で好きな数入れます(画像のファイル形式はpng)。 25 | 1. `./dataset/group_B`ディレクトリに、Bに属する画像を`./dataset/group_B/*/*`という形式で好きな数入れます(画像のファイル形式はpng)。 26 | 1. `UGATIT_train.py`の置いてあるディレクトリで`python UGATIT_train.py`を実行することで、「A⇄B」の変換ができるよう目指して学習を実行します。 27 | * 学習の過程が`./output`以下に出力されます。 28 | * 学習済みモデルが`./trained_model/generator_A2B_trained_model_cpu.pth`として出力されます。 29 | ### 推論の実行 30 | 1. `UGATIT_inference.py`のあるディレクトリに`./conversion`ディレクトリを作成します 31 | 1. `./conversion`内に`target`ディレクトリを作成し、Aに属する画像を好きな数入れます。 32 | 1. `UGATIT_inference.py`の置いてあるディレクトリで`python UGATIT_inference.py`を実行して`./conversion/target`内の画像をBへ変換します 33 | * A→Bの変換結果が`./conversion/converted/`以下に出力されます。 34 | * 注意点として、`./trained_model`内に学習済みモデル`generator_A2B_trained_model_cpu.pth`がなければエラーとなります 35 | 36 | 学習には環境によっては12時間以上要する場合があります。 37 | 入力された画像は256×256にリサイズされた上で学習に使われます。出力画像も256×256です。 38 | 39 | ## 参考 40 | U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation 41 | U-GAT-IT — Official PyTorch Implementation 42 | -------------------------------------------------------------------------------- /UGATIT_inference.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | 3 | from module.importer import * 4 | from module.dataloader import * 5 | from module.generator import * 6 | 7 | #学習済みモデルの読み込み 8 | generator = Generator() 9 | generator.load_state_dict(torch.load('./trained_model/generator_A2B_trained_model_cpu.pth')) 10 | #推論モードに切り替え 11 | generator.eval() 12 | #変換対象となる画像の読み込み 13 | path_list = make_datapath_list('./conversion/target/**/*') 14 | train_dataset = GAN_Img_Dataset(file_list=path_list,transform=ImageTransform(256)) 15 | dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=len(path_list),shuffle=False) 16 | target = next(iter(dataloader)) 17 | #generatorへ入力、出力画像を得る 18 | converted,_,_ = generator.forward(target) 19 | #画像出力用にディレクトリを作成 20 | os.makedirs("./conversion/converted",exist_ok=True) 21 | #画像を出力 22 | for i,output_img in enumerate(converted): 23 | origin_filename = os.path.basename(path_list[i]) 24 | origin_filename_without_ex = os.path.splitext(origin_filename)[0] 25 | filename = "./conversion/converted/{}_converted{}.png".format(origin_filename_without_ex,i) 26 | #そこへ保存 27 | vutils.save_image(output_img,filename,normalize=True) 28 | print(origin_filename + " : converted") 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /UGATIT_train.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | 3 | from module.importer import * 4 | from module.discriminator import * 5 | from module.generator import * 6 | from module.dataloader import * 7 | from module.base_module import * 8 | 9 | #データセットAの、各データへのパスのフォーマット make_datapath_listへの引数 10 | dataroot_A = './dataset/group_A/**/*' 11 | #データセットBの、各データへのパスのフォーマット 12 | dataroot_B = './dataset/group_B/**/*' 13 | #バッチサイズ 14 | batch_size = 1 15 | #エポック数 16 | num_epochs = 40 17 | #generator,discriminatorのoptimizerに使う学習率 18 | learning_rate = 0.0001 19 | #Adamのweight decay(重み減衰)の度合い 20 | weight_decay = 0.0001 21 | #output_progress_intervalエポックごとに学習状況の画像を出力する 22 | output_progress_interval = 1 23 | 24 | #訓練データAの読み込み、データセット作成 25 | path_list_A = make_datapath_list(dataroot_A) 26 | transform_A = ImageModification(resize_pixel=256,x_move=[-0.05,0.05],y_move=[-0.05,0.05],min_scale=0.9) 27 | train_dataset_A = GAN_Img_Dataset(file_list=path_list_A,transform=transform_A) 28 | dataloader_A = torch.utils.data.DataLoader(train_dataset_A,batch_size=batch_size,shuffle=True) 29 | 30 | #訓練データBの読み込み、データセット作成 31 | path_list_B = make_datapath_list(dataroot_B) 32 | transform_B = ImageModification(resize_pixel=256,x_move=[-0.1,0.1],y_move=[-0.1,0.25],min_scale=0.7) 33 | train_dataset_B = GAN_Img_Dataset(file_list=path_list_B,transform=transform_B) 34 | dataloader_B = torch.utils.data.DataLoader(train_dataset_B,batch_size=batch_size,shuffle=True) 35 | 36 | #GPUが使用可能かどうか確認 37 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 38 | print("device:",device) 39 | 40 | # #ネットワークを初期化するための関数 41 | def weights_init(m): 42 | classname = m.__class__.__name__ 43 | if classname.find('Conv2d') != -1: 44 | #平均0.0,標準偏差0.02となるように初期化 45 | nn.init.normal_(m.weight.data, 0.0, 0.02) 46 | elif classname.find('ConvTranspose2d') != -1: 47 | #平均0.0,標準偏差0.02となるように初期化 48 | nn.init.normal_(m.weight.data, 0.0, 0.02) 49 | 50 | #各ネットワークのインスタンスを生成、デバイスに移動 51 | netG_A2B = Generator().to(device) 52 | netG_B2A = Generator().to(device) 53 | netD_GA = Discriminator(n_layers=7).to(device) 54 | netD_GB = Discriminator(n_layers=7).to(device) 55 | netD_LA = Discriminator(n_layers=5).to(device) 56 | netD_LB = Discriminator(n_layers=5).to(device) 57 | #ネットワークの初期化 58 | netG_A2B.apply(weights_init) 59 | netG_B2A.apply(weights_init) 60 | netD_GA.apply(weights_init) 61 | netD_GB.apply(weights_init) 62 | netD_LA.apply(weights_init) 63 | netD_LB.apply(weights_init) 64 | 65 | #損失関数の初期化 66 | L1_loss = nn.L1Loss().to(device) 67 | MSE_loss = nn.MSELoss().to(device) 68 | BCE_loss = nn.BCEWithLogitsLoss().to(device) 69 | #Adam optimizersをGeneratorとDiscriminatorに適用 70 | beta1 = 0.5 71 | beta2 = 0.999 72 | optimizerG = torch.optim.Adam(itertools.chain(netG_A2B.parameters(),netG_B2A.parameters()),lr=learning_rate,betas=(beta1,beta2),weight_decay=weight_decay) 73 | optimizerD = torch.optim.Adam(itertools.chain(netD_GA.parameters(),netD_GB.parameters(),netD_LA.parameters(),netD_LB.parameters()),lr=learning_rate,betas=(beta1,beta2),weight_decay=weight_decay) 74 | 75 | #Generatorに使われているAdaILNとILNの、rhoの範囲を[0,1]に制限するためのモジュールを宣言 76 | Rho_Clipper = RhoClipper(0,1) 77 | 78 | #イテレーションを全部で何回実行することになるかを計算 79 | iteration_per_epoch = len(path_list_A) if len(path_list_A) (total_iteration // 2)): 145 | optimizerG.param_groups[0]['lr'] -= (learning_rate / (total_iteration//2)) 146 | optimizerD.param_groups[0]['lr'] -= (learning_rate / (total_iteration//2)) 147 | 148 | #------------------------- 149 | #discriminatorの学習 150 | #------------------------- 151 | #前のイテレーションでたまった傾きをリセット 152 | optimizerD.zero_grad() 153 | 154 | #本物画像から偽物画像を生成 155 | fake_A2B, _, _ = netG_A2B(real_A) 156 | fake_B2A, _, _ = netG_B2A(real_B) 157 | 158 | #本物画像に対しそれぞれ判定 159 | real_GA_logit, real_GA_cam_logit, _ = netD_GA(real_A) 160 | real_LA_logit, real_LA_cam_logit, _ = netD_LA(real_A) 161 | real_GB_logit, real_GB_cam_logit, _ = netD_GB(real_B) 162 | real_LB_logit, real_LB_cam_logit, _ = netD_LB(real_B) 163 | 164 | #偽物画像に対しそれぞれ判定 165 | fake_GA_logit, fake_GA_cam_logit, _ = netD_GA(fake_B2A) 166 | fake_LA_logit, fake_LA_cam_logit, _ = netD_LA(fake_B2A) 167 | fake_GB_logit, fake_GB_cam_logit, _ = netD_GB(fake_A2B) 168 | fake_LB_logit, fake_LB_cam_logit, _ = netD_LB(fake_A2B) 169 | 170 | #損失の計算 171 | D_ad_loss_GA = MSE_loss(real_GA_logit, torch.ones_like(real_GA_logit).to(device)) + MSE_loss(fake_GA_logit, torch.zeros_like(fake_GA_logit).to(device)) 172 | D_ad_cam_loss_GA = MSE_loss(real_GA_cam_logit, torch.ones_like(real_GA_cam_logit).to(device)) + MSE_loss(fake_GA_cam_logit, torch.zeros_like(fake_GA_cam_logit).to(device)) 173 | D_ad_loss_LA = MSE_loss(real_LA_logit, torch.ones_like(real_LA_logit).to(device)) + MSE_loss(fake_LA_logit, torch.zeros_like(fake_LA_logit).to(device)) 174 | D_ad_cam_loss_LA = MSE_loss(real_LA_cam_logit, torch.ones_like(real_LA_cam_logit).to(device)) + MSE_loss(fake_LA_cam_logit, torch.zeros_like(fake_LA_cam_logit).to(device)) 175 | D_ad_loss_GB = MSE_loss(real_GB_logit, torch.ones_like(real_GB_logit).to(device)) + MSE_loss(fake_GB_logit, torch.zeros_like(fake_GB_logit).to(device)) 176 | D_ad_cam_loss_GB = MSE_loss(real_GB_cam_logit, torch.ones_like(real_GB_cam_logit).to(device)) + MSE_loss(fake_GB_cam_logit, torch.zeros_like(fake_GB_cam_logit).to(device)) 177 | D_ad_loss_LB = MSE_loss(real_LB_logit, torch.ones_like(real_LB_logit).to(device)) + MSE_loss(fake_LB_logit, torch.zeros_like(fake_LB_logit).to(device)) 178 | D_ad_cam_loss_LB = MSE_loss(real_LB_cam_logit, torch.ones_like(real_LB_cam_logit).to(device)) + MSE_loss(fake_LB_cam_logit, torch.zeros_like(fake_LB_cam_logit).to(device)) 179 | 180 | D_loss_A = 1*(D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA) 181 | D_loss_B = 1*(D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB) 182 | 183 | Discriminator_loss = D_loss_A + D_loss_B 184 | #傾きを計算して 185 | Discriminator_loss.backward() 186 | #discriminatorのパラメーターを更新 187 | optimizerD.step() 188 | 189 | #後でグラフに出力するために記録 190 | D_losses_per_epoch.append(Discriminator_loss.item()) 191 | 192 | #------------------------- 193 | #Generatorの学習 194 | #------------------------- 195 | #前のイテレーションでたまった傾きをリセット 196 | optimizerG.zero_grad() 197 | 198 | #本物画像から偽物画像を生成 199 | fake_A2B, fake_A2B_cam_logit, _ = netG_A2B(real_A) 200 | fake_B2A, fake_B2A_cam_logit, _ = netG_B2A(real_B) 201 | 202 | #偽物画像から本物に戻ってくるのを目指す 203 | fake_A2B2A, _, _ = netG_B2A(fake_A2B) 204 | fake_B2A2B, _, _ = netG_A2B(fake_B2A) 205 | 206 | #変換先と同じドメインの本物画像から偽物画像を生成 207 | fake_A2A, fake_A2A_cam_logit, _ = netG_B2A(real_A) 208 | fake_B2B, fake_B2B_cam_logit, _ = netG_A2B(real_B) 209 | 210 | #生成された偽物画像についてそれぞれ判定 211 | fake_GA_logit, fake_GA_cam_logit, _ = netD_GA(fake_B2A) 212 | fake_LA_logit, fake_LA_cam_logit, _ = netD_LA(fake_B2A) 213 | fake_GB_logit, fake_GB_cam_logit, _ = netD_GB(fake_A2B) 214 | fake_LB_logit, fake_LB_cam_logit, _ = netD_LB(fake_A2B) 215 | 216 | #損失の計算 217 | G_ad_loss_GA = MSE_loss(fake_GA_logit, torch.ones_like(fake_GA_logit).to(device)) 218 | G_ad_cam_loss_GA = MSE_loss(fake_GA_cam_logit, torch.ones_like(fake_GA_cam_logit).to(device)) 219 | G_ad_loss_LA = MSE_loss(fake_LA_logit, torch.ones_like(fake_LA_logit).to(device)) 220 | G_ad_cam_loss_LA = MSE_loss(fake_LA_cam_logit, torch.ones_like(fake_LA_cam_logit).to(device)) 221 | G_ad_loss_GB = MSE_loss(fake_GB_logit, torch.ones_like(fake_GB_logit).to(device)) 222 | G_ad_cam_loss_GB = MSE_loss(fake_GB_cam_logit, torch.ones_like(fake_GB_cam_logit).to(device)) 223 | G_ad_loss_LB = MSE_loss(fake_LB_logit, torch.ones_like(fake_LB_logit).to(device)) 224 | G_ad_cam_loss_LB = MSE_loss(fake_LB_cam_logit, torch.ones_like(fake_LB_cam_logit).to(device)) 225 | 226 | G_recon_loss_A = L1_loss(fake_A2B2A, real_A) 227 | G_recon_loss_B = L1_loss(fake_B2A2B, real_B) 228 | 229 | G_identity_loss_A = L1_loss(fake_A2A, real_A) 230 | G_identity_loss_B = L1_loss(fake_B2B, real_B) 231 | 232 | G_cam_loss_A = BCE_loss(fake_B2A_cam_logit, torch.ones_like(fake_B2A_cam_logit).to(device)) + BCE_loss(fake_A2A_cam_logit, torch.zeros_like(fake_A2A_cam_logit).to(device)) 233 | G_cam_loss_B = BCE_loss(fake_A2B_cam_logit, torch.ones_like(fake_A2B_cam_logit).to(device)) + BCE_loss(fake_B2B_cam_logit, torch.zeros_like(fake_B2B_cam_logit).to(device)) 234 | 235 | G_loss_A = 1*(G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA) + 10*G_recon_loss_A + 10*G_identity_loss_A + 1000*G_cam_loss_A 236 | G_loss_B = 1*(G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB) + 10*G_recon_loss_B + 10*G_identity_loss_B + 1000*G_cam_loss_B 237 | 238 | Generator_loss = G_loss_A + G_loss_B 239 | #傾きを計算して 240 | Generator_loss.backward() 241 | #generatorのパラメーターを更新 242 | optimizerG.step() 243 | 244 | #Generatorに使われているAdaILNとILNの、rhoの範囲を制限する 245 | netG_A2B.apply(Rho_Clipper) 246 | netG_B2A.apply(Rho_Clipper) 247 | 248 | #後でグラフに出力するために記録 249 | G_losses_per_epoch.append(Generator_loss.item()) 250 | 251 | #学習状況をシェルに出力 252 | if iteration % 50 == 0: 253 | print('[%d/%d][iteration:%d]\tLoss_D: %.4f\tLoss_G: %.4f' 254 | % (epoch,num_epochs,iteration, 255 | Discriminator_loss.item(),Generator_loss.item())) 256 | 257 | iteration += 1 258 | #テスト用break 259 | #break 260 | 261 | #後で出力するためにepochごとにlossの平均を取り記録 262 | G_losses.append(torch.mean(torch.tensor(G_losses_per_epoch,dtype=torch.float64)).item()) 263 | D_losses.append(torch.mean(torch.tensor(D_losses_per_epoch,dtype=torch.float64)).item()) 264 | #Generatorの学習状況を画像として記録 265 | if (epoch % output_progress_interval == 0 or (epoch+1)==num_epochs): 266 | #ネットワークを推論モードにする 267 | netG_A2B.eval(),netG_B2A.eval() 268 | netD_GA.eval(),netD_GB.eval() 269 | netD_LA.eval(),netD_LB.eval() 270 | #画像出力用ディレクトリがなければ作成 271 | output_dir = "./output/epoch_{}".format(epoch+1) 272 | if not os.path.exists(output_dir): 273 | os.makedirs(output_dir) 274 | #画像の生成と出力 275 | #デバイスに配置されている画像をcpuに移す関数 276 | def move_to_cpu(imgs_on_device): 277 | imgs_on_cpu = [] 278 | for im in imgs_on_device: 279 | imgs_on_cpu.append(im.detach().cpu()) 280 | return tuple(imgs_on_cpu) 281 | 282 | fake_A2B, _, fake_A2B_heatmap = netG_A2B(sample_real_A) 283 | fake_B2A, _, fake_B2A_heatmap = netG_B2A(sample_real_B) 284 | 285 | fake_A2B2A, _, fake_A2B2A_heatmap = netG_B2A(fake_A2B) 286 | fake_B2A2B, _, fake_B2A2B_heatmap = netG_A2B(fake_B2A) 287 | 288 | fake_A2B,fake_A2B_heatmap = move_to_cpu([fake_A2B,fake_A2B_heatmap]) 289 | fake_B2A,fake_B2A_heatmap = move_to_cpu([fake_B2A,fake_B2A_heatmap]) 290 | fake_A2B2A,fake_A2B2A_heatmap = move_to_cpu([fake_A2B2A,fake_A2B2A_heatmap]) 291 | fake_B2A2B,fake_B2A2B_heatmap = move_to_cpu([fake_B2A2B,fake_B2A2B_heatmap]) 292 | 293 | fake_A2A, _, fake_A2A_heatmap = netG_B2A(sample_real_A) 294 | fake_B2B, _, fake_B2B_heatmap = netG_A2B(sample_real_B) 295 | 296 | fake_A2A,fake_A2A_heatmap = move_to_cpu([fake_A2A,fake_A2A_heatmap]) 297 | fake_B2B,fake_B2B_heatmap = move_to_cpu([fake_B2B,fake_B2B_heatmap]) 298 | sr_A,sr_B = move_to_cpu([sample_real_A,sample_real_B]) 299 | #A->B->Aの画像の出力 300 | output_how_much_progress("./output/epoch_{}/conversion_A2B2A.png".format(epoch+1),[sr_A,fake_A2B,fake_A2B2A]) 301 | #B->A->Bの画像の出力 302 | output_how_much_progress("./output/epoch_{}/conversion_B2A2B.png".format(epoch+1),[sr_B,fake_B2A,fake_B2A2B]) 303 | #ヒートマップ(A)の出力 304 | fake_A2A_heatmap = F.interpolate(fake_A2A_heatmap,size=(256,256)) 305 | fake_A2B_heatmap = F.interpolate(fake_A2B_heatmap,size=(256,256)) 306 | fake_A2B2A_heatmap = F.interpolate(fake_A2B2A_heatmap,size=(256,256)) 307 | output_how_much_progress("./output/epoch_{}/heatmap_A.png".format(epoch+1).format(epoch+1),[fake_A2A_heatmap,fake_A2B_heatmap,fake_A2B2A_heatmap]) 308 | #ヒートマップ(B)の出力 309 | fake_B2B_heatmap = F.interpolate(fake_B2B_heatmap,size=(256,256)) 310 | fake_B2A_heatmap = F.interpolate(fake_B2A_heatmap,size=(256,256)) 311 | fake_B2A2B_heatmap = F.interpolate(fake_B2A2B_heatmap,size=(256,256)) 312 | output_how_much_progress("./output/epoch_{}/heatmap_B.png".format(epoch+1).format(epoch+1),[fake_B2B_heatmap,fake_B2A_heatmap,fake_B2A2B_heatmap]) 313 | #テスト用break 314 | #break 315 | 316 | #学習にかかった時間を出力 317 | #学習終了時の時間を記録 318 | t_epoch_finish = time.time() 319 | total_time = t_epoch_finish - t_epoch_start 320 | with open('./output/time.txt', mode='w') as f: 321 | f.write("total_time: {:.4f} sec.\n".format(total_time)) 322 | f.write("dataset_A size: {}\n".format(len(path_list_A))) 323 | f.write("dataset_B size: {}\n".format(len(path_list_B))) 324 | f.write("num_epochs: {}\n".format(num_epochs)) 325 | f.write("batch_size: {}\n".format(batch_size)) 326 | 327 | #学習済みGeneratorのモデル(CPU向け)を出力 328 | #モデル出力用ディレクトリがなければ作成 329 | output_dir = "./trained_model" 330 | if not os.path.exists(output_dir): 331 | os.makedirs(output_dir) 332 | torch.save(netG_A2B.to('cpu').state_dict(),'./trained_model/generator_A2B_trained_model_cpu.pth') 333 | torch.save(netG_B2A.to('cpu').state_dict(),'./trained_model/generator_B2A_trained_model_cpu.pth') 334 | 335 | #lossのグラフを出力 336 | plt.clf() 337 | plt.figure(figsize=(10,5)) 338 | plt.title("Generator and Discriminator Loss During Training") 339 | plt.plot(G_losses,label="G") 340 | plt.plot(D_losses,label="D") 341 | plt.xlabel("epoch") 342 | plt.ylabel("Loss") 343 | plt.legend() 344 | plt.savefig('./output/loss.png') 345 | 346 | -------------------------------------------------------------------------------- /example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zassou65535/image_converter/4dd9a76154d43e6c339d5953b7b57287fd8e8204/example.png -------------------------------------------------------------------------------- /module/base_module.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | 3 | from .importer import * 4 | 5 | class ResnetBlock(nn.Module): 6 | def __init__(self, dim, use_bias): 7 | super(ResnetBlock,self).__init__() 8 | self.conv_block = nn.Sequential( 9 | nn.ReflectionPad2d(1), 10 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias), 11 | nn.InstanceNorm2d(dim), 12 | nn.ReLU(True), 13 | 14 | nn.ReflectionPad2d(1), 15 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias), 16 | nn.InstanceNorm2d(dim) 17 | ) 18 | 19 | def forward(self, x): 20 | out = x + self.conv_block(x) 21 | return out 22 | 23 | #上のResnetBlockのInstanceNorm2dを、後述のadaILNに差し替えたもの 24 | class ResnetAdaILNBlock(nn.Module): 25 | def __init__(self, dim, use_bias): 26 | super(ResnetAdaILNBlock,self).__init__() 27 | self.pad1 = nn.ReflectionPad2d(1) 28 | self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias) 29 | self.norm1 = adaILN(dim) 30 | self.relu1 = nn.ReLU(True) 31 | 32 | self.pad2 = nn.ReflectionPad2d(1) 33 | self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias) 34 | self.norm2 = adaILN(dim) 35 | 36 | def forward(self, x, gamma, beta): 37 | out = self.pad1(x) 38 | out = self.conv1(out) 39 | out = self.norm1(out, gamma, beta) 40 | out = self.relu1(out) 41 | out = self.pad2(out) 42 | out = self.conv2(out) 43 | out = self.norm2(out, gamma, beta) 44 | return out + x 45 | 46 | #入力されたTensorに対し 47 | #各チャネル別々に正規化したものと、そのレイヤー内でいっぺんに正規化をかけたもの 48 | #双方を割合rhoで混ぜ合わせ出力とする層 49 | #rhoは学習可能 50 | class adaILN(nn.Module): 51 | def __init__(self, num_features, eps=1e-5): 52 | super(adaILN,self).__init__() 53 | self.eps = eps 54 | #rhoを学習可能にする 55 | self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) 56 | #初期値0.9 57 | self.rho.data.fill_(0.9) 58 | 59 | def forward(self, input, gamma, beta): 60 | #各チャネル別々に、特徴マップに対し正規化をかける 61 | #各チャネル別々に、特徴マップの縦横全ての平均と分散を計算 62 | in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True) 63 | out_in = (input - in_mean) / torch.sqrt(in_var + self.eps) 64 | #レイヤーの全部の特徴マップに対し正規化をかける 65 | #特徴マップのチャネル+縦横全ての平均と分散を計算(そのレイヤーの全部の値に対して計算) 66 | ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3], keepdim=True), torch.var(input, dim=[1, 2, 3], keepdim=True) 67 | out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps) 68 | #self.rhoを割合としてout_inとout_lnを混ぜ合わせ出力とする 69 | out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln 70 | out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3) 71 | return out 72 | 73 | #adaILNの、gamma,betaを学習によって決められるようにしたバージョン 74 | class ILN(nn.Module): 75 | def __init__(self, num_features, eps=1e-5): 76 | super(ILN,self).__init__() 77 | self.eps = eps 78 | self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) 79 | self.gamma = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) 80 | self.beta = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) 81 | self.rho.data.fill_(0.0) 82 | self.gamma.data.fill_(1.0) 83 | self.beta.data.fill_(0.0) 84 | 85 | def forward(self, input): 86 | in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True) 87 | out_in = (input - in_mean) / torch.sqrt(in_var + self.eps) 88 | ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3], keepdim=True), torch.var(input, dim=[1, 2, 3], keepdim=True) 89 | out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps) 90 | out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln 91 | out = out * self.gamma.expand(input.shape[0], -1, -1, -1) + self.beta.expand(input.shape[0], -1, -1, -1) 92 | return out 93 | 94 | #Model.apply(Rho_Clipper)とすることで、 95 | #Model中に含まれるrhoの値を[min,max]の範囲に制限できる 96 | class RhoClipper(object): 97 | def __init__(self, min, max): 98 | self.clip_min = min 99 | self.clip_max = max 100 | assert min < max 101 | def __call__(self, module): 102 | if hasattr(module, 'rho'): 103 | w = module.rho.data 104 | w = w.clamp(self.clip_min, self.clip_max) 105 | module.rho.data = w 106 | -------------------------------------------------------------------------------- /module/dataloader.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | 3 | from .importer import * 4 | 5 | def make_datapath_list(target_path): 6 | #読み込むデータセットのパス 7 | #例えば 8 | #target_path = "./dataset/**/*"などとします 9 | #画像のファイル形式はpng 10 | path_list = []#画像ファイルパスのリストを作り、戻り値とする 11 | for path in glob.glob(target_path,recursive=True): 12 | if os.path.isfile(path): 13 | path_list.append(path) 14 | ##読み込むパスを全部表示 必要ならコメントアウトを外す 15 | #print(path) 16 | #読み込んだ画像の数を表示 17 | print("images : " + str(len(path_list))) 18 | path_list = sorted(path_list) 19 | return path_list 20 | 21 | class ImageTransform(): 22 | #画像の前処理クラス 23 | def __init__(self,resize_pixel): 24 | self.data_transform = transforms.Compose([ 25 | transforms.Resize((resize_pixel,resize_pixel)), 26 | transforms.ToTensor(), 27 | transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5]) 28 | ]) 29 | def __call__(self,img): 30 | return self.data_transform(img) 31 | 32 | class ImageModification(): 33 | #画像の前処理クラス 画像の平行移動、拡大縮小なども行う 34 | def __init__(self,resize_pixel,x_move=[-0.1,0.1],y_move=[-0.1,0.2],min_scale=0.75): 35 | self.resize_pixel = resize_pixel 36 | self.x_move = x_move 37 | self.y_move = y_move 38 | self.min_scale = min_scale 39 | self.data_resize = transforms.Resize((resize_pixel*2,resize_pixel*2)) 40 | self.data_arrange = transforms.Compose([ 41 | transforms.Resize((resize_pixel,resize_pixel)), 42 | transforms.ToTensor(), 43 | transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5]) 44 | ]) 45 | def __call__(self,img): 46 | img = self.data_resize(img) 47 | #中心(y,x)を何pixelずらすかを指定 48 | move_pixel_x = np.random.uniform(self.resize_pixel*self.x_move[0],self.resize_pixel*self.x_move[1]) 49 | move_pixel_y = np.random.uniform(self.resize_pixel*self.y_move[0],self.resize_pixel*self.y_move[1]) 50 | move_pixel = [move_pixel_x,move_pixel_y] 51 | #ずらす 52 | img = transforms.functional.affine(img,angle=0,translate=(move_pixel),scale=1,shear=0) 53 | #切り取る 54 | max_crop_size = 2*(self.resize_pixel - np.max(np.abs(move_pixel))) 55 | min_crop_size = 2*(self.resize_pixel*self.min_scale) 56 | crop_size = np.random.randint(min_crop_size,max_crop_size) 57 | img = transforms.functional.center_crop(img,crop_size) 58 | #Tensorに変換して出力 59 | img = self.data_arrange(img) 60 | return img 61 | 62 | class GAN_Img_Dataset(data.Dataset): 63 | #画像のデータセットクラス 64 | def __init__(self,file_list,transform): 65 | self.file_list = file_list 66 | self.transform = transform 67 | #画像の枚数を返す 68 | def __len__(self): 69 | return len(self.file_list) 70 | #前処理済み画像の、Tensor形式のデータを取得 71 | def __getitem__(self,index): 72 | img_path = self.file_list[index] 73 | img = Image.open(img_path)#[RGB][高さ][幅] 74 | img = img.convert('RGB')#pngをjpg形式に変換 75 | img_transformed = self.transform(img) 76 | return img_transformed 77 | 78 | #動作確認 79 | # path_list = make_datapath_list("../dataset/group_B/**/*") 80 | 81 | # transform = ImageModification(resize_pixel=256,x_move=[-0.1,0.1],y_move=[-0.1,0.25],min_scale=0.7) 82 | # dataset = GAN_Img_Dataset(file_list=path_list,transform=transform) 83 | 84 | # batch_size = 8 85 | # dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=False) 86 | 87 | # imgs = next(iter(dataloader)) 88 | # print(imgs.size()) 89 | 90 | # for i,img_transformed in enumerate(imgs): 91 | # img_transformed = img_transformed.detach() 92 | # vutils.save_image(img_transformed,"../output/test_img_{}.png".format(i),normalize=True) 93 | 94 | 95 | 96 | -------------------------------------------------------------------------------- /module/discriminator.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | 3 | from .importer import * 4 | from .base_module import * 5 | 6 | class Discriminator(nn.Module): 7 | def __init__(self, input_nc=3, ndf=64, n_layers=5): 8 | super(Discriminator,self).__init__() 9 | #レイヤーは全部でn_layers個 10 | #最初の1レイヤー目 11 | model = [nn.ReflectionPad2d(1), 12 | nn.utils.spectral_norm( 13 | nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=0, bias=True)), 14 | nn.LeakyReLU(0.2, True)] 15 | 16 | #(n_layers-2)個の中間のレイヤー 17 | for i in range(1, n_layers - 2): 18 | mult = 2 ** (i - 1) 19 | model += [nn.ReflectionPad2d(1), 20 | nn.utils.spectral_norm( 21 | nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=2, padding=0, bias=True)), 22 | nn.LeakyReLU(0.2, True)] 23 | 24 | #最後のレイヤー 25 | mult = 2 ** (n_layers - 2 - 1) 26 | model += [nn.ReflectionPad2d(1), 27 | nn.utils.spectral_norm( 28 | nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=1, padding=0, bias=True)), 29 | nn.LeakyReLU(0.2, True)] 30 | 31 | self.model = nn.Sequential(*model) 32 | 33 | #CAM用モジュール類を定義 34 | mult = 2 ** (n_layers - 2) 35 | self.gap_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False)) 36 | self.gmp_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False)) 37 | self.conv1x1 = nn.Conv2d(ndf * mult * 2, ndf * mult, kernel_size=1, stride=1, bias=True) 38 | self.leaky_relu = nn.LeakyReLU(0.2, True) 39 | 40 | self.pad = nn.ReflectionPad2d(1) 41 | self.conv = nn.utils.spectral_norm(nn.Conv2d(ndf * mult, 1, kernel_size=4, stride=1, padding=0, bias=False)) 42 | 43 | def forward(self, input): 44 | #論文中で、Encoderと呼ばれている層を通す 45 | x = self.model(input) 46 | 47 | #平均を取る操作によって 48 | # x : torch.Size([batch_size,channel,Height,Width])を 49 | #gap : torch.Size([batch_size,channel,1,1])に変換する 50 | # = Global Average Poolingを実行 51 | gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) 52 | #gap_logit : torch.Size([batch_size,1]) 53 | gap_logit = self.gap_fc(gap.view(x.shape[0], -1)) 54 | #gap_weight : torch.Size([1,channel]) 55 | gap_weight = list(self.gap_fc.parameters())[0] 56 | #gap_weight.unsqueeze(2).unsqueeze(3) : torch.Size([1,channel,1,1]) 57 | #gap : torch.Size([batch_size,channel,Height,Width]) 58 | gap = x * gap_weight.unsqueeze(2).unsqueeze(3) 59 | 60 | #最大値を取る操作によって 61 | # x : torch.Size([batch_size,channel,Height,Width])を 62 | #gmp : torch.Size([batch_size,channel,1,1])に変換する 63 | # = Global Max Poolingを実行 64 | gmp = torch.nn.functional.adaptive_max_pool2d(x, 1) 65 | #gmp_logit : torch.Size([batch_size,1]) 66 | gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1)) 67 | #gmp_weight : torch.Size([1,channel]) 68 | gmp_weight = list(self.gmp_fc.parameters())[0] #元々[1]以降は存在しない つまりlen(list(self.gmp_fc.parameters())) = 1 69 | #gmp_weight.unsqueeze(2).unsqueeze(3) : torch.Size([1,channel,1,1]) 70 | #gmp : torch.Size([batch_size,channel,Height,Width]) 71 | gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3) 72 | 73 | #cam_logit : torch.Size([batch_size,2]) 74 | cam_logit = torch.cat([gap_logit, gmp_logit], 1) 75 | x = torch.cat([gap, gmp], 1) 76 | #この時点で 77 | # x : torch.Size([batch_size,channel*2,Height,Width]) 78 | x = self.leaky_relu(self.conv1x1(x)) 79 | #この時点で 80 | # x : torch.Size([batch_size,channel,Height,Width]) 81 | 82 | #heatmap : torch.Size([batch_size,1,Height,Width]) 83 | heatmap = torch.sum(x, dim=1, keepdim=True) 84 | 85 | #self.pad = nn.ReflectionPad2d(1)を適用 86 | x = self.pad(x) 87 | #この時点で 88 | # x : torch.Size([batch_size,channel,Height+2,Width+2]) 89 | out = self.conv(x) 90 | #out : torch.Size([batch_size,1,Height-1,Width-1]) 91 | 92 | #out : torch.Size([batch_size,1,Height-1,Width-1]) 93 | #cam_logit : torch.Size([batch_size,2]) 94 | #heatmap : torch.Size([batch_size,1,Height,Width]) 95 | return out, cam_logit, heatmap 96 | -------------------------------------------------------------------------------- /module/generator.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | 3 | from .importer import * 4 | from .base_module import * 5 | 6 | class Generator(nn.Module): 7 | def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=6, img_size=256): 8 | super(Generator,self).__init__() 9 | self.input_nc = input_nc 10 | self.output_nc = output_nc 11 | self.ngf = ngf 12 | self.n_blocks = n_blocks 13 | self.img_size = img_size 14 | 15 | DownBlock = [] 16 | DownBlock += [nn.ReflectionPad2d(3), 17 | nn.Conv2d(input_nc, ngf, kernel_size=7, stride=1, padding=0, bias=False), 18 | nn.InstanceNorm2d(ngf), 19 | nn.ReLU(True)] 20 | 21 | #2回分の畳み込み層を入れる 22 | n_downsampling = 2 23 | for i in range(n_downsampling): 24 | mult = 2**i 25 | DownBlock += [nn.ReflectionPad2d(1), 26 | nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=0, bias=False), 27 | nn.InstanceNorm2d(ngf * mult * 2), 28 | nn.ReLU(True)] 29 | 30 | #ResnetBlockをn_blocks=6個分入れる 31 | mult = 2**n_downsampling 32 | for i in range(n_blocks): 33 | DownBlock += [ResnetBlock(ngf * mult, use_bias=False)] 34 | 35 | #論文中で、Encoderと呼ばれている箇所を作成 36 | self.DownBlock = nn.Sequential(*DownBlock) 37 | 38 | # Class Activation Map 39 | self.gap_fc = nn.Linear(ngf * mult, 1, bias=False) 40 | self.gmp_fc = nn.Linear(ngf * mult, 1, bias=False) 41 | self.conv1x1 = nn.Conv2d(ngf * mult * 2, ngf * mult, kernel_size=1, stride=1, bias=True) 42 | self.relu = nn.ReLU(True) 43 | 44 | #γとβ用のモジュールを作成 45 | FC = [nn.Linear(img_size // mult * img_size // mult * ngf * mult, ngf * mult, bias=False), 46 | nn.ReLU(True), 47 | nn.Linear(ngf * mult, ngf * mult, bias=False), 48 | nn.ReLU(True)] 49 | self.gamma = nn.Linear(ngf * mult, ngf * mult, bias=False) 50 | self.beta = nn.Linear(ngf * mult, ngf * mult, bias=False) 51 | 52 | self.FC = nn.Sequential(*FC) 53 | 54 | #Up-Samplingするための箇所を作成(論文中でDecoderと呼ばれている箇所) 55 | #AdaILN入りのResnetBlockをn_blocks=6個分作る 56 | for i in range(n_blocks): 57 | setattr(self, 'UpBlock1_' + str(i+1), ResnetAdaILNBlock(ngf * mult, use_bias=False)) 58 | 59 | #n_downsampling=2回分、Up-Samplingする層を追加 60 | UpBlock2 = [] 61 | for i in range(n_downsampling): 62 | mult = 2**(n_downsampling - i) 63 | UpBlock2 += [nn.Upsample(scale_factor=2, mode='nearest'), 64 | nn.ReflectionPad2d(1), 65 | nn.Conv2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=1, padding=0, bias=False), 66 | ILN(int(ngf * mult / 2)), 67 | nn.ReLU(True)] 68 | 69 | UpBlock2 += [nn.ReflectionPad2d(3), 70 | nn.Conv2d(ngf, output_nc, kernel_size=7, stride=1, padding=0, bias=False), 71 | nn.Tanh()] 72 | 73 | #論文中で、Decoderと呼ばれている箇所を作成 74 | self.UpBlock2 = nn.Sequential(*UpBlock2) 75 | 76 | def forward(self, input): 77 | #論文中で、Encoderと呼ばれている層に通す 78 | x = self.DownBlock(input) 79 | 80 | gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) 81 | gap_logit = self.gap_fc(gap.view(x.shape[0], -1)) 82 | gap_weight = list(self.gap_fc.parameters())[0] 83 | gap = x * gap_weight.unsqueeze(2).unsqueeze(3) 84 | 85 | gmp = torch.nn.functional.adaptive_max_pool2d(x, 1) 86 | gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1)) 87 | gmp_weight = list(self.gmp_fc.parameters())[0] 88 | gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3) 89 | 90 | cam_logit = torch.cat([gap_logit, gmp_logit], 1) 91 | x = torch.cat([gap, gmp], 1) 92 | x = self.relu(self.conv1x1(x)) 93 | 94 | heatmap = torch.sum(x, dim=1, keepdim=True) 95 | 96 | #γとβを求める 97 | x_ = self.FC(x.view(x.shape[0], -1)) 98 | gamma, beta = self.gamma(x_), self.beta(x_) 99 | 100 | #論文中で、Decoderと呼ばれている層に通す 101 | #n_blocks=6個分のResnetBlockに通す 102 | for i in range(self.n_blocks): 103 | x = getattr(self, 'UpBlock1_' + str(i+1))(x, gamma, beta) 104 | #Upsample 105 | out = self.UpBlock2(x) 106 | 107 | return out, cam_logit, heatmap 108 | -------------------------------------------------------------------------------- /module/importer.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | 3 | import glob 4 | import os as os 5 | import os.path as osp 6 | import random 7 | import numpy as np 8 | import json 9 | from PIL import Image 10 | from tqdm import tqdm 11 | import matplotlib as mpl 12 | mpl.use('Agg')# AGG(Anti-Grain Geometry engine) 13 | import matplotlib.pyplot as plt 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.optim as optim 18 | import torch.utils.data as data 19 | import torchvision 20 | from torchvision import models,transforms 21 | import torch.nn.init as init 22 | from torch.autograd import Function 23 | import torch.nn.functional as F 24 | import torchvision.utils as vutils 25 | 26 | import xml.etree.ElementTree as ET 27 | import itertools 28 | from math import sqrt 29 | import time 30 | import sys 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cycler==0.10.0 2 | kiwisolver==1.1.0 3 | matplotlib==3.1.3 4 | numpy 5 | opencv-python==4.2.0.32 6 | pandas==1.0.1 7 | pillow>=8.3.2 8 | pyparsing==2.4.6 9 | python-dateutil==2.8.1 10 | pytz==2019.3 11 | six==1.14.0 12 | tqdm==4.42.1 13 | --------------------------------------------------------------------------------