├── metrics ├── __init__.py ├── README.md ├── reconstruction.py ├── calc_inception.py ├── local_editing.py ├── fid.py └── inception.py ├── training ├── __init__.py ├── lpips │ ├── weights │ │ ├── v0.0 │ │ │ ├── alex.pth │ │ │ ├── vgg.pth │ │ │ └── squeeze.pth │ │ └── v0.1 │ │ │ ├── alex.pth │ │ │ ├── vgg.pth │ │ │ └── squeeze.pth │ ├── base_model.py │ ├── pretrained_networks.py │ ├── __init__.py │ └── networks_basic.py ├── op │ ├── __init__.py │ ├── fused_bias_act.cpp │ ├── upfirdn2d.cpp │ ├── fused_act.py │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── dataset_ddp.py └── dataset.py ├── demo ├── .gitignore ├── static │ └── components │ │ ├── img │ │ ├── eraser.png │ │ ├── celeba_hq │ │ │ ├── 102885.jpg │ │ │ ├── 103999.jpg │ │ │ ├── 36_real.png │ │ │ ├── 54_real.png │ │ │ ├── 60_real.png │ │ │ ├── 67_real.png │ │ │ ├── 69_real.png │ │ │ └── 85_real.png │ │ ├── afhq │ │ │ ├── flickr_cat_000003.jpg │ │ │ ├── flickr_cat_000023.jpg │ │ │ ├── flickr_cat_000043.jpg │ │ │ ├── flickr_cat_000061.jpg │ │ │ ├── flickr_dog_000005.jpg │ │ │ ├── flickr_dog_000029.jpg │ │ │ ├── flickr_dog_000044.jpg │ │ │ ├── flickr_dog_000099.jpg │ │ │ ├── flickr_wild_000033.jpg │ │ │ ├── flickr_wild_000076.jpg │ │ │ ├── flickr_wild_000160.jpg │ │ │ ├── flickr_wild_000171.jpg │ │ │ ├── flickr_wild_000909.jpg │ │ │ ├── flickr_wild_000941.jpg │ │ │ └── flickr_wild_000964.jpg │ │ └── lsun │ │ │ ├── car │ │ │ ├── 00e30bfb6ac0e67150917f0e2fb06b7a1582e4a9.png │ │ │ ├── 0107fe0cf4eefd5094bbe8fb0c84e5c7b5c22f63.png │ │ │ ├── 015896f0a5959447958c42c15e1f1afcaef67c9e.png │ │ │ ├── 037107ca4d41ecdbca9e549e86926d6a2a1446c5.png │ │ │ ├── 03d76705df3a29347834dedda66ef0481bbc886d.png │ │ │ ├── 0614a3900c9766d271488235f1406dc24303525f.png │ │ │ ├── 08b62480b12cdc8a7e040b867be0518721098ac4.png │ │ │ ├── 08c9f4ec9cc4cef3300bea0048b4f5c365a31c78.png │ │ │ ├── 0926aebe03e5080d1adcb07e272995d568da9385.png │ │ │ ├── 0d419ebabf279b852243bfd8e5f0dbebed6a436f.png │ │ │ ├── 0f45c777ad4fe12f15ca6930edfb0a565c5676e7.png │ │ │ ├── 0ff36fa76703ae4f092ccdbb893f492c5080b60a.png │ │ │ ├── 1027a57db2caa8fc898ab2559b30a24a8157258f.png │ │ │ ├── 107f944fb978a3bf186268642ee7be6c432fb33e.png │ │ │ ├── 1093077a965f8df84170c4db3c0773236e5c0b24.png │ │ │ ├── 12d95ca342f4a69377485cfbf506c4d3b4f6c22c.png │ │ │ ├── 188942e55cb6a1e691b7b55a895e82c64275e613.png │ │ │ └── 1f4119cc547f5d7e24cf479a12b302ea100833a8.png │ │ │ └── church_outdoor │ │ │ ├── 004443093196d8d7ea0812986ba199efaa43a80c.png │ │ │ ├── 008f3dca6af124e4f4fdf95b0af49f4a37e4baf1.png │ │ │ ├── 00d6e3a86d5057ebf2861d2344db798bebd752c6.png │ │ │ ├── 0334103b18b1df87e54a65bf381fb8637b3fd8a9.png │ │ │ ├── 036741824f3cda9f38ad4462c3320086ea6b0f61.png │ │ │ ├── 03ab4af1907cc92db3eeab31705947e2b048ac1b.png │ │ │ ├── 03ba8b2062a5abdbe5308ae26c05fa7f07b40c4c.png │ │ │ ├── 03d22ab2109d199626540e73de8c025a5523cbea.png │ │ │ ├── 042af7ff14b7dcbc4b2d332290112caceed06e9f.png │ │ │ ├── 04c8ef914dde49827e857a6ab875d0e715a7fc4e.png │ │ │ ├── 05945271e611aaf8d204d7aed7b36aaf994f7a0d.png │ │ │ ├── 05a1047e68e92aefad7c8e38d2734d0368009d11.png │ │ │ ├── 06e346fc835a06e815d43538790c4361710156d4.png │ │ │ ├── 081d1b08c1e258163959f61940059ed11f63c25d.png │ │ │ ├── 0877034e3766ef1a6eab3b1a433daf82da487b9e.png │ │ │ ├── 09692e955d1caaa8f28d08dfcac8795de98d2cbf.png │ │ │ ├── 09fe47f9bd401f6805b754fa5ec333b9eb350374.png │ │ │ ├── 0a113c2718b0e735eca901f2555f34923b23f5ee.png │ │ │ ├── 0a2b4a195b22a8a1a720fdd48161941e536a010d.png │ │ │ ├── 0a7b78f0446d76a60e308a993000370cc017ab46.png │ │ │ ├── 0ae0b61a9f3e0e21f8d912ed9e305b1fe39555f5.png │ │ │ ├── 0af7632a0e267739e391716aace4187b10ce5f9a.png │ │ │ ├── 0bc79797d806698900aea13d6392f9c5d95bb973.png │ │ │ ├── 0be3712901cedaff96288df4f9e23f242b0453f1.png │ │ │ ├── 0c2cb4f4bbf6806b02e0aef92c9f29a69051b03c.png │ │ │ ├── 0cffcf748599ea5d3ce7e35b0cc6f10d41ea08fa.png │ │ │ ├── 0f391454ed6fa9d70d8e45802a97757c754ca0fe.png │ │ │ ├── 0f3ef9ab115540f3a336790df30a097bfec14be2.png │ │ │ ├── 0f8857dd9369f0dcf9a4247377859e50a1cab285.png │ │ │ └── 0fea9254a7d66c824e458d948c3dbe9b983fd228.png │ │ ├── css │ │ ├── image-picker.css │ │ ├── main.scss │ │ └── main.css │ │ └── js │ │ ├── image-picker.min.js │ │ └── main.js └── templates │ ├── layout.html │ └── index.html ├── assets ├── teaser.jpg └── teaser_video.jpg ├── semantic_manipulation ├── 0_neg_indices.npy ├── 0_pos_indices.npy ├── 10_neg_indices.npy ├── 10_pos_indices.npy ├── 11_neg_indices.npy ├── 11_pos_indices.npy ├── 12_neg_indices.npy ├── 12_pos_indices.npy ├── 13_neg_indices.npy ├── 13_pos_indices.npy ├── 14_neg_indices.npy ├── 14_pos_indices.npy ├── 15_neg_indices.npy ├── 15_pos_indices.npy ├── 16_neg_indices.npy ├── 16_pos_indices.npy ├── 17_neg_indices.npy ├── 17_pos_indices.npy ├── 18_neg_indices.npy ├── 18_pos_indices.npy ├── 19_neg_indices.npy ├── 19_pos_indices.npy ├── 1_neg_indices.npy ├── 1_pos_indices.npy ├── 20_neg_indices.npy ├── 20_pos_indices.npy ├── 21_neg_indices.npy ├── 21_pos_indices.npy ├── 22_neg_indices.npy ├── 22_pos_indices.npy ├── 23_neg_indices.npy ├── 23_pos_indices.npy ├── 24_neg_indices.npy ├── 24_pos_indices.npy ├── 25_neg_indices.npy ├── 25_pos_indices.npy ├── 26_neg_indices.npy ├── 26_pos_indices.npy ├── 27_neg_indices.npy ├── 27_pos_indices.npy ├── 28_neg_indices.npy ├── 28_pos_indices.npy ├── 29_neg_indices.npy ├── 29_pos_indices.npy ├── 2_neg_indices.npy ├── 2_pos_indices.npy ├── 30_neg_indices.npy ├── 30_pos_indices.npy ├── 31_neg_indices.npy ├── 31_pos_indices.npy ├── 32_neg_indices.npy ├── 32_pos_indices.npy ├── 33_neg_indices.npy ├── 33_pos_indices.npy ├── 34_neg_indices.npy ├── 34_pos_indices.npy ├── 35_neg_indices.npy ├── 35_pos_indices.npy ├── 36_neg_indices.npy ├── 36_pos_indices.npy ├── 37_neg_indices.npy ├── 37_pos_indices.npy ├── 38_neg_indices.npy ├── 38_pos_indices.npy ├── 39_neg_indices.npy ├── 39_pos_indices.npy ├── 3_neg_indices.npy ├── 3_pos_indices.npy ├── 4_neg_indices.npy ├── 4_pos_indices.npy ├── 5_neg_indices.npy ├── 5_pos_indices.npy ├── 6_neg_indices.npy ├── 6_pos_indices.npy ├── 7_neg_indices.npy ├── 7_pos_indices.npy ├── 8_neg_indices.npy ├── 8_pos_indices.npy ├── 9_neg_indices.npy └── 9_pos_indices.npy ├── install.sh ├── preprocessor ├── README.md ├── prepare_data.py └── pair_masks.py ├── .gitignore ├── download.sh ├── demo.py └── README.md /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /demo/.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | static/generated/ 3 | ckpt/ -------------------------------------------------------------------------------- /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/assets/teaser.jpg -------------------------------------------------------------------------------- /assets/teaser_video.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/assets/teaser_video.jpg -------------------------------------------------------------------------------- /training/lpips/weights/v0.0/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/training/lpips/weights/v0.0/alex.pth -------------------------------------------------------------------------------- /training/lpips/weights/v0.0/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/training/lpips/weights/v0.0/vgg.pth -------------------------------------------------------------------------------- /training/lpips/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/training/lpips/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /training/lpips/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/training/lpips/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /demo/static/components/img/eraser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/eraser.png -------------------------------------------------------------------------------- /semantic_manipulation/0_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/0_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/0_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/0_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/10_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/10_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/10_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/10_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/11_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/11_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/11_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/11_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/12_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/12_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/12_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/12_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/13_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/13_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/13_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/13_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/14_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/14_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/14_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/14_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/15_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/15_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/15_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/15_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/16_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/16_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/16_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/16_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/17_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/17_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/17_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/17_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/18_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/18_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/18_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/18_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/19_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/19_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/19_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/19_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/1_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/1_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/1_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/1_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/20_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/20_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/20_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/20_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/21_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/21_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/21_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/21_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/22_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/22_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/22_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/22_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/23_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/23_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/23_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/23_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/24_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/24_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/24_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/24_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/25_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/25_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/25_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/25_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/26_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/26_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/26_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/26_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/27_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/27_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/27_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/27_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/28_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/28_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/28_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/28_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/29_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/29_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/29_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/29_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/2_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/2_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/2_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/2_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/30_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/30_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/30_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/30_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/31_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/31_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/31_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/31_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/32_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/32_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/32_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/32_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/33_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/33_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/33_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/33_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/34_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/34_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/34_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/34_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/35_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/35_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/35_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/35_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/36_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/36_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/36_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/36_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/37_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/37_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/37_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/37_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/38_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/38_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/38_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/38_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/39_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/39_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/39_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/39_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/3_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/3_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/3_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/3_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/4_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/4_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/4_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/4_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/5_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/5_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/5_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/5_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/6_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/6_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/6_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/6_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/7_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/7_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/7_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/7_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/8_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/8_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/8_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/8_pos_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/9_neg_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/9_neg_indices.npy -------------------------------------------------------------------------------- /semantic_manipulation/9_pos_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/semantic_manipulation/9_pos_indices.npy -------------------------------------------------------------------------------- /training/lpips/weights/v0.0/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/training/lpips/weights/v0.0/squeeze.pth -------------------------------------------------------------------------------- /training/lpips/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/training/lpips/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /demo/static/components/img/celeba_hq/102885.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/celeba_hq/102885.jpg -------------------------------------------------------------------------------- /demo/static/components/img/celeba_hq/103999.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/celeba_hq/103999.jpg -------------------------------------------------------------------------------- /demo/static/components/img/celeba_hq/36_real.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/celeba_hq/36_real.png -------------------------------------------------------------------------------- /demo/static/components/img/celeba_hq/54_real.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/celeba_hq/54_real.png -------------------------------------------------------------------------------- /demo/static/components/img/celeba_hq/60_real.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/celeba_hq/60_real.png -------------------------------------------------------------------------------- /demo/static/components/img/celeba_hq/67_real.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/celeba_hq/67_real.png -------------------------------------------------------------------------------- /demo/static/components/img/celeba_hq/69_real.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/celeba_hq/69_real.png -------------------------------------------------------------------------------- /demo/static/components/img/celeba_hq/85_real.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/celeba_hq/85_real.png -------------------------------------------------------------------------------- /demo/static/components/img/afhq/flickr_cat_000003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/afhq/flickr_cat_000003.jpg -------------------------------------------------------------------------------- /demo/static/components/img/afhq/flickr_cat_000023.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/afhq/flickr_cat_000023.jpg -------------------------------------------------------------------------------- /demo/static/components/img/afhq/flickr_cat_000043.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/afhq/flickr_cat_000043.jpg -------------------------------------------------------------------------------- /demo/static/components/img/afhq/flickr_cat_000061.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/afhq/flickr_cat_000061.jpg -------------------------------------------------------------------------------- /demo/static/components/img/afhq/flickr_dog_000005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/afhq/flickr_dog_000005.jpg -------------------------------------------------------------------------------- /demo/static/components/img/afhq/flickr_dog_000029.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/afhq/flickr_dog_000029.jpg -------------------------------------------------------------------------------- /demo/static/components/img/afhq/flickr_dog_000044.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/afhq/flickr_dog_000044.jpg -------------------------------------------------------------------------------- /demo/static/components/img/afhq/flickr_dog_000099.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/afhq/flickr_dog_000099.jpg -------------------------------------------------------------------------------- /demo/static/components/img/afhq/flickr_wild_000033.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/afhq/flickr_wild_000033.jpg -------------------------------------------------------------------------------- /demo/static/components/img/afhq/flickr_wild_000076.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/afhq/flickr_wild_000076.jpg -------------------------------------------------------------------------------- /demo/static/components/img/afhq/flickr_wild_000160.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/afhq/flickr_wild_000160.jpg -------------------------------------------------------------------------------- /demo/static/components/img/afhq/flickr_wild_000171.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/afhq/flickr_wild_000171.jpg -------------------------------------------------------------------------------- /demo/static/components/img/afhq/flickr_wild_000909.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/afhq/flickr_wild_000909.jpg -------------------------------------------------------------------------------- /demo/static/components/img/afhq/flickr_wild_000941.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/afhq/flickr_wild_000941.jpg -------------------------------------------------------------------------------- /demo/static/components/img/afhq/flickr_wild_000964.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/afhq/flickr_wild_000964.jpg -------------------------------------------------------------------------------- /demo/static/components/img/lsun/car/00e30bfb6ac0e67150917f0e2fb06b7a1582e4a9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/car/00e30bfb6ac0e67150917f0e2fb06b7a1582e4a9.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/car/0107fe0cf4eefd5094bbe8fb0c84e5c7b5c22f63.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/car/0107fe0cf4eefd5094bbe8fb0c84e5c7b5c22f63.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/car/015896f0a5959447958c42c15e1f1afcaef67c9e.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/car/015896f0a5959447958c42c15e1f1afcaef67c9e.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/car/037107ca4d41ecdbca9e549e86926d6a2a1446c5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/car/037107ca4d41ecdbca9e549e86926d6a2a1446c5.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/car/03d76705df3a29347834dedda66ef0481bbc886d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/car/03d76705df3a29347834dedda66ef0481bbc886d.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/car/0614a3900c9766d271488235f1406dc24303525f.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/car/0614a3900c9766d271488235f1406dc24303525f.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/car/08b62480b12cdc8a7e040b867be0518721098ac4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/car/08b62480b12cdc8a7e040b867be0518721098ac4.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/car/08c9f4ec9cc4cef3300bea0048b4f5c365a31c78.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/car/08c9f4ec9cc4cef3300bea0048b4f5c365a31c78.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/car/0926aebe03e5080d1adcb07e272995d568da9385.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/car/0926aebe03e5080d1adcb07e272995d568da9385.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/car/0d419ebabf279b852243bfd8e5f0dbebed6a436f.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/car/0d419ebabf279b852243bfd8e5f0dbebed6a436f.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/car/0f45c777ad4fe12f15ca6930edfb0a565c5676e7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/car/0f45c777ad4fe12f15ca6930edfb0a565c5676e7.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/car/0ff36fa76703ae4f092ccdbb893f492c5080b60a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/car/0ff36fa76703ae4f092ccdbb893f492c5080b60a.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/car/1027a57db2caa8fc898ab2559b30a24a8157258f.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/car/1027a57db2caa8fc898ab2559b30a24a8157258f.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/car/107f944fb978a3bf186268642ee7be6c432fb33e.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/car/107f944fb978a3bf186268642ee7be6c432fb33e.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/car/1093077a965f8df84170c4db3c0773236e5c0b24.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/car/1093077a965f8df84170c4db3c0773236e5c0b24.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/car/12d95ca342f4a69377485cfbf506c4d3b4f6c22c.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/car/12d95ca342f4a69377485cfbf506c4d3b4f6c22c.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/car/188942e55cb6a1e691b7b55a895e82c64275e613.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/car/188942e55cb6a1e691b7b55a895e82c64275e613.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/car/1f4119cc547f5d7e24cf479a12b302ea100833a8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/car/1f4119cc547f5d7e24cf479a12b302ea100833a8.png -------------------------------------------------------------------------------- /training/op/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/op/__init__.py 3 | """ 4 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 5 | from .upfirdn2d import upfirdn2d -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/004443093196d8d7ea0812986ba199efaa43a80c.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/004443093196d8d7ea0812986ba199efaa43a80c.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/008f3dca6af124e4f4fdf95b0af49f4a37e4baf1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/008f3dca6af124e4f4fdf95b0af49f4a37e4baf1.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/00d6e3a86d5057ebf2861d2344db798bebd752c6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/00d6e3a86d5057ebf2861d2344db798bebd752c6.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/0334103b18b1df87e54a65bf381fb8637b3fd8a9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/0334103b18b1df87e54a65bf381fb8637b3fd8a9.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/036741824f3cda9f38ad4462c3320086ea6b0f61.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/036741824f3cda9f38ad4462c3320086ea6b0f61.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/03ab4af1907cc92db3eeab31705947e2b048ac1b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/03ab4af1907cc92db3eeab31705947e2b048ac1b.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/03ba8b2062a5abdbe5308ae26c05fa7f07b40c4c.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/03ba8b2062a5abdbe5308ae26c05fa7f07b40c4c.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/03d22ab2109d199626540e73de8c025a5523cbea.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/03d22ab2109d199626540e73de8c025a5523cbea.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/042af7ff14b7dcbc4b2d332290112caceed06e9f.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/042af7ff14b7dcbc4b2d332290112caceed06e9f.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/04c8ef914dde49827e857a6ab875d0e715a7fc4e.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/04c8ef914dde49827e857a6ab875d0e715a7fc4e.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/05945271e611aaf8d204d7aed7b36aaf994f7a0d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/05945271e611aaf8d204d7aed7b36aaf994f7a0d.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/05a1047e68e92aefad7c8e38d2734d0368009d11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/05a1047e68e92aefad7c8e38d2734d0368009d11.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/06e346fc835a06e815d43538790c4361710156d4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/06e346fc835a06e815d43538790c4361710156d4.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/081d1b08c1e258163959f61940059ed11f63c25d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/081d1b08c1e258163959f61940059ed11f63c25d.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/0877034e3766ef1a6eab3b1a433daf82da487b9e.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/0877034e3766ef1a6eab3b1a433daf82da487b9e.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/09692e955d1caaa8f28d08dfcac8795de98d2cbf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/09692e955d1caaa8f28d08dfcac8795de98d2cbf.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/09fe47f9bd401f6805b754fa5ec333b9eb350374.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/09fe47f9bd401f6805b754fa5ec333b9eb350374.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/0a113c2718b0e735eca901f2555f34923b23f5ee.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/0a113c2718b0e735eca901f2555f34923b23f5ee.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/0a2b4a195b22a8a1a720fdd48161941e536a010d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/0a2b4a195b22a8a1a720fdd48161941e536a010d.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/0a7b78f0446d76a60e308a993000370cc017ab46.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/0a7b78f0446d76a60e308a993000370cc017ab46.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/0ae0b61a9f3e0e21f8d912ed9e305b1fe39555f5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/0ae0b61a9f3e0e21f8d912ed9e305b1fe39555f5.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/0af7632a0e267739e391716aace4187b10ce5f9a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/0af7632a0e267739e391716aace4187b10ce5f9a.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/0bc79797d806698900aea13d6392f9c5d95bb973.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/0bc79797d806698900aea13d6392f9c5d95bb973.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/0be3712901cedaff96288df4f9e23f242b0453f1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/0be3712901cedaff96288df4f9e23f242b0453f1.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/0c2cb4f4bbf6806b02e0aef92c9f29a69051b03c.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/0c2cb4f4bbf6806b02e0aef92c9f29a69051b03c.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/0cffcf748599ea5d3ce7e35b0cc6f10d41ea08fa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/0cffcf748599ea5d3ce7e35b0cc6f10d41ea08fa.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/0f391454ed6fa9d70d8e45802a97757c754ca0fe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/0f391454ed6fa9d70d8e45802a97757c754ca0fe.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/0f3ef9ab115540f3a336790df30a097bfec14be2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/0f3ef9ab115540f3a336790df30a097bfec14be2.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/0f8857dd9369f0dcf9a4247377859e50a1cab285.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/0f8857dd9369f0dcf9a4247377859e50a1cab285.png -------------------------------------------------------------------------------- /demo/static/components/img/lsun/church_outdoor/0fea9254a7d66c824e458d948c3dbe9b983fd228.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/StyleMapGAN/HEAD/demo/static/components/img/lsun/church_outdoor/0fea9254a7d66c824e458d948c3dbe9b983fd228.png -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | conda install -y pytorch=1.4.0 torchvision=0.5.0 -c pytorch 4 | conda install -y numpy=1.18.1 scikit-image=0.16.2 tqdm 5 | conda install -y -c anaconda ipython=7.13.0 6 | pip install lmdb==0.98 opencv-python==4.2.0.34 munch==2.5.0 7 | pip install -U scikit-image==0.15.0 scipy==1.2.1 matplotlib scikit-learn 8 | pip install flask==1.0.2 pillow==7.0.0 -------------------------------------------------------------------------------- /preprocessor/README.md: -------------------------------------------------------------------------------- 1 | ## Dataset 2 | 3 | Transform raw images(ex. jpg, png) to LMDB format. Refer to `download.sh`. 4 | ``` 5 | python prepare_data.py [raw images path] --out [destination path] --size [TARGET_SIZE] 6 | 7 | ``` 8 | 9 | ## Local editing 10 | 11 | We use an overall mask of original and reference mask, so we need a pair of images which has a similar target mask with each other. `download.sh create-lmdb-dataset celeba_hq` already offers precalculated pairs of images for local editing. But you can pair your own images based on the similarity of target semantic(e.g., nose, hair) mask, please modify `pair_masks.py` for your purposes. 12 | ``` 13 | python pair_masks.py 14 | ``` -------------------------------------------------------------------------------- /training/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp 2 | 3 | #include 4 | 5 | 6 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 7 | int act, int grad, float alpha, float scale); 8 | 9 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 10 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 11 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 12 | 13 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 14 | int act, int grad, float alpha, float scale) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(bias); 17 | 18 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 23 | } -------------------------------------------------------------------------------- /demo/static/components/css/image-picker.css: -------------------------------------------------------------------------------- 1 | /* 2 | Refer to https://github.com/rvera/image-picker/blob/master/image-picker/image-picker.css 3 | */ 4 | 5 | ul.thumbnails.image_picker_selector { 6 | overflow: auto; 7 | list-style-image: none; 8 | list-style-position: outside; 9 | list-style-type: none; 10 | padding: 0px; 11 | margin: 0px; } 12 | ul.thumbnails.image_picker_selector ul { 13 | overflow: auto; 14 | list-style-image: none; 15 | list-style-position: outside; 16 | list-style-type: none; 17 | padding: 0px; 18 | margin: 0px; } 19 | ul.thumbnails.image_picker_selector li.group {width:100%;} 20 | ul.thumbnails.image_picker_selector li.group_title { 21 | float: none; } 22 | ul.thumbnails.image_picker_selector li { 23 | margin: 0px 12px 12px 0px; 24 | float: left; } 25 | ul.thumbnails.image_picker_selector li .thumbnail { 26 | padding: 6px; 27 | border: 1px solid #dddddd; 28 | -webkit-user-select: none; 29 | -moz-user-select: none; 30 | -ms-user-select: none; } 31 | ul.thumbnails.image_picker_selector li .thumbnail img { 32 | -webkit-user-drag: none; } 33 | ul.thumbnails.image_picker_selector li .thumbnail.selected { 34 | background: #0088cc; } 35 | -------------------------------------------------------------------------------- /training/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Refer to https : //github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp 2 | 3 | #include 4 | 5 | torch::Tensor 6 | upfirdn2d_op(const torch::Tensor &input, const torch::Tensor &kernel, 7 | int up_x, int up_y, int down_x, int down_y, 8 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 12 | #define CHECK_INPUT(x) \ 13 | CHECK_CUDA(x); \ 14 | CHECK_CONTIGUOUS(x) 15 | 16 | torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel, 17 | int up_x, int up_y, int down_x, int down_y, 18 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) 19 | { 20 | CHECK_CUDA(input); 21 | CHECK_CUDA(kernel); 22 | 23 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 24 | } 25 | 26 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 27 | { 28 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 29 | } -------------------------------------------------------------------------------- /training/lpips/base_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/lpips/base_model.py 3 | Refer to https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/trainer.py 4 | """ 5 | import os 6 | import torch 7 | from torch.autograd import Variable 8 | from pdb import set_trace as st 9 | from IPython import embed 10 | 11 | 12 | class BaseModel: 13 | def __init__(self): 14 | pass 15 | 16 | def name(self): 17 | return "BaseModel" 18 | 19 | def initialize(self, use_gpu=True, gpu_ids=[0]): 20 | self.use_gpu = use_gpu 21 | self.gpu_ids = gpu_ids 22 | 23 | def forward(self): 24 | pass 25 | 26 | def get_image_paths(self): 27 | pass 28 | 29 | def optimize_parameters(self): 30 | pass 31 | 32 | def get_current_visuals(self): 33 | return self.input 34 | 35 | def get_current_errors(self): 36 | return {} 37 | 38 | def save(self, label): 39 | pass 40 | 41 | # helper saving function that can be used by subclasses 42 | def save_network(self, network, path, network_label, epoch_label): 43 | save_filename = "%s_net_%s.pth" % (epoch_label, network_label) 44 | save_path = os.path.join(path, save_filename) 45 | torch.save(network.state_dict(), save_path) 46 | 47 | # helper loading function that can be used by subclasses 48 | def load_network(self, network, network_label, epoch_label): 49 | save_filename = "%s_net_%s.pth" % (epoch_label, network_label) 50 | save_path = os.path.join(self.save_dir, save_filename) 51 | print("Loading network from %s" % save_path) 52 | network.load_state_dict(torch.load(save_path)) 53 | 54 | def update_learning_rate(): 55 | pass 56 | 57 | def get_image_paths(self): 58 | return self.image_paths 59 | 60 | def save_done(self, flag=False): 61 | np.save(os.path.join(self.save_dir, "done_flag"), flag) 62 | np.savetxt( 63 | os.path.join(self.save_dir, "done_flag"), 64 | [ 65 | flag, 66 | ], 67 | fmt="%i", 68 | ) 69 | -------------------------------------------------------------------------------- /training/dataset_ddp.py: -------------------------------------------------------------------------------- 1 | """ 2 | StyleMapGAN 3 | Copyright (c) 2021-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | 11 | # Dataset code for the DDP training setting. 12 | 13 | from io import BytesIO 14 | import lmdb 15 | from PIL import Image 16 | from torch.utils.data import Dataset 17 | from torch.utils import data 18 | import numpy as np 19 | import random 20 | import re, os 21 | from torchvision import transforms 22 | import torch 23 | 24 | 25 | class MultiResolutionDataset(Dataset): 26 | def __init__(self, path, transform, resolution=256): 27 | self.path = path 28 | self.resolution = resolution 29 | self.transform = transform 30 | self.length = None 31 | 32 | def _open(self): 33 | self.env = lmdb.open( 34 | self.path, 35 | max_readers=32, 36 | readonly=True, 37 | lock=False, 38 | readahead=False, 39 | meminit=False, 40 | ) 41 | 42 | if not self.env: 43 | raise IOError(f"Cannot open lmdb dataset {self.path}") 44 | 45 | with self.env.begin(write=False) as txn: 46 | self.length = int(txn.get("length".encode("utf-8")).decode("utf-8")) 47 | 48 | def _close(self): 49 | if self.env is not None: 50 | self.env.close() 51 | self.env = None 52 | 53 | def __len__(self): 54 | if self.length is None: 55 | self._open() 56 | self._close() 57 | 58 | return self.length 59 | 60 | def __getitem__(self, index): 61 | if self.env is None: 62 | self._open() 63 | 64 | with self.env.begin(write=False) as txn: 65 | key = f"{self.resolution}-{str(index).zfill(5)}".encode("utf-8") 66 | img_bytes = txn.get(key) 67 | 68 | buffer = BytesIO(img_bytes) 69 | img = Image.open(buffer) 70 | img = self.transform(img) 71 | 72 | return img 73 | -------------------------------------------------------------------------------- /metrics/README.md: -------------------------------------------------------------------------------- 1 | ## Metrics 2 | 3 | * Reconstruction: LPIPS, MSE 4 | * W interpolation: FIDlerp 5 | * Generation: FID from Gaussian distribution 6 | * Local editing: MSEsrc, MSEref / Detectability (Refer to [CNNDetection](https://github.com/PeterWang512/CNNDetection)) 7 | 8 | 9 | All command lines should be run in `StyleMapGAN/` 10 | ```bash 11 | bash download.sh prepare-fid-calculation 12 | ``` 13 | 14 | Reconstruction 15 | ```bash 16 | # First, reconstruct images 17 | # python generate.py --ckpt expr/checkpoints/celeba_hq_256_8x8.pt --mixing_type reconstruction --test_lmdb data/celeba_hq/LMDB_test 18 | python -m metrics.reconstruction --image_folder_path expr/reconstruction/celeba_hq 19 | ``` 20 | 21 | W interpolation 22 | ```bash 23 | # First, interpolate images 24 | # python generate.py --ckpt expr/checkpoints/celeba_hq_256_8x8.pt --mixing_type w_interpolation --test_lmdb data/celeba_hq/LMDB_test 25 | 26 | # Second, precalculate mean and variance of dataset: fid/celeba_hq_stats_256_29000.pkl 27 | # But, we already provided them. 28 | # python -m metrics.fid --size 256 --dataset celeba_hq --generated_image_path data/celeba_hq/LMDB_train 29 | 30 | # CelebA-HQ 31 | python -m metrics.fid --comparative_fid_pkl metrics/fid_stats/celeba_hq_stats_256_29000.pkl --dataset celeba_hq --generated_image_path expr/w_interpolation/celeba_hq 32 | 33 | # AFHQ 34 | python -m metrics.fid --comparative_fid_pkl metrics/fid_stats/afhq_stats_256_15130.pkl --dataset afhq --generated_image_path expr/w_interpolation/afhq 35 | ``` 36 | 37 | Generation 38 | ```bash 39 | python -m metrics.fid --ckpt expr/checkpoints/celeba_hq_256_8x8.pt --comparative_fid_pkl metrics/fid_stats/celeba_hq_stats_256_29000.pkl --dataset celeba_hq 40 | ``` 41 | 42 | Local editing 43 | 44 | For MSEsrc, MSEref 45 | ```bash 46 | # First, generate local edited image 47 | # for part in nose hair background eye eyebrow lip neck cloth skin ear; do 48 | # python generate.py --ckpt expr/checkpoints/celeba_hq_256_8x8.pt --mixing_type local_editing --test_lmdb data/celeba_hq/LMDB_test --local_editing_part $part 49 | # done 50 | python -m metrics.local_editing --data_dir expr/local_editing/celeba_hq 51 | ``` -------------------------------------------------------------------------------- /demo/templates/layout.html: -------------------------------------------------------------------------------- 1 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | {% if title %} 13 | {{ title }} 14 | {% endif %} 15 | 16 | 18 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 31 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 45 | 46 | 47 | 48 | {% block content %}{% endblock %} 49 | 50 | 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | core* 3 | expr/ 4 | preprocessor/bisenet.ckpt 5 | metrics/pt_inception-2015-12-05-6726825d.pth 6 | metrics/fid_stats/ 7 | stylegan2-ffhq-config-f* 8 | archive/ 9 | 10 | # Demo 11 | **/__pycache__/ 12 | demo/static/generated/ 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.pyc 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | pip-wheel-metadata/ 38 | share/python-wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .nox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *.cover 64 | *.py,cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | db.sqlite3 76 | db.sqlite3-journal 77 | 78 | # Flask stuff: 79 | instance/ 80 | .webassets-cache 81 | 82 | # Scrapy stuff: 83 | .scrapy 84 | 85 | # Sphinx documentation 86 | docs/_build/ 87 | 88 | # PyBuilder 89 | target/ 90 | 91 | # Jupyter Notebook 92 | .ipynb_checkpoints 93 | 94 | # IPython 95 | profile_default/ 96 | ipython_config.py 97 | 98 | # pyenv 99 | .python-version 100 | 101 | # pipenv 102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 105 | # install all needed dependencies. 106 | #Pipfile.lock 107 | 108 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 109 | __pypackages__/ 110 | 111 | # Celery stuff 112 | celerybeat-schedule 113 | celerybeat.pid 114 | 115 | # SageMath parsed files 116 | *.sage.py 117 | 118 | # Environments 119 | .env 120 | .venv 121 | env/ 122 | venv/ 123 | ENV/ 124 | env.bak/ 125 | venv.bak/ 126 | 127 | # Spyder project settings 128 | .spyderproject 129 | .spyproject 130 | 131 | # Rope project settings 132 | .ropeproject 133 | 134 | # mkdocs documentation 135 | /site 136 | 137 | # mypy 138 | .mypy_cache/ 139 | .dmypy.json 140 | dmypy.json 141 | 142 | # Pyre type checker 143 | .pyre/ 144 | 145 | wandb/ 146 | *.lmdb/ 147 | -------------------------------------------------------------------------------- /metrics/reconstruction.py: -------------------------------------------------------------------------------- 1 | """ 2 | StyleMapGAN 3 | Copyright (c) 2021-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | 11 | import argparse 12 | import os 13 | from torchvision import transforms 14 | import training.lpips as lpips 15 | import torch.nn as nn 16 | from PIL import Image 17 | import torch 18 | 19 | if __name__ == "__main__": 20 | device = "cuda" 21 | 22 | parser = argparse.ArgumentParser() 23 | 24 | parser.add_argument( 25 | "--image_folder_path", type=str, default="expr/reconstruction/celeba_hq" 26 | ) 27 | args = parser.parse_args() 28 | image_folder_path = args.image_folder_path 29 | 30 | # images(0~1) are converted to -1 ~ 1 31 | transform = transforms.Compose( 32 | [ 33 | transforms.ToTensor(), 34 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 35 | ] 36 | ) 37 | 38 | mse_loss = nn.MSELoss(size_average=True) 39 | percept = lpips.PerceptualLoss( 40 | model="net-lin", net="vgg", use_gpu=device.startswith("cuda") 41 | ) 42 | 43 | fake_filenames = [] 44 | real_filenames = [] 45 | 46 | print(image_folder_path) 47 | dataset_len = 500 48 | 49 | for i in range(dataset_len): 50 | fake_filename = os.path.join(image_folder_path, f"{i}_recon.png") 51 | real_filename = os.path.join(image_folder_path, f"{i}_real.png") 52 | 53 | if os.path.isfile(fake_filename) and os.path.isfile(real_filename): 54 | fake_filenames.append(fake_filename) 55 | real_filenames.append(real_filename) 56 | else: 57 | print(f"{fake_filename} or {real_filename} doesn't exists") 58 | break 59 | 60 | print(len(fake_filenames), len(real_filenames)) 61 | assert len(fake_filenames) == dataset_len 62 | assert len(real_filenames) == dataset_len 63 | 64 | mse_results = [] 65 | lpips_results = [] 66 | 67 | with torch.no_grad(): 68 | for fake_filename, real_filename in zip(fake_filenames, real_filenames): 69 | fake_img = transform(Image.open(fake_filename).convert("RGB")).to(device) 70 | real_img = transform(Image.open(real_filename).convert("RGB")).to(device) 71 | # assert real_img.shape == (3, 256, 256) 72 | 73 | mse = mse_loss(fake_img, real_img).item() 74 | lpips = percept(fake_img, real_img).item() 75 | 76 | mse_results.append(mse) 77 | lpips_results.append(lpips) 78 | 79 | mse_mean = sum(mse_results) / len(mse_results) 80 | lpips_mean = sum(lpips_results) / len(lpips_results) 81 | 82 | print(mse_mean, lpips_mean) 83 | -------------------------------------------------------------------------------- /training/op/fused_act.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py 3 | """ 4 | import os 5 | import torch 6 | from torch import nn 7 | from torch.autograd import Function 8 | from torch.utils.cpp_extension import load 9 | 10 | 11 | module_path = os.path.dirname(__file__) 12 | fused = load( 13 | "fused", 14 | sources=[ 15 | os.path.join(module_path, "fused_bias_act.cpp"), 16 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 17 | ], 18 | ) 19 | 20 | 21 | class FusedLeakyReLUFunctionBackward(Function): 22 | @staticmethod 23 | def forward(ctx, grad_output, out, negative_slope, scale): 24 | ctx.save_for_backward(out) 25 | ctx.negative_slope = negative_slope 26 | ctx.scale = scale 27 | 28 | empty = grad_output.new_empty(0) 29 | 30 | grad_input = fused.fused_bias_act( 31 | grad_output, empty, out, 3, 1, negative_slope, scale 32 | ) 33 | 34 | dim = [0] 35 | 36 | if grad_input.ndim > 2: 37 | dim += list(range(2, grad_input.ndim)) 38 | 39 | grad_bias = grad_input.sum(dim).detach() 40 | 41 | return grad_input, grad_bias 42 | 43 | @staticmethod 44 | def backward(ctx, gradgrad_input, gradgrad_bias): 45 | (out,) = ctx.saved_tensors 46 | gradgrad_out = fused.fused_bias_act( 47 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 48 | ) 49 | 50 | return gradgrad_out, None, None, None 51 | 52 | 53 | class FusedLeakyReLUFunction(Function): 54 | @staticmethod 55 | def forward(ctx, input, bias, negative_slope, scale): 56 | empty = input.new_empty(0) 57 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 58 | ctx.save_for_backward(out) 59 | ctx.negative_slope = negative_slope 60 | ctx.scale = scale 61 | 62 | return out 63 | 64 | @staticmethod 65 | def backward(ctx, grad_output): 66 | (out,) = ctx.saved_tensors 67 | 68 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 69 | grad_output, out, ctx.negative_slope, ctx.scale 70 | ) 71 | 72 | return grad_input, grad_bias, None, None 73 | 74 | 75 | class FusedLeakyReLU(nn.Module): 76 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 77 | super().__init__() 78 | 79 | self.bias = nn.Parameter(torch.zeros(channel)) 80 | self.negative_slope = negative_slope 81 | self.scale = scale 82 | 83 | def forward(self, input): 84 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 85 | 86 | 87 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 88 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 89 | -------------------------------------------------------------------------------- /demo/static/components/css/main.scss: -------------------------------------------------------------------------------- 1 | /* 2 | Refer to https://github.com/quolc/neural-collage/blob/master/static/demo_feature_blending/css/main.scss 3 | */ 4 | 5 | body { 6 | background: #f0f0f0; 7 | } 8 | 9 | #main-column { 10 | background: #fff; 11 | } 12 | 13 | // header 14 | #title { 15 | text-align: center; 16 | h2 { 17 | margin: 10px auto 30px; 18 | } 19 | h4 { 20 | font-size: large; 21 | margin: 20px auto 10px; 22 | } 23 | } 24 | 25 | // base class selection 26 | #base-class { 27 | padding: 10px; 28 | } 29 | 30 | // main ui 31 | #main-ui { 32 | margin: 20px 0 10px; 33 | text-align: center; 34 | 35 | p { 36 | margin-bottom: 5px; 37 | } 38 | 39 | #sketch-control { 40 | margin: 5px 0 0; 41 | text-align: center; 42 | button { 43 | margin: 0 auto; 44 | font-size: smaller; 45 | } 46 | } 47 | 48 | #result-control { 49 | text-align: center; 50 | .slider { 51 | margin: 10px auto 0; 52 | } 53 | } 54 | 55 | #main-ui-middle { 56 | padding: 0; 57 | button { 58 | width: max-content; 59 | font-size: large; 60 | } 61 | } 62 | 63 | .p5-wrapper { 64 | position: relative; 65 | width: 100%; 66 | border: solid 1px #ccc; 67 | } 68 | .p5-wrapper:before { 69 | content:""; 70 | display: block; 71 | padding-top: 100%; 72 | } 73 | .p5 { 74 | position: absolute; 75 | top: 0; 76 | left: 0; 77 | bottom: 0; 78 | right: 0; 79 | } 80 | } 81 | 82 | // palette 83 | #palette { 84 | margin: 10px 0 10px; 85 | 86 | #palette-body { 87 | min-height: 58px; 88 | margin: 0; 89 | padding: 0; 90 | overflow: auto; 91 | list-style-image: none; 92 | list-style-position: outside; 93 | list-style-type: none; 94 | 95 | .palette-item { 96 | float: left; 97 | margin-right: 4px; 98 | width: 58px; 99 | height: 58px; 100 | opacity: 0.6; 101 | img { 102 | width: 54px; 103 | height: 54px; 104 | margin: 2px; 105 | } 106 | &.selected { 107 | opacity: 1.0; 108 | } 109 | } 110 | } 111 | } 112 | 113 | // class-list 114 | #class-list { 115 | margin: 20px 0 30px; 116 | 117 | #class-picker-ui { 118 | margin: 5px 0; 119 | button{ 120 | margin-left: 5px; 121 | } 122 | } 123 | 124 | .image_picker_selector { 125 | max-height: 445px; 126 | overflow: scroll; 127 | 128 | img { 129 | width: 75px; 130 | height: 75px; 131 | } 132 | li { 133 | margin: 0px 4px 4px 0px !important; 134 | .thumbnail { 135 | background: #fff; 136 | padding: 4px !important; 137 | &.selected { 138 | background: #0088cc; 139 | } 140 | } 141 | } 142 | } 143 | } 144 | 145 | // p5 146 | .p5Canvas { 147 | width: 100% !important; 148 | height: 100% !important; 149 | } -------------------------------------------------------------------------------- /preprocessor/prepare_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/prepare_data.py 3 | """ 4 | 5 | import argparse 6 | from io import BytesIO 7 | import multiprocessing 8 | from functools import partial 9 | 10 | from PIL import Image 11 | import lmdb 12 | from tqdm import tqdm 13 | from torchvision import datasets 14 | from torchvision.transforms import functional as trans_fn 15 | 16 | 17 | def resize_and_convert(img, size, resample, quality=100): 18 | img = trans_fn.resize(img, (size, size), resample) 19 | # img = trans_fn.center_crop(img, size) 20 | buffer = BytesIO() 21 | img.save(buffer, format="jpeg", quality=quality) 22 | val = buffer.getvalue() 23 | 24 | return val 25 | 26 | 27 | def resize_multiple( 28 | img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100 29 | ): 30 | imgs = [] 31 | 32 | for size in sizes: 33 | imgs.append(resize_and_convert(img, size, resample, quality)) 34 | 35 | return imgs 36 | 37 | 38 | def resize_worker(img_file, sizes, resample): 39 | i, file = img_file 40 | img = Image.open(file) 41 | img = img.convert("RGB") 42 | out = resize_multiple(img, sizes=sizes, resample=resample) 43 | 44 | return i, out 45 | 46 | 47 | def prepare( 48 | env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS 49 | ): 50 | resize_fn = partial(resize_worker, sizes=sizes, resample=resample) 51 | files = sorted(dataset.imgs, key=lambda x: x[0]) 52 | files = [(i, file) for i, (file, label) in enumerate(files)] 53 | total = 0 54 | 55 | with multiprocessing.Pool(n_worker) as pool: 56 | for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)): 57 | for size, img in zip(sizes, imgs): 58 | key = f"{size}-{str(i).zfill(5)}".encode("utf-8") 59 | 60 | with env.begin(write=True) as txn: 61 | txn.put(key, img) 62 | 63 | total += 1 64 | 65 | with env.begin(write=True) as txn: 66 | txn.put("length".encode("utf-8"), str(total).encode("utf-8")) 67 | 68 | 69 | if __name__ == "__main__": 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument("--out", type=str) 72 | parser.add_argument("--size", type=str, default="128,256,512,1024") 73 | parser.add_argument("--n_worker", type=int, default=5) 74 | parser.add_argument("--resample", type=str, default="bilinear") 75 | parser.add_argument("path", type=str) 76 | 77 | args = parser.parse_args() 78 | 79 | resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR} 80 | resample = resample_map[args.resample] 81 | 82 | sizes = [int(s.strip()) for s in args.size.split(",")] 83 | print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes)) 84 | 85 | imgset = datasets.ImageFolder(args.path) 86 | 87 | with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env: 88 | prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample) 89 | -------------------------------------------------------------------------------- /demo/static/components/css/main.css: -------------------------------------------------------------------------------- 1 | /* 2 | Refer to https://github.com/quolc/neural-collage/blob/master/static/demo_feature_blending/css/main.css 3 | */ 4 | 5 | body { 6 | background: #f0f0f0; } 7 | 8 | #main-column { 9 | background: #fff; } 10 | 11 | #title { 12 | text-align: center; } 13 | #title h2 { 14 | margin: 10px auto 30px; } 15 | #title h4 { 16 | font-size: large; 17 | margin: 20px auto 10px; } 18 | 19 | #base-class { 20 | padding: 10px; } 21 | 22 | #main-ui { 23 | margin: 20px 0 10px; 24 | text-align: center; } 25 | #main-ui p { 26 | margin-bottom: 5px; } 27 | #main-ui #sketch-control { 28 | margin: 5px 0 0; 29 | text-align: center; } 30 | #main-ui #sketch-control button { 31 | margin: 0 auto; 32 | font-size: smaller; } 33 | #main-ui #result-control { 34 | text-align: center; } 35 | #main-ui #result-control .slider { 36 | margin: 10px auto 0; } 37 | #main-ui #main-ui-middle { 38 | padding: 0; } 39 | #main-ui #main-ui-middle button { 40 | width: max-content; 41 | font-size: large; } 42 | #main-ui .p5-wrapper { 43 | position: relative; 44 | width: 100%; 45 | border: solid 1px #ccc; } 46 | #main-ui .p5-wrapper:before { 47 | content: ""; 48 | display: block; 49 | padding-top: 100%; } 50 | #main-ui .p5 { 51 | position: absolute; 52 | top: 0; 53 | left: 0; 54 | bottom: 0; 55 | right: 0; } 56 | 57 | #palette { 58 | margin: 10px 0 10px; } 59 | #palette #palette-body { 60 | min-height: 58px; 61 | margin: 0; 62 | padding: 0; 63 | overflow: auto; 64 | list-style-image: none; 65 | list-style-position: outside; 66 | list-style-type: none; } 67 | #palette #palette-body .palette-item { 68 | float: left; 69 | margin-right: 4px; 70 | width: 90px; 71 | height: 58px; 72 | opacity: 0.3; } 73 | #palette #palette-body .palette-item img { 74 | width: 54px; 75 | height: 54px; 76 | margin: 2px; } 77 | #palette #palette-body .palette-item.selected { 78 | opacity: 1.0; } 79 | 80 | #class-list { 81 | margin: 20px 0 30px; } 82 | #class-list #class-picker-ui { 83 | margin: 5px 0; } 84 | #class-list #class-picker-ui button { 85 | margin-left: 5px; } 86 | #class-list .image_picker_selector { 87 | max-height: 445px; 88 | overflow: scroll; } 89 | #class-list .image_picker_selector img { 90 | width: 75px; 91 | height: 75px; } 92 | #class-list .image_picker_selector li { 93 | margin: 0px 4px 4px 0px !important; } 94 | #class-list .image_picker_selector li .thumbnail { 95 | background: #fff; 96 | padding: 4px !important; } 97 | #class-list .image_picker_selector li .thumbnail.selected { 98 | background: #0088cc; } 99 | 100 | .p5Canvas { 101 | width: 100% !important; 102 | height: 100% !important; } 103 | 104 | /*# sourceMappingURL=main.css.map */ 105 | -------------------------------------------------------------------------------- /training/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /demo/templates/index.html: -------------------------------------------------------------------------------- 1 | 4 | 5 | {% extends "layout.html" %} 6 | {% block content %} 7 | 8 |
9 |
10 |
11 | 12 |
13 |

14 | Demo: StyleMapGAN (CVPR21) 15 |

16 |
17 | 18 |
19 |

20 | Exploiting Spatial Dimensions of Latent in GAN for Real-time Image Editing 21 |

22 |
23 | 24 | 25 | 26 |
27 |
28 |

Reference Images / Canvas Draw

29 |
30 |
31 |
32 |
33 | 34 |
35 |
36 | 37 |
38 |

Original Image / Canvas Move

39 |
40 |
41 |
42 |
43 | 44 |
45 | 46 |
47 | 48 |
49 |

Generated Result

50 |
51 |
52 |
53 |
54 | 56 |
57 |
58 |
59 | 60 |
61 |
62 |

Selected Reference Images

63 |
    64 |
  • 66 | 67 |
  • 68 |
69 |
70 |
71 | 72 |
73 |
74 |

Sample Images

75 | 77 |
78 |
79 |
80 |
81 |
82 | 83 | {% endblock %} -------------------------------------------------------------------------------- /metrics/calc_inception.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/calc_inception.py 3 | """ 4 | 5 | import argparse 6 | import pickle 7 | import os 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | from torch.utils.data import DataLoader 13 | from torchvision import transforms 14 | from torchvision.models import inception_v3, Inception3 15 | import numpy as np 16 | from tqdm import tqdm 17 | 18 | from metrics.inception import InceptionV3 19 | from training.dataset import MultiResolutionDataset 20 | 21 | 22 | class Inception3Feature(Inception3): 23 | def forward(self, x): 24 | if x.shape[2] != 299 or x.shape[3] != 299: 25 | x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=True) 26 | 27 | x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3 28 | x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32 29 | x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32 30 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64 31 | 32 | x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64 33 | x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80 34 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192 35 | 36 | x = self.Mixed_5b(x) # 35 x 35 x 192 37 | x = self.Mixed_5c(x) # 35 x 35 x 256 38 | x = self.Mixed_5d(x) # 35 x 35 x 288 39 | 40 | x = self.Mixed_6a(x) # 35 x 35 x 288 41 | x = self.Mixed_6b(x) # 17 x 17 x 768 42 | x = self.Mixed_6c(x) # 17 x 17 x 768 43 | x = self.Mixed_6d(x) # 17 x 17 x 768 44 | x = self.Mixed_6e(x) # 17 x 17 x 768 45 | 46 | x = self.Mixed_7a(x) # 17 x 17 x 768 47 | x = self.Mixed_7b(x) # 8 x 8 x 1280 48 | x = self.Mixed_7c(x) # 8 x 8 x 2048 49 | 50 | x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048 51 | 52 | return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048 53 | 54 | 55 | def load_patched_inception_v3(): 56 | # inception = inception_v3(pretrained=True) 57 | # inception_feat = Inception3Feature() 58 | # inception_feat.load_state_dict(inception.state_dict()) 59 | inception_feat = InceptionV3([3], normalize_input=False) 60 | 61 | return inception_feat 62 | 63 | 64 | @torch.no_grad() 65 | def extract_features(loader, inception, device): 66 | pbar = tqdm(loader) 67 | 68 | feature_list = [] 69 | 70 | for img in pbar: 71 | img = img.to(device) 72 | feature = inception(img)[0].view(img.shape[0], -1) 73 | feature_list.append(feature.to("cpu")) 74 | 75 | features = torch.cat(feature_list, 0) 76 | 77 | return features 78 | 79 | 80 | if __name__ == "__main__": 81 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 82 | 83 | parser = argparse.ArgumentParser( 84 | description="Calculate Inception v3 features for datasets" 85 | ) 86 | parser.add_argument("--size", type=int, default=256) 87 | parser.add_argument("--batch", default=64, type=int, help="batch size") 88 | parser.add_argument("--n_sample", type=int, default=50000) 89 | parser.add_argument("--flip", action="store_true") 90 | parser.add_argument("path", metavar="PATH", help="path to datset lmdb file") 91 | 92 | args = parser.parse_args() 93 | 94 | inception = load_patched_inception_v3() 95 | inception = nn.DataParallel(inception).eval().to(device) 96 | 97 | transform = transforms.Compose( 98 | [ 99 | transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0), 100 | transforms.ToTensor(), 101 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 102 | ] 103 | ) 104 | 105 | dset = MultiResolutionDataset(args.path, transform=transform, resolution=args.size) 106 | loader = DataLoader(dset, batch_size=args.batch, num_workers=4) 107 | 108 | features = extract_features(loader, inception, device).numpy() 109 | 110 | features = features[: args.n_sample] 111 | 112 | print(f"extracted {features.shape[0]} features") 113 | 114 | mean = np.mean(features, 0) 115 | cov = np.cov(features, rowvar=False) 116 | 117 | name = os.path.splitext(os.path.basename(args.path))[0] 118 | 119 | with open(f"inception_{name}.pkl", "wb") as f: 120 | pickle.dump({"mean": mean, "cov": cov, "size": args.size, "path": args.path}, f) 121 | -------------------------------------------------------------------------------- /metrics/local_editing.py: -------------------------------------------------------------------------------- 1 | """ 2 | StyleMapGAN 3 | Copyright (c) 2021-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | 11 | import argparse 12 | import pickle 13 | 14 | import torch 15 | from torch import nn 16 | import numpy as np 17 | from scipy import linalg 18 | from torchvision import utils, transforms 19 | from torch.nn import functional as F 20 | from training.dataset import DataSetTestLocalEditing 21 | from torch.utils import data 22 | import random 23 | import time 24 | 25 | from tqdm import tqdm 26 | import os 27 | 28 | torch.manual_seed(0) 29 | torch.cuda.manual_seed_all(0) 30 | random.seed(0) 31 | 32 | 33 | def data_sampler(dataset, shuffle): 34 | if shuffle: 35 | return data.RandomSampler(dataset) 36 | else: 37 | return data.SequentialSampler(dataset) 38 | 39 | 40 | if __name__ == "__main__": 41 | device = "cuda" 42 | 43 | parser = argparse.ArgumentParser() 44 | 45 | parser.add_argument("--batch", type=int, default=4) 46 | parser.add_argument("--num_workers", type=int, default=10) 47 | parser.add_argument("--data_dir") 48 | 49 | args = parser.parse_args() 50 | 51 | size = 256 52 | channel_multiplier = 2 53 | 54 | batch_size = args.batch 55 | 56 | if "celeba_hq" in args.data_dir: 57 | parts = os.listdir(args.data_dir) 58 | elif "afhq" in args.data_dir: 59 | parts = [""] 60 | 61 | print("parts", parts) 62 | 63 | transform = transforms.Compose( 64 | [ 65 | transforms.ToTensor(), 66 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), 67 | ] 68 | ) 69 | mse_loss = nn.MSELoss(reduction="none") 70 | 71 | inMSE_all_part = [] 72 | outMSE_all_part = [] 73 | 74 | with torch.no_grad(): 75 | for part in parts: 76 | print(part) 77 | inMSE_ref_list = [] 78 | outMSE_src_list = [] 79 | 80 | dataset = DataSetTestLocalEditing( 81 | os.path.join(args.data_dir, part), transform 82 | ) 83 | assert len(dataset) == 500 / 2 84 | 85 | loader = data.DataLoader( 86 | dataset, 87 | batch_size, 88 | sampler=data_sampler(dataset, shuffle=False), 89 | num_workers=args.num_workers, 90 | pin_memory=True, 91 | ) 92 | 93 | for mask, image_reference, image_source, image_synthesized in tqdm( 94 | loader, mininterval=1 95 | ): 96 | N = len(mask) 97 | 98 | mask, image_reference, image_source, image_synthesized = ( 99 | mask.to(device), 100 | image_reference.to(device), 101 | image_source.to(device), 102 | image_synthesized.to(device), 103 | ) 104 | 105 | MSE_between_src = mse_loss(image_synthesized, image_source) 106 | MSE_between_ref = mse_loss(image_synthesized, image_reference) 107 | 108 | inMSE_mask_count = (mask == 1).sum() 109 | outMSE_mask_count = (mask == -1).sum() 110 | 111 | if inMSE_mask_count == 0: 112 | # print("no mask is found") 113 | continue 114 | 115 | assert inMSE_mask_count + outMSE_mask_count == size * size * N * 3 116 | 117 | dummy = torch.zeros(MSE_between_src.shape, device=device) 118 | 119 | inMSE_ref = torch.where(mask == 1, MSE_between_ref, dummy) 120 | inMSE_ref = inMSE_ref.sum() / inMSE_mask_count 121 | 122 | outMSE_src = torch.where(mask == -1, MSE_between_src, dummy) 123 | outMSE_src = outMSE_src.sum() / outMSE_mask_count 124 | inMSE_ref_list.append(inMSE_ref.mean()) 125 | outMSE_src_list.append(outMSE_src.mean()) 126 | 127 | inMSE_ref = sum(inMSE_ref_list) / len(inMSE_ref_list) 128 | outMSE_src = sum(outMSE_src_list) / len(outMSE_src_list) 129 | 130 | inMSE_all_part.append(inMSE_ref) 131 | outMSE_all_part.append(outMSE_src) 132 | 133 | print(f"{inMSE_ref:.3f}, {outMSE_src:.3f}") 134 | 135 | print(f"average in inMSE_ref, outMSE_src") 136 | print( 137 | f"{sum(inMSE_all_part) / len(inMSE_all_part):.3f}, {sum(outMSE_all_part) / len(outMSE_all_part):.3f}" 138 | ) 139 | -------------------------------------------------------------------------------- /preprocessor/pair_masks.py: -------------------------------------------------------------------------------- 1 | """ 2 | StyleMapGAN 3 | Copyright (c) 2021-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | 11 | from io import BytesIO 12 | import sys 13 | from os import path 14 | import lmdb 15 | from PIL import Image 16 | import argparse 17 | import numpy as np 18 | import os 19 | import pickle 20 | import itertools 21 | import torch.nn.functional as F 22 | from tqdm import tqdm 23 | import torch 24 | from torch import nn 25 | from torch.utils.data import Dataset 26 | from torch.utils import data 27 | from torchvision import transforms, utils 28 | from torchvision.utils import save_image 29 | 30 | 31 | sys.path.append(path.dirname(path.dirname(path.abspath(__file__)))) 32 | from training.dataset import GTMaskDataset 33 | 34 | 35 | def sample_data(loader): 36 | while True: 37 | for batch in loader: 38 | yield batch 39 | 40 | 41 | @torch.no_grad() 42 | def group_pair_GT(): 43 | device = "cuda" 44 | args.n_sample = 500 45 | 46 | transform = transforms.Compose( 47 | [ 48 | transforms.ToTensor(), 49 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), 50 | ] 51 | ) 52 | 53 | images_size = 256 # you can use other resolution for calculating if your LMDB(args.path) has different resolution. 54 | dataset = GTMaskDataset(args.path, transform, images_size) 55 | 56 | parts_index = { 57 | "all": None, 58 | "background": [0], 59 | "skin": [1], 60 | "eyebrow": [6, 7], 61 | "eye": [3, 4, 5], 62 | "ear": [8, 9, 15], 63 | "nose": [2], 64 | "lip": [10, 11, 12], 65 | "neck": [16, 17], 66 | "cloth": [18], 67 | "hair": [13, 14], 68 | } 69 | 70 | indexes = range(args.n_sample) 71 | 72 | similarity_dict = {} 73 | parts = parts_index.keys() 74 | 75 | for part in parts: 76 | similarity_dict[part] = {} 77 | 78 | for src, ref in tqdm( 79 | itertools.combinations(indexes, 2), 80 | total=sum(1 for _ in itertools.combinations(indexes, 2)), 81 | ): 82 | _, mask1 = dataset[src] 83 | _, mask2 = dataset[ref] 84 | mask1 = mask1.to(device) 85 | mask2 = mask2.to(device) 86 | for part in parts: 87 | if part == "all": 88 | similarity = torch.sum(mask1 == mask2).item() / (images_size ** 2) 89 | similarity_dict["all"][src, ref] = similarity 90 | else: 91 | part1 = torch.zeros( 92 | [images_size, images_size], dtype=torch.bool, device=device 93 | ) 94 | part2 = torch.zeros( 95 | [images_size, images_size], dtype=torch.bool, device=device 96 | ) 97 | 98 | for p in parts_index[part]: 99 | part1 = part1 | (mask1 == p) 100 | part2 = part2 | (mask2 == p) 101 | 102 | intersection = (part1 & part2).sum().float().item() 103 | union = (part1 | part2).sum().float().item() 104 | if union == 0: 105 | similarity_dict[part][src, ref] = 0.0 106 | else: 107 | sim = intersection / union 108 | similarity_dict[part][src, ref] = sim 109 | 110 | sorted_similarity = {} 111 | 112 | for part, similarities in similarity_dict.items(): 113 | all_indexes = set(range(args.n_sample)) 114 | sorted_similarity[part] = [] 115 | 116 | sorted_list = sorted(similarities.items(), key=(lambda x: x[1]), reverse=True) 117 | 118 | for (i1, i2), prob in sorted_list: 119 | if (i1 in all_indexes) and (i2 in all_indexes): 120 | all_indexes -= {i1, i2} 121 | sorted_similarity[part].append(((i1, i2), prob)) 122 | elif len(all_indexes) == 0: 123 | break 124 | 125 | assert len(sorted_similarity[part]) == args.n_sample // 2 126 | 127 | with open( 128 | f"{args.save_dir}/{args.dataset_name}_test_{args.mask_origin}_sorted_pair.pkl", 129 | "wb", 130 | ) as handle: 131 | pickle.dump(sorted_similarity, handle) 132 | 133 | 134 | if __name__ == "__main__": 135 | device = "cuda" 136 | 137 | parser = argparse.ArgumentParser() 138 | parser.add_argument("--num_workers", type=int, default=10) 139 | parser.add_argument("--batch", type=int, default=16) 140 | parser.add_argument( 141 | "--save_dir", type=str, default="../data/celeba_hq/local_editing" 142 | ) 143 | 144 | args = parser.parse_args() 145 | args.dataset_name = "celeba_hq" 146 | os.makedirs(args.save_dir, exist_ok=True) 147 | args.path = f"../data/{args.dataset_name}" 148 | 149 | with torch.no_grad(): 150 | # our CelebA-HQ test dataset contains 500 images 151 | # change this value if you have the different number of GT_labels 152 | args.n_sample = 500 153 | group_pair_GT() 154 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | """ 2 | StyleMapGAN 3 | Copyright (c) 2021-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | 11 | CMD=$1 12 | DATASET=$2 13 | 14 | if [ $CMD == "prepare-fid-calculation" ]; then 15 | # download pretrained network and stats(mean, var) of each dataset to calculate FID 16 | URL="https://docs.google.com/uc?export=download&id=1pCr4lNCON7IZcNVdskIDXhFJ3jYuge1w" 17 | NETWORK_FOLDER="./metrics" 18 | NETWORK_FILE=$NETWORK_FOLDER/pt_inception-2015-12-05-6726825d.pth 19 | wget --no-check-certificate $URL -O $NETWORK_FILE 20 | 21 | # download precalculated statistics in several datasets 22 | URL="https://docs.google.com/uc?export=download&id=1sJ7AYaY3JTVFqI6Dodnzx81KydadRpnj" 23 | mkdir -p "./metrics/fid_stats" 24 | ZIP_FILE="./metrics/fid_stats.zip" 25 | wget --no-check-certificate -r $URL -O $ZIP_FILE 26 | unzip $ZIP_FILE -d "./metrics/fid_stats" 27 | rm $ZIP_FILE 28 | 29 | elif [ $CMD == "create-lmdb-dataset" ]; then 30 | if [ $DATASET == "celeba_hq" ]; then 31 | URL="https://docs.google.com/uc?export=download&id=1R72NB79CX0MpnmWSli2SMu-Wp-M0xI-o" 32 | DATASET_FOLDER="./data/celeba_hq" 33 | ZIP_FILE=$DATASET_FOLDER/celeba_hq_raw.zip 34 | elif [ $DATASET == "afhq" ]; then 35 | URL="https://docs.google.com/uc?export=download&id=1Pf4f6Y27lQX9y9vjeSQnoOQntw_ln7il" 36 | DATASET_FOLDER="./data/afhq" 37 | ZIP_FILE=$DATASET_FOLDER/afhq_raw.zip 38 | else 39 | echo "Unknown DATASET" 40 | exit 1 41 | fi 42 | mkdir -p $DATASET_FOLDER 43 | wget --no-check-certificate -r $URL -O $ZIP_FILE 44 | unzip $ZIP_FILE -d $DATASET_FOLDER 45 | rm $ZIP_FILE 46 | 47 | # raw images to LMDB format 48 | TARGET_SIZE=256,1024 49 | for DATASET_TYPE in "train" "test" "val"; do 50 | python preprocessor/prepare_data.py --out $DATASET_FOLDER/LMDB_$DATASET_TYPE --size $TARGET_SIZE $DATASET_FOLDER/raw_images/$DATASET_TYPE 51 | done 52 | 53 | # for local editing 54 | FOLDERNAME=$DATASET_FOLDER/local_editing 55 | mkdir -p $FOLDERNAME 56 | 57 | if [ $DATASET == "celeba_hq" ]; then 58 | URL="https://docs.google.com/uc?export=download&id=1_4Cxd7uH8Zqlutu5zUNJfVpgljqh7Olf" 59 | wget -r --no-check-certificate $URL -O $FOLDERNAME/GT_labels.zip 60 | unzip $FOLDERNAME/GT_labels.zip -d $FOLDERNAME 61 | rm $FOLDERNAME/GT_labels.zip 62 | URL="https://docs.google.com/uc?export=download&id=1dy-3UxETpI58xroeAGXqidSxRCgV71SV" 63 | wget -r --no-check-certificate $URL -O $FOLDERNAME/LMDB_test_mask.zip 64 | unzip $FOLDERNAME/LMDB_test_mask.zip -d $FOLDERNAME 65 | rm $FOLDERNAME/LMDB_test_mask.zip 66 | URL="https://docs.google.com/uc?export=download&id=1rCxK0ybho9Xqexvec0g0khPPLL_cRZfx" 67 | wget --no-check-certificate $URL -O $FOLDERNAME/celeba_hq_test_GT_sorted_pair.pkl 68 | URL="https://docs.google.com/uc?export=download&id=1g4tatzpPsycq2h4B2NjejgiuHUGIvuB0" 69 | wget --no-check-certificate $URL -O $FOLDERNAME/CelebA-HQ-to-CelebA-mapping.txt 70 | fi 71 | 72 | elif [ $CMD == "download-pretrained-network-256" ]; then 73 | # 20M-image-trained models 74 | if [ $DATASET == "celeba_hq" ]; then 75 | URL="https://docs.google.com/uc?export=download&id=1Up6qELYFF1cV0HREnHpykKN2Ordr1xpp" 76 | elif [ $DATASET == "afhq" ]; then 77 | URL="https://docs.google.com/uc?export=download&id=1gKSxyBWUc53OaRwFZ6w2CuNkiLf3_X9C" 78 | elif [ $DATASET == "lsun_car" ]; then 79 | URL="https://docs.google.com/uc?export=download&id=1P77_21yBcgF5AMs8hMBT8m0eEhR5j8_Q" 80 | elif [ $DATASET == "lsun_church" ]; then 81 | URL="https://docs.google.com/uc?export=download&id=1sxdDn2dK1Ilqv9KSrXqABSUywV25pbin" 82 | 83 | # 5M-image-trained models used in our paper for comparison with other baselines and for ablation studies. 84 | elif [ $DATASET == "celeba_hq_5M" ]; then 85 | URL="https://docs.google.com/uc?export=download&id=1-t9WkasJzn4-pljZI5619SMyvDXB_9iv" 86 | elif [ $DATASET == "afhq_5M" ]; then 87 | URL="https://docs.google.com/uc?export=download&id=1on4L_2WAl8PpH4iU1wrRCftbE_j_CrTI" 88 | else 89 | echo "Unknown DATASET" 90 | exit 1 91 | fi 92 | 93 | NETWORK_FOLDER="./expr/checkpoints" 94 | NETWORK_FILE=$NETWORK_FOLDER/${DATASET}_256_8x8.pt 95 | mkdir -p $NETWORK_FOLDER 96 | wget -r --no-check-certificate $URL -O $NETWORK_FILE 97 | 98 | elif [ $CMD == "download-pretrained-network-1024" ]; then 99 | NETWORK_FOLDER="./expr/checkpoints" 100 | mkdir -p $NETWORK_FOLDER 101 | if [ $DATASET == "ffhq_16x16" ]; then 102 | URL="https://docs.google.com/uc?export=download&id=14wPqqpWIe34hh2LHoSsOhXidrm4bS3Sg" 103 | NETWORK_FILE=$NETWORK_FOLDER/ffhq_1024_16x16.pt 104 | elif [ $DATASET == "ffhq_32x32" ]; then 105 | URL="https://docs.google.com/uc?export=download&id=1UqBHEICkL1Ml2m56eG3_u9_KvddeJlDK" 106 | NETWORK_FILE=$NETWORK_FOLDER/ffhq_1024_32x32.pt 107 | else 108 | echo "Unknown DATASET" 109 | exit 1 110 | fi 111 | wget -r --no-check-certificate $URL -O $NETWORK_FILE 112 | 113 | else 114 | echo "Unknown CMD" 115 | exit 1 116 | fi 117 | -------------------------------------------------------------------------------- /training/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py 3 | """ 4 | import os 5 | 6 | import torch 7 | from torch.autograd import Function 8 | from torch.utils.cpp_extension import load 9 | 10 | 11 | module_path = os.path.dirname(__file__) 12 | upfirdn2d_op = load( 13 | "upfirdn2d", 14 | sources=[ 15 | os.path.join(module_path, "upfirdn2d.cpp"), 16 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 17 | ], 18 | ) 19 | 20 | 21 | class UpFirDn2dBackward(Function): 22 | @staticmethod 23 | def forward( 24 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 25 | ): 26 | 27 | up_x, up_y = up 28 | down_x, down_y = down 29 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 30 | 31 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 32 | 33 | grad_input = upfirdn2d_op.upfirdn2d( 34 | grad_output, 35 | grad_kernel, 36 | down_x, 37 | down_y, 38 | up_x, 39 | up_y, 40 | g_pad_x0, 41 | g_pad_x1, 42 | g_pad_y0, 43 | g_pad_y1, 44 | ) 45 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 46 | 47 | ctx.save_for_backward(kernel) 48 | 49 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 50 | 51 | ctx.up_x = up_x 52 | ctx.up_y = up_y 53 | ctx.down_x = down_x 54 | ctx.down_y = down_y 55 | ctx.pad_x0 = pad_x0 56 | ctx.pad_x1 = pad_x1 57 | ctx.pad_y0 = pad_y0 58 | ctx.pad_y1 = pad_y1 59 | ctx.in_size = in_size 60 | ctx.out_size = out_size 61 | 62 | return grad_input 63 | 64 | @staticmethod 65 | def backward(ctx, gradgrad_input): 66 | (kernel,) = ctx.saved_tensors 67 | 68 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 69 | 70 | gradgrad_out = upfirdn2d_op.upfirdn2d( 71 | gradgrad_input, 72 | kernel, 73 | ctx.up_x, 74 | ctx.up_y, 75 | ctx.down_x, 76 | ctx.down_y, 77 | ctx.pad_x0, 78 | ctx.pad_x1, 79 | ctx.pad_y0, 80 | ctx.pad_y1, 81 | ) 82 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 83 | gradgrad_out = gradgrad_out.view( 84 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 85 | ) 86 | 87 | return gradgrad_out, None, None, None, None, None, None, None, None 88 | 89 | 90 | class UpFirDn2d(Function): 91 | @staticmethod 92 | def forward(ctx, input, kernel, up, down, pad): 93 | up_x, up_y = up 94 | down_x, down_y = down 95 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 96 | 97 | kernel_h, kernel_w = kernel.shape 98 | batch, channel, in_h, in_w = input.shape 99 | ctx.in_size = input.shape 100 | 101 | input = input.reshape(-1, in_h, in_w, 1) 102 | 103 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 104 | 105 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 106 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 107 | ctx.out_size = (out_h, out_w) 108 | 109 | ctx.up = (up_x, up_y) 110 | ctx.down = (down_x, down_y) 111 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 112 | 113 | g_pad_x0 = kernel_w - pad_x0 - 1 114 | g_pad_y0 = kernel_h - pad_y0 - 1 115 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 116 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 117 | 118 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 119 | 120 | out = upfirdn2d_op.upfirdn2d( 121 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 122 | ) 123 | # out = out.view(major, out_h, out_w, minor) 124 | out = out.view(-1, channel, out_h, out_w) 125 | 126 | return out 127 | 128 | @staticmethod 129 | def backward(ctx, grad_output): 130 | kernel, grad_kernel = ctx.saved_tensors 131 | 132 | grad_input = UpFirDn2dBackward.apply( 133 | grad_output, 134 | kernel, 135 | grad_kernel, 136 | ctx.up, 137 | ctx.down, 138 | ctx.pad, 139 | ctx.g_pad, 140 | ctx.in_size, 141 | ctx.out_size, 142 | ) 143 | 144 | return grad_input, None, None, None, None 145 | 146 | 147 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 148 | out = UpFirDn2d.apply( 149 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 150 | ) 151 | 152 | return out 153 | 154 | 155 | def upfirdn2d_native( 156 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 157 | ): 158 | _, in_h, in_w, minor = input.shape 159 | kernel_h, kernel_w = kernel.shape 160 | 161 | out = input.view(-1, in_h, 1, in_w, 1, minor) 162 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 163 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 164 | 165 | out = F.pad( 166 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 167 | ) 168 | out = out[ 169 | :, 170 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 171 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 172 | :, 173 | ] 174 | 175 | out = out.permute(0, 3, 1, 2) 176 | out = out.reshape( 177 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 178 | ) 179 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 180 | out = F.conv2d(out, w) 181 | out = out.reshape( 182 | -1, 183 | minor, 184 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 185 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 186 | ) 187 | out = out.permute(0, 2, 3, 1) 188 | 189 | return out[:, ::down_y, ::down_x, :] 190 | -------------------------------------------------------------------------------- /training/dataset.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | """ 4 | StyleMapGAN 5 | Copyright (c) 2021-present NAVER Corp. 6 | 7 | This work is licensed under the Creative Commons Attribution-NonCommercial 8 | 4.0 International License. To view a copy of this license, visit 9 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 10 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 11 | """ 12 | 13 | import lmdb 14 | from PIL import Image 15 | from torch.utils.data import Dataset 16 | from torch.utils import data 17 | import numpy as np 18 | import random 19 | import re, os 20 | from torchvision import transforms 21 | import torch 22 | 23 | 24 | class MultiResolutionDataset(Dataset): 25 | def __init__(self, path, transform, resolution=256): 26 | self.env = lmdb.open( 27 | path, 28 | max_readers=32, 29 | readonly=True, 30 | lock=False, 31 | readahead=False, 32 | meminit=False, 33 | ) 34 | 35 | if not self.env: 36 | raise IOError("Cannot open lmdb dataset", path) 37 | 38 | with self.env.begin(write=False) as txn: 39 | self.length = int(txn.get("length".encode("utf-8")).decode("utf-8")) 40 | 41 | self.resolution = resolution 42 | self.transform = transform 43 | 44 | def __len__(self): 45 | return self.length 46 | 47 | def __getitem__(self, index): 48 | with self.env.begin(write=False) as txn: 49 | key = f"{self.resolution}-{str(index).zfill(5)}".encode("utf-8") 50 | img_bytes = txn.get(key) 51 | 52 | buffer = BytesIO(img_bytes) 53 | img = Image.open(buffer) 54 | img = self.transform(img) 55 | 56 | return img 57 | 58 | 59 | class GTMaskDataset(Dataset): 60 | def __init__(self, dataset_folder, transform, resolution=256): 61 | 62 | self.env = lmdb.open( 63 | f"{dataset_folder}/LMDB_test", 64 | max_readers=32, 65 | readonly=True, 66 | lock=False, 67 | readahead=False, 68 | meminit=False, 69 | ) 70 | 71 | if not self.env: 72 | raise IOError("Cannot open lmdb dataset", f"{dataset_folder}/LMDB_test") 73 | 74 | with self.env.begin(write=False) as txn: 75 | self.length = int(txn.get("length".encode("utf-8")).decode("utf-8")) 76 | 77 | self.resolution = resolution 78 | self.transform = transform 79 | 80 | # convert filename to celeba_hq index 81 | CelebA_HQ_to_CelebA = ( 82 | f"{dataset_folder}/local_editing/CelebA-HQ-to-CelebA-mapping.txt" 83 | ) 84 | CelebA_to_CelebA_HQ_dict = {} 85 | 86 | original_test_path = f"{dataset_folder}/raw_images/test/images" 87 | mask_label_path = f"{dataset_folder}/local_editing/GT_labels" 88 | 89 | with open(CelebA_HQ_to_CelebA, "r") as fp: 90 | read_line = fp.readline() 91 | attrs = re.sub(" +", " ", read_line).strip().split(" ") 92 | while True: 93 | read_line = fp.readline() 94 | 95 | if not read_line: 96 | break 97 | 98 | idx, orig_idx, orig_file = ( 99 | re.sub(" +", " ", read_line).strip().split(" ") 100 | ) 101 | 102 | CelebA_to_CelebA_HQ_dict[orig_file] = idx 103 | 104 | self.mask = [] 105 | 106 | for filename in os.listdir(original_test_path): 107 | CelebA_HQ_filename = CelebA_to_CelebA_HQ_dict[filename] 108 | CelebA_HQ_filename = CelebA_HQ_filename + ".png" 109 | self.mask.append(os.path.join(mask_label_path, CelebA_HQ_filename)) 110 | 111 | def __len__(self): 112 | return self.length 113 | 114 | def __getitem__(self, index): 115 | with self.env.begin(write=False) as txn: 116 | key = f"{self.resolution}-{str(index).zfill(5)}".encode("utf-8") 117 | img_bytes = txn.get(key) 118 | 119 | buffer = BytesIO(img_bytes) 120 | img = Image.open(buffer) 121 | img = self.transform(img) 122 | 123 | mask = Image.open(self.mask[index]) 124 | 125 | mask = mask.resize((self.resolution, self.resolution), Image.NEAREST) 126 | mask = transforms.ToTensor()(mask) 127 | 128 | mask = mask.squeeze() 129 | mask *= 255 130 | mask = mask.long() 131 | 132 | assert mask.shape == (self.resolution, self.resolution) 133 | return img, mask 134 | 135 | 136 | class DataSetFromDir(Dataset): 137 | def __init__(self, main_dir, transform): 138 | self.main_dir = main_dir 139 | self.transform = transform 140 | all_imgs = os.listdir(main_dir) 141 | self.total_imgs = [] 142 | 143 | for img in all_imgs: 144 | if ".png" in img: 145 | self.total_imgs.append(img) 146 | 147 | def __len__(self): 148 | return len(self.total_imgs) 149 | 150 | def __getitem__(self, idx): 151 | img_loc = os.path.join(self.main_dir, self.total_imgs[idx]) 152 | image = Image.open(img_loc).convert("RGB") 153 | tensor_image = self.transform(image) 154 | return tensor_image 155 | 156 | 157 | class DataSetTestLocalEditing(Dataset): 158 | def __init__(self, main_dir, transform): 159 | self.main_dir = main_dir 160 | self.transform = transform 161 | 162 | all_imgs = os.listdir(os.path.join(main_dir, "mask")) 163 | self.total_imgs = [] 164 | 165 | for img in all_imgs: 166 | if ".png" in img: 167 | self.total_imgs.append(img) 168 | 169 | def __len__(self): 170 | return len(self.total_imgs) 171 | 172 | def __getitem__(self, idx): 173 | image_mask = self.transform( 174 | Image.open( 175 | os.path.join(self.main_dir, "mask", self.total_imgs[idx]) 176 | ).convert("RGB") 177 | ) 178 | image_reference = self.transform( 179 | Image.open( 180 | os.path.join(self.main_dir, "reference_image", self.total_imgs[idx]) 181 | ).convert("RGB") 182 | ) 183 | # image_reference_recon = self.transform(Image.open(os.path.join(self.main_dir, 'reference_image', self.total_imgs[idx].replace('.png', '_recon_img.png'))).convert("RGB")) 184 | 185 | image_source = self.transform( 186 | Image.open( 187 | os.path.join(self.main_dir, "source_image", self.total_imgs[idx]) 188 | ).convert("RGB") 189 | ) 190 | # image_source_recon = self.transform(Image.open(os.path.join(self.main_dir, 'source_image', self.total_imgs[idx].replace('.png', '_recon_img.png'))).convert("RGB")) 191 | 192 | image_synthesized = self.transform( 193 | Image.open( 194 | os.path.join(self.main_dir, "synthesized_image", self.total_imgs[idx]) 195 | ).convert("RGB") 196 | ) 197 | 198 | return image_mask, image_reference, image_source, image_synthesized 199 | -------------------------------------------------------------------------------- /demo/static/components/js/image-picker.min.js: -------------------------------------------------------------------------------- 1 | // Image Picker 2 | // by Rodrigo Vera 3 | // 4 | // Version 0.3.1 5 | // Full source at https://github.com/rvera/image-picker 6 | // MIT License, https://github.com/rvera/image-picker/blob/master/LICENSE 7 | // Image Picker 8 | // by Rodrigo Vera 9 | // 10 | // Version 0.3.0 11 | // Full source at https://github.com/rvera/image-picker 12 | // MIT License, https://github.com/rvera/image-picker/blob/master/LICENSE 13 | (function(){var ImagePicker,ImagePickerOption,both_array_are_equal,sanitized_options,bind=function(fn,me){return function(){return fn.apply(me,arguments)}},indexOf=[].indexOf||function(item){for(var i=0,l=this.length;i");this.picker_options=[];this.recursively_parse_option_groups(this.select,this.picker);return this.picker};ImagePicker.prototype.recursively_parse_option_groups=function(scoped_dom,target_container){var container,j,k,len,len1,option,option_group,ref,ref1,results;ref=scoped_dom.children("optgroup");for(j=0,len=ref.length;j");container.append(jQuery("
  • "+option_group.attr("label")+"
  • "));target_container.append(jQuery("
  • ").append(container));this.recursively_parse_option_groups(option_group,container)}ref1=function(){var l,len1,ref1,results1;ref1=scoped_dom.children("option");results1=[];for(l=0,len1=ref1.length;l0};ImagePicker.prototype.selected_values=function(){if(this.multiple){return this.select.val()||[]}else{return[this.select.val()]}};ImagePicker.prototype.toggle=function(imagepicker_option,original_event){var new_values,old_values,selected_value;old_values=this.selected_values();selected_value=imagepicker_option.value().toString();if(this.multiple){if(indexOf.call(this.selected_values(),selected_value)>=0){new_values=this.selected_values();new_values.splice(jQuery.inArray(selected_value,old_values),1);this.select.val([]);this.select.val(new_values)}else{if(this.opts.limit!=null&&this.selected_values().length>=this.opts.limit){if(this.opts.limit_reached!=null){this.opts.limit_reached.call(this.select)}}else{this.select.val(this.selected_values().concat(selected_value))}}}else{if(this.has_implicit_blanks()&&imagepicker_option.is_selected()){this.select.val("")}else{this.select.val(selected_value)}}if(!both_array_are_equal(old_values,this.selected_values())){this.select.change();if(this.opts.changed!=null){return this.opts.changed.call(this.select,old_values,this.selected_values(),original_event)}}};return ImagePicker}();ImagePickerOption=function(){function ImagePickerOption(option_element,picker,opts1){this.picker=picker;this.opts=opts1!=null?opts1:{};this.clicked=bind(this.clicked,this);this.option=jQuery(option_element);this.create_node()}ImagePickerOption.prototype.destroy=function(){return this.node.find(".thumbnail").off("click",this.clicked)};ImagePickerOption.prototype.has_image=function(){return this.option.data("img-src")!=null};ImagePickerOption.prototype.is_blank=function(){return!(this.value()!=null&&this.value()!=="")};ImagePickerOption.prototype.is_selected=function(){var select_value;select_value=this.picker.select.val();if(this.picker.multiple){return jQuery.inArray(this.value(),select_value)>=0}else{return this.value()===select_value}};ImagePickerOption.prototype.mark_as_selected=function(){return this.node.find(".thumbnail").addClass("selected")};ImagePickerOption.prototype.unmark_as_selected=function(){return this.node.find(".thumbnail").removeClass("selected")};ImagePickerOption.prototype.value=function(){return this.option.val()};ImagePickerOption.prototype.label=function(){if(this.option.data("img-label")){return this.option.data("img-label")}else{return this.option.text()}};ImagePickerOption.prototype.clicked=function(event){this.picker.toggle(this,event);if(this.opts.clicked!=null){this.opts.clicked.call(this.picker.select,this,event)}if(this.opts.selected!=null&&this.is_selected()){return this.opts.selected.call(this.picker.select,this,event)}};ImagePickerOption.prototype.create_node=function(){var image,imgAlt,imgClass,thumbnail;this.node=jQuery("
  • ");if(this.option.data("font_awesome")){image=jQuery("");image.attr("class","fa-fw "+this.option.data("img-src"))}else{image=jQuery("");image.attr("src",this.option.data("img-src"))}thumbnail=jQuery("
    ");imgClass=this.option.data("img-class");if(imgClass){this.node.addClass(imgClass);image.addClass(imgClass);thumbnail.addClass(imgClass)}imgAlt=this.option.data("img-alt");if(imgAlt){image.attr("alt",imgAlt)}thumbnail.on("click",this.clicked);thumbnail.append(image);if(this.opts.show_label){thumbnail.append(jQuery("

    ").html(this.label()))}this.node.append(thumbnail);return this.node};return ImagePickerOption}()}).call(this); -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | StyleMapGAN 3 | Copyright (c) 2021-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | 11 | import flask 12 | from flask import Flask, render_template, request, redirect, url_for 13 | import numpy as np 14 | import base64 15 | import os 16 | import secrets 17 | import argparse 18 | from PIL import Image 19 | 20 | ###### 21 | import torch 22 | from torch import nn 23 | from training.model import Generator, Encoder 24 | import torch.nn.functional as F 25 | import torchvision.transforms.functional as TF 26 | from torchvision import transforms 27 | import io 28 | 29 | app = Flask( 30 | __name__, 31 | template_folder="demo/templates", 32 | static_url_path="/demo/static", 33 | static_folder="demo/static", 34 | ) 35 | 36 | app.config["MAX_CONTENT_LENGTH"] = 10000000 # allow 10 MB post 37 | 38 | # for 1 gpu only. 39 | class Model(nn.Module): 40 | def __init__(self): 41 | super(Model, self).__init__() 42 | self.g_ema = Generator( 43 | train_args.size, 44 | train_args.mapping_layer_num, 45 | train_args.latent_channel_size, 46 | train_args.latent_spatial_size, 47 | lr_mul=train_args.lr_mul, 48 | channel_multiplier=train_args.channel_multiplier, 49 | normalize_mode=train_args.normalize_mode, 50 | small_generator=train_args.small_generator, 51 | ) 52 | self.e_ema = Encoder( 53 | train_args.size, 54 | train_args.latent_channel_size, 55 | train_args.latent_spatial_size, 56 | channel_multiplier=train_args.channel_multiplier, 57 | ) 58 | self.device = device 59 | 60 | def forward(self, original_image, references, masks, shift_values): 61 | 62 | combined = torch.cat([original_image, references], dim=0) 63 | 64 | ws = self.e_ema(combined) 65 | original_stylemap, reference_stylemaps = torch.split( 66 | ws, [1, len(ws) - 1], dim=0 67 | ) 68 | 69 | mixed = self.g_ema( 70 | [original_stylemap, reference_stylemaps], 71 | input_is_stylecode=True, 72 | mix_space="demo", 73 | mask=[masks, shift_values, args.interpolation_step], 74 | )[0] 75 | 76 | return mixed 77 | 78 | 79 | @app.route("/") 80 | def index(): 81 | image_paths = [] 82 | return render_template( 83 | "index.html", 84 | canvas_size=train_args.size, 85 | base_path=base_path, 86 | image_paths=list(os.listdir(base_path)), 87 | ) 88 | 89 | 90 | # "#010FFF" -> (1, 15, 255) 91 | def hex2val(hex): 92 | if len(hex) != 7: 93 | raise Exception("invalid hex") 94 | val = int(hex[1:], 16) 95 | return np.array([val >> 16, (val >> 8) & 255, val & 255]) 96 | 97 | 98 | @torch.no_grad() 99 | def my_morphed_images( 100 | original, references, masks, shift_values, interpolation=8, save_dir=None 101 | ): 102 | original_image = Image.open(base_path + original) 103 | reference_images = [] 104 | 105 | for ref in references: 106 | reference_images.append( 107 | TF.to_tensor( 108 | Image.open(base_path + ref).resize((train_args.size, train_args.size)) 109 | ) 110 | ) 111 | 112 | original_image = TF.to_tensor(original_image).unsqueeze(0) 113 | original_image = F.interpolate( 114 | original_image, size=(train_args.size, train_args.size) 115 | ) 116 | original_image = (original_image - 0.5) * 2 117 | 118 | reference_images = torch.stack(reference_images) 119 | reference_images = F.interpolate( 120 | reference_images, size=(train_args.size, train_args.size) 121 | ) 122 | reference_images = (reference_images - 0.5) * 2 123 | 124 | masks = masks[: len(references)] 125 | masks = torch.from_numpy(np.stack(masks)) 126 | 127 | original_image, reference_images, masks = ( 128 | original_image.to(device), 129 | reference_images.to(device), 130 | masks.to(device), 131 | ) 132 | 133 | mixed = model(original_image, reference_images, masks, shift_values).cpu() 134 | mixed = np.asarray( 135 | np.clip(mixed * 127.5 + 127.5, 0.0, 255.0), dtype=np.uint8 136 | ).transpose( 137 | (0, 2, 3, 1) 138 | ) # 0~255 139 | 140 | return mixed 141 | 142 | 143 | @app.route("/post", methods=["POST"]) 144 | def post(): 145 | if request.method == "POST": 146 | user_id = request.json["id"] 147 | original = request.json["original"] 148 | references = request.json["references"] 149 | colors = [hex2val(hex) for hex in request.json["colors"]] 150 | data_reference_bin = [] 151 | shift_values = request.json["shift_original"] 152 | save_dir = f"demo/static/generated/{user_id}" 153 | masks = [] 154 | 155 | if not os.path.exists(save_dir): 156 | os.makedirs(save_dir, exist_ok=True) 157 | 158 | for i, d_ref in enumerate(request.json["data_reference"]): 159 | data_reference_bin.append(base64.b64decode(d_ref)) 160 | 161 | with open(f"{save_dir}/classmap_reference_{i}.png", "wb") as f: 162 | f.write(bytearray(data_reference_bin[i])) 163 | 164 | for i in range(len(colors)): 165 | class_map = Image.open(io.BytesIO(data_reference_bin[i])) 166 | class_map = np.array(class_map)[:, :, :3] 167 | mask = np.array( 168 | (np.isclose(class_map, colors[i], atol=2.0)).all(axis=2), dtype=np.uint8 169 | ) # * 255 170 | mask = np.asarray(mask, dtype=np.float32).reshape( 171 | (1, mask.shape[0], mask.shape[1]) 172 | ) 173 | masks.append(mask) 174 | 175 | generated_images = my_morphed_images( 176 | original, 177 | references, 178 | masks, 179 | shift_values, 180 | interpolation=args.interpolation_step, 181 | save_dir=save_dir, 182 | ) 183 | paths = [] 184 | 185 | for i in range(args.interpolation_step): 186 | path = f"{save_dir}/{i}.png" 187 | Image.fromarray(generated_images[i]).save(path) 188 | paths.append(path + "?{}".format(secrets.token_urlsafe(16))) 189 | 190 | return flask.jsonify(result=paths) 191 | else: 192 | return redirect(url_for("index")) 193 | 194 | 195 | if __name__ == "__main__": 196 | parser = argparse.ArgumentParser() 197 | parser.add_argument( 198 | "--dataset", 199 | type=str, 200 | default="celeba_hq", 201 | choices=["celeba_hq", "afhq", "lsun/church_outdoor", "lsun/car"], 202 | ) 203 | parser.add_argument("--interpolation_step", type=int, default=16) 204 | parser.add_argument("--ckpt", type=str, required=True) 205 | parser.add_argument( 206 | "--MAX_CONTENT_LENGTH", type=int, default=10000000 207 | ) # allow maximum 10 MB POST 208 | args = parser.parse_args() 209 | 210 | device = "cuda" 211 | base_path = f"demo/static/components/img/{args.dataset}/" 212 | ckpt = torch.load(args.ckpt) 213 | 214 | train_args = ckpt["train_args"] 215 | print("train_args: ", train_args) 216 | 217 | model = Model().to(device) 218 | model.g_ema.load_state_dict(ckpt["g_ema"]) 219 | model.e_ema.load_state_dict(ckpt["e_ema"]) 220 | model.eval() 221 | 222 | app.debug = True 223 | app.run(host="127.0.0.1", port=6006) 224 | -------------------------------------------------------------------------------- /training/lpips/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/lpips/pretrained_networks.py 3 | Refer to https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/pretrained_networks.py 4 | """ 5 | from collections import namedtuple 6 | import torch 7 | from torchvision import models as tv 8 | from IPython import embed 9 | 10 | 11 | class squeezenet(torch.nn.Module): 12 | def __init__(self, requires_grad=False, pretrained=True): 13 | super(squeezenet, self).__init__() 14 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features 15 | self.slice1 = torch.nn.Sequential() 16 | self.slice2 = torch.nn.Sequential() 17 | self.slice3 = torch.nn.Sequential() 18 | self.slice4 = torch.nn.Sequential() 19 | self.slice5 = torch.nn.Sequential() 20 | self.slice6 = torch.nn.Sequential() 21 | self.slice7 = torch.nn.Sequential() 22 | self.N_slices = 7 23 | for x in range(2): 24 | self.slice1.add_module(str(x), pretrained_features[x]) 25 | for x in range(2, 5): 26 | self.slice2.add_module(str(x), pretrained_features[x]) 27 | for x in range(5, 8): 28 | self.slice3.add_module(str(x), pretrained_features[x]) 29 | for x in range(8, 10): 30 | self.slice4.add_module(str(x), pretrained_features[x]) 31 | for x in range(10, 11): 32 | self.slice5.add_module(str(x), pretrained_features[x]) 33 | for x in range(11, 12): 34 | self.slice6.add_module(str(x), pretrained_features[x]) 35 | for x in range(12, 13): 36 | self.slice7.add_module(str(x), pretrained_features[x]) 37 | if not requires_grad: 38 | for param in self.parameters(): 39 | param.requires_grad = False 40 | 41 | def forward(self, X): 42 | h = self.slice1(X) 43 | h_relu1 = h 44 | h = self.slice2(h) 45 | h_relu2 = h 46 | h = self.slice3(h) 47 | h_relu3 = h 48 | h = self.slice4(h) 49 | h_relu4 = h 50 | h = self.slice5(h) 51 | h_relu5 = h 52 | h = self.slice6(h) 53 | h_relu6 = h 54 | h = self.slice7(h) 55 | h_relu7 = h 56 | vgg_outputs = namedtuple( 57 | "SqueezeOutputs", 58 | ["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"], 59 | ) 60 | out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7) 61 | 62 | return out 63 | 64 | 65 | class alexnet(torch.nn.Module): 66 | def __init__(self, requires_grad=False, pretrained=True): 67 | super(alexnet, self).__init__() 68 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features 69 | self.slice1 = torch.nn.Sequential() 70 | self.slice2 = torch.nn.Sequential() 71 | self.slice3 = torch.nn.Sequential() 72 | self.slice4 = torch.nn.Sequential() 73 | self.slice5 = torch.nn.Sequential() 74 | self.N_slices = 5 75 | for x in range(2): 76 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 77 | for x in range(2, 5): 78 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 79 | for x in range(5, 8): 80 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 81 | for x in range(8, 10): 82 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 83 | for x in range(10, 12): 84 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 85 | if not requires_grad: 86 | for param in self.parameters(): 87 | param.requires_grad = False 88 | 89 | def forward(self, X): 90 | h = self.slice1(X) 91 | h_relu1 = h 92 | h = self.slice2(h) 93 | h_relu2 = h 94 | h = self.slice3(h) 95 | h_relu3 = h 96 | h = self.slice4(h) 97 | h_relu4 = h 98 | h = self.slice5(h) 99 | h_relu5 = h 100 | alexnet_outputs = namedtuple( 101 | "AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"] 102 | ) 103 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 104 | 105 | return out 106 | 107 | 108 | class vgg16(torch.nn.Module): 109 | def __init__(self, requires_grad=False, pretrained=True): 110 | super(vgg16, self).__init__() 111 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features 112 | self.slice1 = torch.nn.Sequential() 113 | self.slice2 = torch.nn.Sequential() 114 | self.slice3 = torch.nn.Sequential() 115 | self.slice4 = torch.nn.Sequential() 116 | self.slice5 = torch.nn.Sequential() 117 | self.N_slices = 5 118 | for x in range(4): 119 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 120 | for x in range(4, 9): 121 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 122 | for x in range(9, 16): 123 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 124 | for x in range(16, 23): 125 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 126 | for x in range(23, 30): 127 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 128 | if not requires_grad: 129 | for param in self.parameters(): 130 | param.requires_grad = False 131 | 132 | def forward(self, X): 133 | h = self.slice1(X) 134 | h_relu1_2 = h 135 | h = self.slice2(h) 136 | h_relu2_2 = h 137 | h = self.slice3(h) 138 | h_relu3_3 = h 139 | h = self.slice4(h) 140 | h_relu4_3 = h 141 | h = self.slice5(h) 142 | h_relu5_3 = h 143 | vgg_outputs = namedtuple( 144 | "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] 145 | ) 146 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 147 | 148 | return out 149 | 150 | 151 | class resnet(torch.nn.Module): 152 | def __init__(self, requires_grad=False, pretrained=True, num=18): 153 | super(resnet, self).__init__() 154 | if num == 18: 155 | self.net = tv.resnet18(pretrained=pretrained) 156 | elif num == 34: 157 | self.net = tv.resnet34(pretrained=pretrained) 158 | elif num == 50: 159 | self.net = tv.resnet50(pretrained=pretrained) 160 | elif num == 101: 161 | self.net = tv.resnet101(pretrained=pretrained) 162 | elif num == 152: 163 | self.net = tv.resnet152(pretrained=pretrained) 164 | self.N_slices = 5 165 | 166 | self.conv1 = self.net.conv1 167 | self.bn1 = self.net.bn1 168 | self.relu = self.net.relu 169 | self.maxpool = self.net.maxpool 170 | self.layer1 = self.net.layer1 171 | self.layer2 = self.net.layer2 172 | self.layer3 = self.net.layer3 173 | self.layer4 = self.net.layer4 174 | 175 | def forward(self, X): 176 | h = self.conv1(X) 177 | h = self.bn1(h) 178 | h = self.relu(h) 179 | h_relu1 = h 180 | h = self.maxpool(h) 181 | h = self.layer1(h) 182 | h_conv2 = h 183 | h = self.layer2(h) 184 | h_conv3 = h 185 | h = self.layer3(h) 186 | h_conv4 = h 187 | h = self.layer4(h) 188 | h_conv5 = h 189 | 190 | outputs = namedtuple("Outputs", ["relu1", "conv2", "conv3", "conv4", "conv5"]) 191 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 192 | 193 | return out 194 | -------------------------------------------------------------------------------- /training/lpips/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/lpips/__init__.py 3 | Refer to https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/__init__.py 4 | """ 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import numpy as np 10 | from skimage.measure import compare_ssim 11 | import torch 12 | from torch.autograd import Variable 13 | 14 | from . import dist_model 15 | 16 | 17 | class exportPerceptualLoss(torch.nn.Module): 18 | def __init__( 19 | self, model="net-lin", net="alex", colorspace="rgb", spatial=False, use_gpu=True 20 | ): # VGG using our perceptually-learned weights (LPIPS metric) 21 | super(exportPerceptualLoss, self).__init__() 22 | print("Setting up Perceptual loss...") 23 | self.use_gpu = use_gpu 24 | self.spatial = spatial 25 | self.model = dist_model.exportModel() 26 | self.model.initialize( 27 | model=model, 28 | net=net, 29 | use_gpu=use_gpu, 30 | colorspace=colorspace, 31 | spatial=self.spatial, 32 | ) 33 | print("...[%s] initialized" % self.model.name()) 34 | print("...Done") 35 | 36 | def forward(self, pred, target): 37 | return self.model.forward(target, pred) 38 | 39 | 40 | class PerceptualLoss(torch.nn.Module): 41 | def __init__( 42 | self, 43 | model="net-lin", 44 | net="alex", 45 | colorspace="rgb", 46 | spatial=False, 47 | use_gpu=True, 48 | gpu_ids=[0], 49 | ): # VGG using our perceptually-learned weights (LPIPS metric) 50 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 51 | super(PerceptualLoss, self).__init__() 52 | print("Setting up Perceptual loss...") 53 | self.use_gpu = use_gpu 54 | self.spatial = spatial 55 | self.gpu_ids = gpu_ids 56 | self.model = dist_model.DistModel() 57 | self.model.initialize( 58 | model=model, 59 | net=net, 60 | use_gpu=use_gpu, 61 | colorspace=colorspace, 62 | spatial=self.spatial, 63 | gpu_ids=gpu_ids, 64 | ) 65 | print("...[%s] initialized" % self.model.name()) 66 | print("...Done") 67 | 68 | def forward(self, pred, target, normalize=False): 69 | """ 70 | Pred and target are Variables. 71 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 72 | If normalize is False, assumes the images are already between [-1,+1] 73 | 74 | Inputs pred and target are Nx3xHxW 75 | Output pytorch Variable N long 76 | """ 77 | 78 | if normalize: 79 | target = 2 * target - 1 80 | pred = 2 * pred - 1 81 | 82 | return self.model.forward(target, pred) 83 | 84 | 85 | def normalize_tensor(in_feat, eps=1e-10): 86 | norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True)) 87 | return in_feat / (norm_factor + eps) 88 | 89 | 90 | def l2(p0, p1, range=255.0): 91 | return 0.5 * np.mean((p0 / range - p1 / range) ** 2) 92 | 93 | 94 | def psnr(p0, p1, peak=255.0): 95 | return 10 * np.log10(peak ** 2 / np.mean((1.0 * p0 - 1.0 * p1) ** 2)) 96 | 97 | 98 | def dssim(p0, p1, range=255.0): 99 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.0 100 | 101 | 102 | def rgb2lab(in_img, mean_cent=False): 103 | from skimage import color 104 | 105 | img_lab = color.rgb2lab(in_img) 106 | if mean_cent: 107 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 108 | return img_lab 109 | 110 | 111 | def tensor2np(tensor_obj): 112 | # change dimension of a tensor object into a numpy array 113 | return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0)) 114 | 115 | 116 | def np2tensor(np_obj): 117 | # change dimenion of np array into tensor array 118 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 119 | 120 | 121 | def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False): 122 | # image tensor to lab tensor 123 | from skimage import color 124 | 125 | img = tensor2im(image_tensor) 126 | img_lab = color.rgb2lab(img) 127 | if mc_only: 128 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 129 | if to_norm and not mc_only: 130 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 131 | img_lab = img_lab / 100.0 132 | 133 | return np2tensor(img_lab) 134 | 135 | 136 | def tensorlab2tensor(lab_tensor, return_inbnd=False): 137 | from skimage import color 138 | import warnings 139 | 140 | warnings.filterwarnings("ignore") 141 | 142 | lab = tensor2np(lab_tensor) * 100.0 143 | lab[:, :, 0] = lab[:, :, 0] + 50 144 | 145 | rgb_back = 255.0 * np.clip(color.lab2rgb(lab.astype("float")), 0, 1) 146 | if return_inbnd: 147 | # convert back to lab, see if we match 148 | lab_back = color.rgb2lab(rgb_back.astype("uint8")) 149 | mask = 1.0 * np.isclose(lab_back, lab, atol=2.0) 150 | mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis]) 151 | return (im2tensor(rgb_back), mask) 152 | else: 153 | return im2tensor(rgb_back) 154 | 155 | 156 | def rgb2lab(input): 157 | from skimage import color 158 | 159 | return color.rgb2lab(input / 255.0) 160 | 161 | 162 | def tensor2im(image_tensor, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): 163 | image_numpy = image_tensor[0].cpu().float().numpy() 164 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 165 | return image_numpy.astype(imtype) 166 | 167 | 168 | def im2tensor(image, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): 169 | return torch.Tensor( 170 | (image / factor - cent)[:, :, :, np.newaxis].transpose((3, 2, 0, 1)) 171 | ) 172 | 173 | 174 | def tensor2vec(vector_tensor): 175 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 176 | 177 | 178 | def voc_ap(rec, prec, use_07_metric=False): 179 | """ap = voc_ap(rec, prec, [use_07_metric]) 180 | Compute VOC AP given precision and recall. 181 | If use_07_metric is true, uses the 182 | VOC 07 11 point method (default:False). 183 | """ 184 | if use_07_metric: 185 | # 11 point metric 186 | ap = 0.0 187 | for t in np.arange(0.0, 1.1, 0.1): 188 | if np.sum(rec >= t) == 0: 189 | p = 0 190 | else: 191 | p = np.max(prec[rec >= t]) 192 | ap = ap + p / 11.0 193 | else: 194 | # correct AP calculation 195 | # first append sentinel values at the end 196 | mrec = np.concatenate(([0.0], rec, [1.0])) 197 | mpre = np.concatenate(([0.0], prec, [0.0])) 198 | 199 | # compute the precision envelope 200 | for i in range(mpre.size - 1, 0, -1): 201 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 202 | 203 | # to calculate area under PR curve, look for points 204 | # where X axis (recall) changes value 205 | i = np.where(mrec[1:] != mrec[:-1])[0] 206 | 207 | # and sum (\Delta recall) * prec 208 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 209 | return ap 210 | 211 | 212 | def tensor2im(image_tensor, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): 213 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 214 | image_numpy = image_tensor[0].cpu().float().numpy() 215 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 216 | return image_numpy.astype(imtype) 217 | 218 | 219 | def im2tensor(image, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): 220 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 221 | return torch.Tensor( 222 | (image / factor - cent)[:, :, :, np.newaxis].transpose((3, 2, 0, 1)) 223 | ) 224 | -------------------------------------------------------------------------------- /metrics/fid.py: -------------------------------------------------------------------------------- 1 | """ 2 | StyleMapGAN 3 | Copyright (c) 2021-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | 11 | import argparse 12 | import pickle 13 | import os 14 | import torch 15 | from torch import nn 16 | import numpy as np 17 | from scipy import linalg 18 | from tqdm import tqdm 19 | from training.model import Generator 20 | from metrics.calc_inception import load_patched_inception_v3 21 | from torchvision import utils, transforms 22 | from training.dataset import MultiResolutionDataset, DataSetFromDir 23 | from torch.utils import data 24 | from torch.nn import functional as F 25 | from PIL import Image 26 | from torch.utils.data import Dataset 27 | 28 | 29 | class DPModel(nn.Module): 30 | def __init__(self, device, model_args): 31 | super(DPModel, self).__init__() 32 | self.g_ema = Generator( 33 | model_args.size, 34 | model_args.mapping_layer_num, 35 | model_args.latent_channel_size, 36 | model_args.latent_spatial_size, 37 | lr_mul=model_args.lr_mul, 38 | channel_multiplier=model_args.channel_multiplier, 39 | normalize_mode=model_args.normalize_mode, 40 | small_generator=model_args.small_generator, 41 | ) 42 | 43 | def forward(self, real_img): 44 | z = real_img 45 | fake_img, _ = self.g_ema(z) 46 | 47 | return fake_img 48 | 49 | 50 | def data_sampler(dataset, shuffle): 51 | if shuffle: 52 | return data.RandomSampler(dataset) 53 | 54 | else: 55 | return data.SequentialSampler(dataset) 56 | 57 | 58 | def make_noise(batch, latent_channel_size, device): 59 | return torch.randn(batch, latent_channel_size, device=device) 60 | 61 | 62 | @torch.no_grad() 63 | def extract_feature_from_samples(generator, inception, batch_size, n_sample, device): 64 | n_batch = n_sample // batch_size 65 | resid = n_sample - (n_batch * batch_size) 66 | if resid > 0: 67 | batch_sizes = [batch_size] * n_batch + [resid] 68 | else: 69 | batch_sizes = [batch_size] * n_batch 70 | features = [] 71 | 72 | for batch in tqdm(batch_sizes): 73 | latent = make_noise(batch, train_args.latent_channel_size, device) 74 | imgs = generator(latent) 75 | imgs = (imgs + 1) / 2 # -1 ~ 1 to 0~1 76 | imgs = torch.clamp(imgs, 0, 1, out=None) 77 | imgs = F.interpolate(imgs, size=(height, width), mode="bilinear") 78 | transformed = [] 79 | 80 | for img in imgs: 81 | transformed.append(transforms.Normalize(mean=mean, std=std)(img)) 82 | 83 | transformed = torch.stack(transformed, dim=0) 84 | 85 | assert transformed.shape == imgs.shape 86 | feat = inception(transformed)[0].view(imgs.shape[0], -1) 87 | features.append(feat.to("cpu")) 88 | 89 | features = torch.cat(features, 0) 90 | 91 | return features 92 | 93 | 94 | def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6): 95 | cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False) 96 | 97 | if not np.isfinite(cov_sqrt).all(): 98 | print("product of cov matrices is singular") 99 | offset = np.eye(sample_cov.shape[0]) * eps 100 | cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset)) 101 | 102 | if np.iscomplexobj(cov_sqrt): 103 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 104 | m = np.max(np.abs(cov_sqrt.imag)) 105 | 106 | raise ValueError(f"Imaginary component {m}") 107 | 108 | cov_sqrt = cov_sqrt.real 109 | 110 | mean_diff = sample_mean - real_mean 111 | mean_norm = mean_diff @ mean_diff 112 | 113 | trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt) 114 | 115 | fid = mean_norm + trace 116 | 117 | return fid 118 | 119 | 120 | def extract_feature_from_generated_samples( 121 | inception, batch_size, n_sample, device, transform 122 | ): 123 | 124 | features = [] 125 | 126 | try: # from LMDB 127 | dataset = MultiResolutionDataset( 128 | args.generated_image_path, transform, args.size 129 | ) 130 | except: # from raw images 131 | dataset = DataSetFromDir(args.generated_image_path, transform) 132 | 133 | loader = data.DataLoader( 134 | dataset, 135 | batch_size=batch_size, 136 | sampler=data_sampler(dataset, shuffle=True), 137 | num_workers=args.num_workers, 138 | pin_memory=True, 139 | ) 140 | 141 | # generated images should match with n sample 142 | print(len(loader), n_sample, batch_size) 143 | 144 | if n_sample % batch_size == 0: 145 | assert len(loader) == n_sample // batch_size 146 | else: 147 | assert len(loader) == n_sample // batch_size + 1 148 | 149 | for i, real_img in enumerate(tqdm(loader)): 150 | real_img = real_img.to(device) 151 | 152 | if args.batch * (i + 1) > n_sample: 153 | real_img = real_img[: n_sample - args.batch * i] 154 | 155 | feat = inception(real_img)[0].view(real_img.shape[0], -1) 156 | features.append(feat.to("cpu")) 157 | 158 | if args.batch * (i + 1) > n_sample: 159 | break 160 | 161 | features = torch.cat(features, 0) 162 | print(len(features)) 163 | assert len(features) == n_sample 164 | 165 | return features 166 | 167 | 168 | if __name__ == "__main__": 169 | device = "cuda" 170 | 171 | parser = argparse.ArgumentParser() 172 | 173 | parser.add_argument("--batch", type=int, default=16) 174 | parser.add_argument("--size", type=int, default=256) 175 | parser.add_argument("--comparative_fid_pkl", type=str, default=None) 176 | parser.add_argument("--num_workers", type=int, default=10) 177 | parser.add_argument("--generated_image_path", type=str) 178 | parser.add_argument("--ckpt", metavar="CHECKPOINT") 179 | parser.add_argument("--dataset", type=str, required=True) 180 | 181 | args = parser.parse_args() 182 | assert ((args.generated_image_path is None) and (args.ckpt is not None)) or ( 183 | (args.generated_image_path is not None) and (args.ckpt is None) 184 | ) 185 | 186 | if args.dataset == "celeba_hq": 187 | n_sample = 29000 188 | elif args.dataset == "afhq": 189 | n_sample = 15130 190 | elif args.dataset in ["lsun/car", "lsun/church_outdoor"]: 191 | n_sample = 50000 192 | elif args.dataset == "ffhq": 193 | n_sample = 69000 194 | 195 | inception = nn.DataParallel(load_patched_inception_v3()).to(device) 196 | inception.eval() 197 | height, width = 299, 299 198 | mean = [0.485, 0.456, 0.406] 199 | std = [0.229, 0.224, 0.225] 200 | 201 | if args.ckpt: 202 | assert args.comparative_fid_pkl is not None 203 | 204 | ckpt = torch.load(args.ckpt) 205 | train_args = ckpt["train_args"] 206 | assert args.size == train_args.size 207 | model = DPModel(device, train_args).to(device) 208 | model.g_ema.load_state_dict(ckpt["g_ema"]) 209 | model = nn.DataParallel(model) 210 | model.eval() 211 | 212 | features = extract_feature_from_samples( 213 | model, inception, args.batch, n_sample, device 214 | ).numpy() 215 | 216 | else: 217 | transform = transforms.Compose( 218 | [ 219 | transforms.Resize([args.size, args.size]), 220 | transforms.Resize([height, width]), 221 | transforms.ToTensor(), 222 | transforms.Normalize(mean=mean, std=std), 223 | ] 224 | ) 225 | 226 | features = extract_feature_from_generated_samples( 227 | inception, args.batch, n_sample, device, transform 228 | ).numpy() 229 | 230 | print(f"extracted {features.shape[0]} features") 231 | 232 | sample_mean = np.mean(features, 0) 233 | sample_cov = np.cov(features, rowvar=False) 234 | 235 | if (args.generated_image_path is not None) and (args.comparative_fid_pkl is None): 236 | with open( 237 | f"metrics/fid_stats/{args.dataset}_stats_{args.size}_{n_sample}.pkl", 238 | "wb", 239 | ) as handle: 240 | pickle.dump({"mean": sample_mean, "cov": sample_cov}, handle) 241 | else: 242 | with open(args.comparative_fid_pkl, "rb") as f: 243 | embeds = pickle.load(f) 244 | real_mean = embeds["mean"] 245 | real_cov = embeds["cov"] 246 | 247 | fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov) 248 | 249 | print("fid:", fid) 250 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## StyleMapGAN - Official PyTorch Implementation 3 | 4 |

    5 | 6 | > **StyleMapGAN: Exploiting Spatial Dimensions of Latent in GAN for Real-time Image Editing**
    7 | > [Hyunsu Kim](https://github.com/blandocs), [Yunjey Choi](https://github.com/yunjey), [Junho Kim](https://github.com/taki0112), [Sungjoo Yoo](http://cmalab.snu.ac.kr), [Youngjung Uh](https://github.com/youngjung)
    8 | > In CVPR 2021.
    9 | 10 | > Paper: https://arxiv.org/abs/2104.14754
    11 | > 5-minute video (CVPR): https://www.youtube.com/watch?v=7sJqjm1qazk
    12 | > Demo video: https://youtu.be/qCapNyRA_Ng
    13 | 14 | > **Abstract:** *Generative adversarial networks (GANs) synthesize realistic images from random latent vectors. Although manipulating the latent vectors controls the synthesized outputs, editing real images with GANs suffers from i) time-consuming optimization for projecting real images to the latent vectors, ii) or inaccurate embedding through an encoder. We propose StyleMapGAN: the intermediate latent space has spatial dimensions, and a spatially variant modulation replaces AdaIN. It makes the embedding through an encoder more accurate than existing optimization-based methods while maintaining the properties of GANs. Experimental results demonstrate that our method significantly outperforms state-of-the-art models in various image manipulation tasks such as local editing and image interpolation. Last but not least, conventional editing methods on GANs are still valid on our StyleMapGAN. Source code is available at https://github.com/naver-ai/StyleMapGAN.* 15 | 16 | ## Demo 17 | 18 | Youtube video 19 | Click the figure to watch the teaser video. 20 | 21 | 22 |

    23 | 24 | Interactive demo app 25 | Run demo in your local machine. 26 | 27 | All test images are from [CelebA-HQ](https://arxiv.org/abs/1710.10196), [AFHQ](https://arxiv.org/abs/1912.01865), and [LSUN](https://www.yf.io/p/lsun). 28 | 29 | ```bash 30 | python demo.py --ckpt expr/checkpoints/celeba_hq_256_8x8.pt --dataset celeba_hq 31 | ``` 32 | 33 | ## Installation 34 | 35 | ![ubuntu](https://img.shields.io/badge/ubuntu-16.04.5_LTS-green.svg?style=plastic) ![gcc 7.4.0](https://img.shields.io/badge/gcc-7.4.0-green.svg?style=plastic) ![CUDA](https://img.shields.io/badge/CUDA-10.0.130-green.svg?style=plastic) ![CUDA-driver](https://img.shields.io/badge/CUDA_driver-410.72-green.svg?style=plastic) ![cudnn7](https://img.shields.io/badge/cudnn-7.6.3-green.svg?style=plastic) ![conda](https://img.shields.io/badge/conda-4.8.4-green.svg?style=plastic) ![Python 3.6.12](https://img.shields.io/badge/python-3.6.12-green.svg?style=plastic) ![pytorch 1.4.0](https://img.shields.io/badge/pytorch-1.4.0-green.svg?style=plastic) 36 | 37 | 38 | Clone this repository: 39 | 40 | ```bash 41 | git clone https://github.com/naver-ai/StyleMapGAN.git 42 | cd StyleMapGAN/ 43 | ``` 44 | 45 | Install the dependencies: 46 | ```bash 47 | conda create -y -n stylemapgan python=3.6.12 48 | conda activate stylemapgan 49 | ./install.sh 50 | ``` 51 | 52 | ## Datasets and pre-trained networks 53 | We provide a script to download datasets used in StyleMapGAN and the corresponding pre-trained networks. The datasets and network checkpoints will be downloaded and stored in the `data` and `expr/checkpoints` directories, respectively. 54 | 55 | CelebA-HQ. To download the CelebA-HQ dataset and parse it, run the following commands: 56 | 57 | ```bash 58 | # Download raw images and create LMDB datasets using them 59 | # Additional files are also downloaded for local editing 60 | bash download.sh create-lmdb-dataset celeba_hq 61 | 62 | # Download the pretrained network (256x256) 63 | bash download.sh download-pretrained-network-256 celeba_hq # 20M-image-trained models 64 | bash download.sh download-pretrained-network-256 celeba_hq_5M # 5M-image-trained models used in our paper for comparison with other baselines and for ablation studies. 65 | 66 | # Download the pretrained network (1024x1024 image / 16x16 stylemap / Light version of Generator) 67 | bash download.sh download-pretrained-network-1024 ffhq_16x16 68 | ``` 69 | 70 | AFHQ. For AFHQ, change above commands from 'celeba_hq' to 'afhq'. 71 | 72 | 73 | ## Train network 74 | Implemented using DistributedDataParallel. 75 | 76 | ```bash 77 | # CelebA-HQ 78 | python train.py --dataset celeba_hq --train_lmdb data/celeba_hq/LMDB_train --val_lmdb data/celeba_hq/LMDB_val 79 | 80 | # AFHQ 81 | python train.py --dataset afhq --train_lmdb data/afhq/LMDB_train --val_lmdb data/afhq/LMDB_val 82 | 83 | # CelebA-HQ / 1024x1024 image / 16x16 stylemap / Light version of Generator 84 | python train.py --size 1024 --latent_spatial_size 16 --small_generator --dataset celeba_hq --train_lmdb data/celeba_hq/LMDB_train --val_lmdb data/celeba_hq/LMDB_val 85 | ``` 86 | 87 | 88 | ## Generate images 89 | 90 | Reconstruction 91 | Results are saved to `expr/reconstruction`. 92 | 93 | ```bash 94 | # CelebA-HQ 95 | python generate.py --ckpt expr/checkpoints/celeba_hq_256_8x8.pt --mixing_type reconstruction --test_lmdb data/celeba_hq/LMDB_test 96 | 97 | # AFHQ 98 | python generate.py --ckpt expr/checkpoints/afhq_256_8x8.pt --mixing_type reconstruction --test_lmdb data/afhq/LMDB_test 99 | 100 | ``` 101 | 102 | W interpolation 103 | Results are saved to `expr/w_interpolation`. 104 | 105 | ```bash 106 | # CelebA-HQ 107 | python generate.py --ckpt expr/checkpoints/celeba_hq_256_8x8.pt --mixing_type w_interpolation --test_lmdb data/celeba_hq/LMDB_test 108 | 109 | # AFHQ 110 | python generate.py --ckpt expr/checkpoints/afhq_256_8x8.pt --mixing_type w_interpolation --test_lmdb data/afhq/LMDB_test 111 | ``` 112 | 113 | 114 | Local editing 115 | Results are saved to `expr/local_editing`. We pair images using a target semantic mask similarity. If you want to see details, please follow `preprocessor/README.md`. 116 | 117 | ```bash 118 | # Using GroundTruth(GT) segmentation masks for CelebA-HQ dataset. 119 | python generate.py --ckpt expr/checkpoints/celeba_hq_256_8x8.pt --mixing_type local_editing --test_lmdb data/celeba_hq/LMDB_test --local_editing_part nose 120 | 121 | # Using half-and-half masks for AFHQ dataset. 122 | python generate.py --ckpt expr/checkpoints/afhq_256_8x8.pt --mixing_type local_editing --test_lmdb data/afhq/LMDB_test 123 | ``` 124 | 125 | Unaligned transplantation 126 | Results are saved to `expr/transplantation`. It shows local transplantations examples of AFHQ. We recommend the demo code instead of this. 127 | 128 | ```bash 129 | python generate.py --ckpt expr/checkpoints/afhq_256_8x8.pt --mixing_type transplantation --test_lmdb data/afhq/LMDB_test 130 | ``` 131 | 132 | Random Generation 133 | Results are saved to `expr/random_generation`. It shows random generation examples. 134 | 135 | ```bash 136 | python generate.py --mixing_type random_generation --ckpt expr/checkpoints/celeba_hq_256_8x8.pt 137 | ``` 138 | 139 | Style Mixing 140 | Results are saved to `expr/stylemixing`. It shows style mixing examples. 141 | 142 | ```bash 143 | python generate.py --mixing_type stylemixing --ckpt expr/checkpoints/celeba_hq_256_8x8.pt --test_lmdb data/celeba_hq/LMDB_test 144 | ``` 145 | 146 | Semantic Manipulation 147 | Results are saved to `expr/semantic_manipulation`. It shows local semantic manipulation examples. 148 | 149 | ```bash 150 | python semantic_manipulation.py --ckpt expr/checkpoints/celeba_hq_256_8x8.pt --LMDB data/celeba_hq/LMDB --svm_train_iter 10000 151 | ``` 152 | 153 | ## Metrics 154 | 155 | * Reconstruction: LPIPS, MSE 156 | * W interpolation: FIDlerp 157 | * Generation: FID 158 | * Local editing: MSEsrc, MSEref, Detectability (Refer to [CNNDetection](https://github.com/PeterWang512/CNNDetection)) 159 | 160 | If you want to see details, please follow `metrics/README.md`. 161 | 162 | ## License 163 | The source code, pre-trained models, and dataset are available under [Creative Commons BY-NC 4.0](LICENSE) license by NAVER Corporation. You can **use, copy, tranform and build upon** the material for **non-commercial purposes** as long as you give **appropriate credit** by citing our paper, and indicate if changes were made. 164 | 165 | For business inquiries, please contact clova-jobs@navercorp.com.
    166 | For technical and other inquires, please contact hyunsu1125.kim@navercorp.com. 167 | 168 | ## Citation 169 | If you find this work useful for your research, please cite our paper: 170 | ``` 171 | @inproceedings{kim2021stylemapgan, 172 | title={Exploiting Spatial Dimensions of Latent in GAN for Real-time Image Editing}, 173 | author={Kim, Hyunsu and Choi, Yunjey and Kim, Junho and Yoo, Sungjoo and Uh, Youngjung}, 174 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 175 | year={2021} 176 | } 177 | ``` 178 | ## Related Projects 179 | 180 | Model code starts from [StyleGAN2 PyTorch unofficial code](https://github.com/rosinality/stylegan2-pytorch), which refers to [StyleGAN2 official code](https://github.com/NVlabs/stylegan2). 181 | [LPIPS](https://github.com/richzhang/PerceptualSimilarity), [FID](https://github.com/mseitzer/pytorch-fid), and [CNNDetection](https://github.com/PeterWang512/CNNDetection) codes are used for evaluation. 182 | In semantic manipulation, we used [StyleGAN pretrained network](https://github.com/NVlabs/stylegan) to get positive and negative samples by ranking. 183 | The demo code starts from [Neural-Collage](https://github.com/quolc/neural-collage#web-based-demos). 184 | -------------------------------------------------------------------------------- /training/lpips/networks_basic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/lpips/networks_basic.py 3 | Refer to https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py 4 | """ 5 | from __future__ import absolute_import 6 | import sys 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.init as init 10 | from torch.autograd import Variable 11 | import numpy as np 12 | from pdb import set_trace as st 13 | from skimage import color 14 | from IPython import embed 15 | from . import pretrained_networks as pn 16 | 17 | from training import lpips as util 18 | 19 | 20 | def spatial_average(in_tens, keepdim=True): 21 | return in_tens.mean([2, 3], keepdim=keepdim) 22 | 23 | 24 | def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W 25 | in_H = in_tens.shape[2] 26 | scale_factor = 1.0 * out_H / in_H 27 | 28 | return nn.Upsample(scale_factor=scale_factor, mode="bilinear", align_corners=False)( 29 | in_tens 30 | ) 31 | 32 | 33 | # Learned perceptual metric 34 | class PNetLin(nn.Module): 35 | def __init__( 36 | self, 37 | pnet_type="vgg", 38 | pnet_rand=False, 39 | pnet_tune=False, 40 | use_dropout=True, 41 | spatial=False, 42 | version="0.1", 43 | lpips=True, 44 | ): 45 | super(PNetLin, self).__init__() 46 | 47 | self.pnet_type = pnet_type 48 | self.pnet_tune = pnet_tune 49 | self.pnet_rand = pnet_rand 50 | self.spatial = spatial 51 | self.lpips = lpips 52 | self.version = version 53 | self.scaling_layer = ScalingLayer() 54 | 55 | if self.pnet_type in ["vgg", "vgg16"]: 56 | net_type = pn.vgg16 57 | self.chns = [64, 128, 256, 512, 512] 58 | elif self.pnet_type == "alex": 59 | net_type = pn.alexnet 60 | self.chns = [64, 192, 384, 256, 256] 61 | elif self.pnet_type == "squeeze": 62 | net_type = pn.squeezenet 63 | self.chns = [64, 128, 256, 384, 384, 512, 512] 64 | self.L = len(self.chns) 65 | 66 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) 67 | 68 | if lpips: 69 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 70 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 71 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 72 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 73 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 74 | self.lins = nn.ModuleList( 75 | [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 76 | ) 77 | 78 | if self.pnet_type == "squeeze": # 7 layers for squeezenet 79 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) 80 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) 81 | self.lins.extend([self.lin5, self.lin6]) 82 | 83 | def forward(self, in0, in1, retPerLayer=False): 84 | # v0.0 - original release had a bug, where input was not scaled 85 | in0_input, in1_input = ( 86 | (self.scaling_layer(in0), self.scaling_layer(in1)) 87 | if self.version == "0.1" 88 | else (in0, in1) 89 | ) 90 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 91 | feats0, feats1, diffs = {}, {}, {} 92 | 93 | for kk in range(self.L): 94 | feats0[kk], feats1[kk] = ( 95 | util.normalize_tensor(outs0[kk]), 96 | util.normalize_tensor(outs1[kk]), 97 | ) 98 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 99 | 100 | if self.lpips: 101 | if self.spatial: 102 | res = [ 103 | upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) 104 | for kk in range(self.L) 105 | ] 106 | else: 107 | res = [ 108 | spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) 109 | for kk in range(self.L) 110 | ] 111 | else: 112 | if self.spatial: 113 | res = [ 114 | upsample(diffs[kk].sum(dim=1, keepdim=True), out_H=in0.shape[2]) 115 | for kk in range(self.L) 116 | ] 117 | else: 118 | res = [ 119 | spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) 120 | for kk in range(self.L) 121 | ] 122 | 123 | val = res[0] 124 | for l in range(1, self.L): 125 | val += res[l] 126 | 127 | if retPerLayer: 128 | return (val, res) 129 | else: 130 | return val 131 | 132 | 133 | class ScalingLayer(nn.Module): 134 | def __init__(self): 135 | super(ScalingLayer, self).__init__() 136 | self.register_buffer( 137 | "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] 138 | ) 139 | self.register_buffer( 140 | "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] 141 | ) 142 | 143 | def forward(self, inp): 144 | return (inp - self.shift) / self.scale 145 | 146 | 147 | class NetLinLayer(nn.Module): 148 | """ A single linear layer which does a 1x1 conv """ 149 | 150 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 151 | super(NetLinLayer, self).__init__() 152 | 153 | layers = ( 154 | [ 155 | nn.Dropout(), 156 | ] 157 | if (use_dropout) 158 | else [] 159 | ) 160 | layers += [ 161 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), 162 | ] 163 | self.model = nn.Sequential(*layers) 164 | 165 | 166 | class Dist2LogitLayer(nn.Module): 167 | """ takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) """ 168 | 169 | def __init__(self, chn_mid=32, use_sigmoid=True): 170 | super(Dist2LogitLayer, self).__init__() 171 | 172 | layers = [ 173 | nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True), 174 | ] 175 | layers += [ 176 | nn.LeakyReLU(0.2, True), 177 | ] 178 | layers += [ 179 | nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True), 180 | ] 181 | layers += [ 182 | nn.LeakyReLU(0.2, True), 183 | ] 184 | layers += [ 185 | nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True), 186 | ] 187 | if use_sigmoid: 188 | layers += [ 189 | nn.Sigmoid(), 190 | ] 191 | self.model = nn.Sequential(*layers) 192 | 193 | def forward(self, d0, d1, eps=0.1): 194 | return self.model.forward( 195 | torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1) 196 | ) 197 | 198 | 199 | class BCERankingLoss(nn.Module): 200 | def __init__(self, chn_mid=32): 201 | super(BCERankingLoss, self).__init__() 202 | self.net = Dist2LogitLayer(chn_mid=chn_mid) 203 | # self.parameters = list(self.net.parameters()) 204 | self.loss = torch.nn.BCELoss() 205 | 206 | def forward(self, d0, d1, judge): 207 | per = (judge + 1.0) / 2.0 208 | self.logit = self.net.forward(d0, d1) 209 | return self.loss(self.logit, per) 210 | 211 | 212 | # L2, DSSIM training 213 | class FakeNet(nn.Module): 214 | def __init__(self, use_gpu=True, colorspace="Lab"): 215 | super(FakeNet, self).__init__() 216 | self.use_gpu = use_gpu 217 | self.colorspace = colorspace 218 | 219 | 220 | class L2(FakeNet): 221 | def forward(self, in0, in1, retPerLayer=None): 222 | assert in0.size()[0] == 1 # currently only supports batchSize 1 223 | 224 | if self.colorspace == "RGB": 225 | (N, C, X, Y) = in0.size() 226 | value = torch.mean( 227 | torch.mean( 228 | torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2 229 | ).view(N, 1, 1, Y), 230 | dim=3, 231 | ).view(N) 232 | return value 233 | elif self.colorspace == "Lab": 234 | value = util.l2( 235 | util.tensor2np(util.tensor2tensorlab(in0.data, to_norm=False)), 236 | util.tensor2np(util.tensor2tensorlab(in1.data, to_norm=False)), 237 | range=100.0, 238 | ).astype("float") 239 | ret_var = Variable(torch.Tensor((value,))) 240 | if self.use_gpu: 241 | ret_var = ret_var.cuda() 242 | return ret_var 243 | 244 | 245 | class DSSIM(FakeNet): 246 | def forward(self, in0, in1, retPerLayer=None): 247 | assert in0.size()[0] == 1 # currently only supports batchSize 1 248 | 249 | if self.colorspace == "RGB": 250 | value = util.dssim( 251 | 1.0 * util.tensor2im(in0.data), 252 | 1.0 * util.tensor2im(in1.data), 253 | range=255.0, 254 | ).astype("float") 255 | elif self.colorspace == "Lab": 256 | value = util.dssim( 257 | util.tensor2np(util.tensor2tensorlab(in0.data, to_norm=False)), 258 | util.tensor2np(util.tensor2tensorlab(in1.data, to_norm=False)), 259 | range=100.0, 260 | ).astype("float") 261 | ret_var = Variable(torch.Tensor((value,))) 262 | if self.use_gpu: 263 | ret_var = ret_var.cuda() 264 | return ret_var 265 | 266 | 267 | def print_network(net): 268 | num_params = 0 269 | for param in net.parameters(): 270 | num_params += param.numel() 271 | print("Network", net) 272 | print("Total number of parameters: %d" % num_params) 273 | -------------------------------------------------------------------------------- /training/op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 19 | int c = a / b; 20 | 21 | if (c * b > a) { 22 | c--; 23 | } 24 | 25 | return c; 26 | } 27 | 28 | 29 | struct UpFirDn2DKernelParams { 30 | int up_x; 31 | int up_y; 32 | int down_x; 33 | int down_y; 34 | int pad_x0; 35 | int pad_x1; 36 | int pad_y0; 37 | int pad_y1; 38 | 39 | int major_dim; 40 | int in_h; 41 | int in_w; 42 | int minor_dim; 43 | int kernel_h; 44 | int kernel_w; 45 | int out_h; 46 | int out_w; 47 | int loop_major; 48 | int loop_x; 49 | }; 50 | 51 | 52 | template 53 | __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) { 54 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 55 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 56 | 57 | __shared__ volatile float sk[kernel_h][kernel_w]; 58 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 59 | 60 | int minor_idx = blockIdx.x; 61 | int tile_out_y = minor_idx / p.minor_dim; 62 | minor_idx -= tile_out_y * p.minor_dim; 63 | tile_out_y *= tile_out_h; 64 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 65 | int major_idx_base = blockIdx.z * p.loop_major; 66 | 67 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { 68 | return; 69 | } 70 | 71 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { 72 | int ky = tap_idx / kernel_w; 73 | int kx = tap_idx - ky * kernel_w; 74 | scalar_t v = 0.0; 75 | 76 | if (kx < p.kernel_w & ky < p.kernel_h) { 77 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 78 | } 79 | 80 | sk[ky][kx] = v; 81 | } 82 | 83 | for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { 84 | for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { 85 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 86 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 87 | int tile_in_x = floor_div(tile_mid_x, up_x); 88 | int tile_in_y = floor_div(tile_mid_y, up_y); 89 | 90 | __syncthreads(); 91 | 92 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { 93 | int rel_in_y = in_idx / tile_in_w; 94 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 95 | int in_x = rel_in_x + tile_in_x; 96 | int in_y = rel_in_y + tile_in_y; 97 | 98 | scalar_t v = 0.0; 99 | 100 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 101 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; 102 | } 103 | 104 | sx[rel_in_y][rel_in_x] = v; 105 | } 106 | 107 | __syncthreads(); 108 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { 109 | int rel_out_y = out_idx / tile_out_w; 110 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 111 | int out_x = rel_out_x + tile_out_x; 112 | int out_y = rel_out_y + tile_out_y; 113 | 114 | int mid_x = tile_mid_x + rel_out_x * down_x; 115 | int mid_y = tile_mid_y + rel_out_y * down_y; 116 | int in_x = floor_div(mid_x, up_x); 117 | int in_y = floor_div(mid_y, up_y); 118 | int rel_in_x = in_x - tile_in_x; 119 | int rel_in_y = in_y - tile_in_y; 120 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 121 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 122 | 123 | scalar_t v = 0.0; 124 | 125 | #pragma unroll 126 | for (int y = 0; y < kernel_h / up_y; y++) 127 | #pragma unroll 128 | for (int x = 0; x < kernel_w / up_x; x++) 129 | v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; 130 | 131 | if (out_x < p.out_w & out_y < p.out_h) { 132 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; 133 | } 134 | } 135 | } 136 | } 137 | } 138 | 139 | 140 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 141 | int up_x, int up_y, int down_x, int down_y, 142 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 143 | int curDevice = -1; 144 | cudaGetDevice(&curDevice); 145 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 146 | 147 | UpFirDn2DKernelParams p; 148 | 149 | auto x = input.contiguous(); 150 | auto k = kernel.contiguous(); 151 | 152 | p.major_dim = x.size(0); 153 | p.in_h = x.size(1); 154 | p.in_w = x.size(2); 155 | p.minor_dim = x.size(3); 156 | p.kernel_h = k.size(0); 157 | p.kernel_w = k.size(1); 158 | p.up_x = up_x; 159 | p.up_y = up_y; 160 | p.down_x = down_x; 161 | p.down_y = down_y; 162 | p.pad_x0 = pad_x0; 163 | p.pad_x1 = pad_x1; 164 | p.pad_y0 = pad_y0; 165 | p.pad_y1 = pad_y1; 166 | 167 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; 168 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; 169 | 170 | auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 171 | 172 | int mode = -1; 173 | 174 | int tile_out_h; 175 | int tile_out_w; 176 | 177 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 178 | mode = 1; 179 | tile_out_h = 16; 180 | tile_out_w = 64; 181 | } 182 | 183 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { 184 | mode = 2; 185 | tile_out_h = 16; 186 | tile_out_w = 64; 187 | } 188 | 189 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 190 | mode = 3; 191 | tile_out_h = 16; 192 | tile_out_w = 64; 193 | } 194 | 195 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { 196 | mode = 4; 197 | tile_out_h = 16; 198 | tile_out_w = 64; 199 | } 200 | 201 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { 202 | mode = 5; 203 | tile_out_h = 8; 204 | tile_out_w = 32; 205 | } 206 | 207 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { 208 | mode = 6; 209 | tile_out_h = 8; 210 | tile_out_w = 32; 211 | } 212 | 213 | dim3 block_size; 214 | dim3 grid_size; 215 | 216 | if (tile_out_h > 0 && tile_out_w) { 217 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 218 | p.loop_x = 1; 219 | block_size = dim3(32 * 8, 1, 1); 220 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 221 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 222 | (p.major_dim - 1) / p.loop_major + 1); 223 | } 224 | 225 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 226 | switch (mode) { 227 | case 1: 228 | upfirdn2d_kernel<<>>( 229 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 230 | ); 231 | 232 | break; 233 | 234 | case 2: 235 | upfirdn2d_kernel<<>>( 236 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 237 | ); 238 | 239 | break; 240 | 241 | case 3: 242 | upfirdn2d_kernel<<>>( 243 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 244 | ); 245 | 246 | break; 247 | 248 | case 4: 249 | upfirdn2d_kernel<<>>( 250 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 251 | ); 252 | 253 | break; 254 | 255 | case 5: 256 | upfirdn2d_kernel<<>>( 257 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 258 | ); 259 | 260 | break; 261 | 262 | case 6: 263 | upfirdn2d_kernel<<>>( 264 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 265 | ); 266 | 267 | break; 268 | } 269 | }); 270 | 271 | return out; 272 | } -------------------------------------------------------------------------------- /metrics/inception.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/inception.py 3 | Refer to https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/inception.py 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torchvision import models 10 | 11 | try: 12 | from torchvision.models.utils import load_state_dict_from_url 13 | except ImportError: 14 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 15 | 16 | # Inception weights ported to Pytorch from 17 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 18 | # FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' 19 | FID_WEIGHTS_URL = "metrics/pt_inception-2015-12-05-6726825d.pth" 20 | 21 | 22 | class InceptionV3(nn.Module): 23 | """Pretrained InceptionV3 network returning feature maps""" 24 | 25 | # Index of default block of inception to return, 26 | # corresponds to output of final average pooling 27 | DEFAULT_BLOCK_INDEX = 3 28 | 29 | # Maps feature dimensionality to their output blocks indices 30 | BLOCK_INDEX_BY_DIM = { 31 | 64: 0, # First max pooling features 32 | 192: 1, # Second max pooling featurs 33 | 768: 2, # Pre-aux classifier features 34 | 2048: 3, # Final average pooling features 35 | } 36 | 37 | def __init__( 38 | self, 39 | output_blocks=[DEFAULT_BLOCK_INDEX], 40 | resize_input=True, 41 | normalize_input=True, 42 | requires_grad=False, 43 | use_fid_inception=True, 44 | ): 45 | """Build pretrained InceptionV3 46 | 47 | Parameters 48 | ---------- 49 | output_blocks : list of int 50 | Indices of blocks to return features of. Possible values are: 51 | - 0: corresponds to output of first max pooling 52 | - 1: corresponds to output of second max pooling 53 | - 2: corresponds to output which is fed to aux classifier 54 | - 3: corresponds to output of final average pooling 55 | resize_input : bool 56 | If true, bilinearly resizes input to width and height 299 before 57 | feeding input to model. As the network without fully connected 58 | layers is fully convolutional, it should be able to handle inputs 59 | of arbitrary size, so resizing might not be strictly needed 60 | normalize_input : bool 61 | If true, scales the input from range (0, 1) to the range the 62 | pretrained Inception network expects, namely (-1, 1) 63 | requires_grad : bool 64 | If true, parameters of the model require gradients. Possibly useful 65 | for finetuning the network 66 | use_fid_inception : bool 67 | If true, uses the pretrained Inception model used in Tensorflow's 68 | FID implementation. If false, uses the pretrained Inception model 69 | available in torchvision. The FID Inception model has different 70 | weights and a slightly different structure from torchvision's 71 | Inception model. If you want to compute FID scores, you are 72 | strongly advised to set this parameter to true to get comparable 73 | results. 74 | """ 75 | super(InceptionV3, self).__init__() 76 | 77 | self.resize_input = resize_input 78 | self.normalize_input = normalize_input 79 | self.output_blocks = sorted(output_blocks) 80 | self.last_needed_block = max(output_blocks) 81 | 82 | assert self.last_needed_block <= 3, "Last possible output block index is 3" 83 | 84 | self.blocks = nn.ModuleList() 85 | 86 | if use_fid_inception: 87 | inception = fid_inception_v3() 88 | else: 89 | inception = models.inception_v3(pretrained=True) 90 | 91 | # Block 0: input to maxpool1 92 | block0 = [ 93 | inception.Conv2d_1a_3x3, 94 | inception.Conv2d_2a_3x3, 95 | inception.Conv2d_2b_3x3, 96 | nn.MaxPool2d(kernel_size=3, stride=2), 97 | ] 98 | self.blocks.append(nn.Sequential(*block0)) 99 | 100 | # Block 1: maxpool1 to maxpool2 101 | if self.last_needed_block >= 1: 102 | block1 = [ 103 | inception.Conv2d_3b_1x1, 104 | inception.Conv2d_4a_3x3, 105 | nn.MaxPool2d(kernel_size=3, stride=2), 106 | ] 107 | self.blocks.append(nn.Sequential(*block1)) 108 | 109 | # Block 2: maxpool2 to aux classifier 110 | if self.last_needed_block >= 2: 111 | block2 = [ 112 | inception.Mixed_5b, 113 | inception.Mixed_5c, 114 | inception.Mixed_5d, 115 | inception.Mixed_6a, 116 | inception.Mixed_6b, 117 | inception.Mixed_6c, 118 | inception.Mixed_6d, 119 | inception.Mixed_6e, 120 | ] 121 | self.blocks.append(nn.Sequential(*block2)) 122 | 123 | # Block 3: aux classifier to final avgpool 124 | if self.last_needed_block >= 3: 125 | block3 = [ 126 | inception.Mixed_7a, 127 | inception.Mixed_7b, 128 | inception.Mixed_7c, 129 | nn.AdaptiveAvgPool2d(output_size=(1, 1)), 130 | ] 131 | self.blocks.append(nn.Sequential(*block3)) 132 | 133 | for param in self.parameters(): 134 | param.requires_grad = requires_grad 135 | 136 | def forward(self, inp): 137 | """Get Inception feature maps 138 | 139 | Parameters 140 | ---------- 141 | inp : torch.autograd.Variable 142 | Input tensor of shape Bx3xHxW. Values are expected to be in 143 | range (0, 1) 144 | 145 | Returns 146 | ------- 147 | List of torch.autograd.Variable, corresponding to the selected output 148 | block, sorted ascending by index 149 | """ 150 | outp = [] 151 | x = inp 152 | 153 | if self.resize_input: 154 | x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=False) 155 | 156 | if self.normalize_input: 157 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 158 | 159 | for idx, block in enumerate(self.blocks): 160 | x = block(x) 161 | if idx in self.output_blocks: 162 | outp.append(x) 163 | 164 | if idx == self.last_needed_block: 165 | break 166 | 167 | return outp 168 | 169 | 170 | def fid_inception_v3(): 171 | """Build pretrained Inception model for FID computation 172 | 173 | The Inception model for FID computation uses a different set of weights 174 | and has a slightly different structure than torchvision's Inception. 175 | 176 | This method first constructs torchvision's Inception and then patches the 177 | necessary parts that are different in the FID Inception model. 178 | """ 179 | 180 | # this takes too long time in scipy 1.4.1. so downgrade to 1.3.3 181 | inception = models.inception_v3( 182 | num_classes=1008, aux_logits=False, pretrained=False 183 | ) 184 | 185 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 186 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 187 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 188 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 189 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 190 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 191 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 192 | inception.Mixed_7b = FIDInceptionE_1(1280) 193 | inception.Mixed_7c = FIDInceptionE_2(2048) 194 | 195 | # change to local pth 196 | # state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 197 | state_dict = torch.load(FID_WEIGHTS_URL) 198 | inception.load_state_dict(state_dict) 199 | 200 | return inception 201 | 202 | 203 | class FIDInceptionA(models.inception.InceptionA): 204 | """InceptionA block patched for FID computation""" 205 | 206 | def __init__(self, in_channels, pool_features): 207 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 208 | 209 | def forward(self, x): 210 | branch1x1 = self.branch1x1(x) 211 | 212 | branch5x5 = self.branch5x5_1(x) 213 | branch5x5 = self.branch5x5_2(branch5x5) 214 | 215 | branch3x3dbl = self.branch3x3dbl_1(x) 216 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 217 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 218 | 219 | # Patch: Tensorflow's average pool does not use the padded zero's in 220 | # its average calculation 221 | branch_pool = F.avg_pool2d( 222 | x, kernel_size=3, stride=1, padding=1, count_include_pad=False 223 | ) 224 | branch_pool = self.branch_pool(branch_pool) 225 | 226 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 227 | return torch.cat(outputs, 1) 228 | 229 | 230 | class FIDInceptionC(models.inception.InceptionC): 231 | """InceptionC block patched for FID computation""" 232 | 233 | def __init__(self, in_channels, channels_7x7): 234 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 235 | 236 | def forward(self, x): 237 | branch1x1 = self.branch1x1(x) 238 | 239 | branch7x7 = self.branch7x7_1(x) 240 | branch7x7 = self.branch7x7_2(branch7x7) 241 | branch7x7 = self.branch7x7_3(branch7x7) 242 | 243 | branch7x7dbl = self.branch7x7dbl_1(x) 244 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 245 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 246 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 247 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 248 | 249 | # Patch: Tensorflow's average pool does not use the padded zero's in 250 | # its average calculation 251 | branch_pool = F.avg_pool2d( 252 | x, kernel_size=3, stride=1, padding=1, count_include_pad=False 253 | ) 254 | branch_pool = self.branch_pool(branch_pool) 255 | 256 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 257 | return torch.cat(outputs, 1) 258 | 259 | 260 | class FIDInceptionE_1(models.inception.InceptionE): 261 | """First InceptionE block patched for FID computation""" 262 | 263 | def __init__(self, in_channels): 264 | super(FIDInceptionE_1, self).__init__(in_channels) 265 | 266 | def forward(self, x): 267 | branch1x1 = self.branch1x1(x) 268 | 269 | branch3x3 = self.branch3x3_1(x) 270 | branch3x3 = [ 271 | self.branch3x3_2a(branch3x3), 272 | self.branch3x3_2b(branch3x3), 273 | ] 274 | branch3x3 = torch.cat(branch3x3, 1) 275 | 276 | branch3x3dbl = self.branch3x3dbl_1(x) 277 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 278 | branch3x3dbl = [ 279 | self.branch3x3dbl_3a(branch3x3dbl), 280 | self.branch3x3dbl_3b(branch3x3dbl), 281 | ] 282 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 283 | 284 | # Patch: Tensorflow's average pool does not use the padded zero's in 285 | # its average calculation 286 | branch_pool = F.avg_pool2d( 287 | x, kernel_size=3, stride=1, padding=1, count_include_pad=False 288 | ) 289 | branch_pool = self.branch_pool(branch_pool) 290 | 291 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 292 | return torch.cat(outputs, 1) 293 | 294 | 295 | class FIDInceptionE_2(models.inception.InceptionE): 296 | """Second InceptionE block patched for FID computation""" 297 | 298 | def __init__(self, in_channels): 299 | super(FIDInceptionE_2, self).__init__(in_channels) 300 | 301 | def forward(self, x): 302 | branch1x1 = self.branch1x1(x) 303 | 304 | branch3x3 = self.branch3x3_1(x) 305 | branch3x3 = [ 306 | self.branch3x3_2a(branch3x3), 307 | self.branch3x3_2b(branch3x3), 308 | ] 309 | branch3x3 = torch.cat(branch3x3, 1) 310 | 311 | branch3x3dbl = self.branch3x3dbl_1(x) 312 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 313 | branch3x3dbl = [ 314 | self.branch3x3dbl_3a(branch3x3dbl), 315 | self.branch3x3dbl_3b(branch3x3dbl), 316 | ] 317 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 318 | 319 | # Patch: The FID Inception model uses max pooling instead of average 320 | # pooling. This is likely an error in this specific Inception 321 | # implementation, as other Inception models use average pooling here 322 | # (which matches the description in the paper). 323 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 324 | branch_pool = self.branch_pool(branch_pool) 325 | 326 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 327 | return torch.cat(outputs, 1) 328 | -------------------------------------------------------------------------------- /demo/static/components/js/main.js: -------------------------------------------------------------------------------- 1 | /* 2 | StyleMapGAN 3 | Copyright (c) 2021-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | */ 10 | 11 | // Refer to https://github.com/quolc/neural-collage/blob/master/static/demo_feature_blending/js/main.js 12 | 13 | max_colors = 3; 14 | colors = ["#FE2712", "#66B032", "#FEFE33"]; 15 | original_image = null; 16 | palette = []; 17 | palette_selected_index = null; 18 | 19 | ui_uninitialized = true; 20 | p5_input_original = null; 21 | p5_input_reference = null; 22 | p5_output = null; 23 | sync_flag = true; 24 | id = null; 25 | 26 | function ReferenceNameSpace() { 27 | return function (s) { 28 | s.setup = function () { 29 | s.pixelDensity(1); 30 | s.createCanvas(canvas_size, canvas_size); 31 | 32 | s.mask = []; 33 | for (var i = 0; i < max_colors; i++) { 34 | s.mask.push(s.createGraphics(canvas_size, canvas_size)); 35 | } 36 | 37 | s.body = null; 38 | s.cursor(s.HAND); 39 | 40 | } 41 | 42 | s.draw = function () { 43 | s.background(255); 44 | s.noTint(); 45 | if (s.body != null) { 46 | s.image(s.body, 0, 0, s.width, s.height); 47 | } 48 | s.tint(255, 127); 49 | 50 | if (palette_selected_index != null) 51 | s.image(s.mask[palette_selected_index], 0, 0); 52 | } 53 | 54 | s.mouseDragged = function () { 55 | if (ui_uninitialized) return; 56 | 57 | var c = $('.palette-item.selected').data('class'); 58 | if (c != -1) { 59 | var col = s.color(colors[palette.indexOf(c)]); 60 | s.mask[palette_selected_index].noStroke(); 61 | s.mask[palette_selected_index].fill(col); 62 | s.mask[palette_selected_index].ellipse(s.mouseX, s.mouseY, 20, 20); 63 | 64 | } else { // eraser 65 | if (sync_flag == true) { 66 | var col = s.color(0, 0); 67 | erase_size = 20; 68 | s.mask[palette_selected_index].loadPixels(); 69 | for (var x = Math.max(0, Math.floor(s.mouseX) - erase_size); x < Math.min(canvas_size, Math.floor(s.mouseX) + erase_size); x++) { 70 | for (var y = Math.max(0, Math.floor(s.mouseY) - erase_size); y < Math.min(canvas_size, Math.floor(s.mouseY) + erase_size); y++) { 71 | if (s.dist(s.mouseX, s.mouseY, x, y) < erase_size) { 72 | s.mask[palette_selected_index].set(x, y, col); 73 | } 74 | } 75 | } 76 | s.mask[palette_selected_index].updatePixels(); 77 | 78 | // p5.Graphics object should be re-created because of a bug related to updatePixels(). 79 | for (var update_g = 0; update_g < max_colors; update_g++) { 80 | var new_g = s.createGraphics(canvas_size, canvas_size); 81 | new_g.image(s.mask[update_g], 0, 0); 82 | s.mask[update_g].remove(); 83 | s.mask[update_g] = new_g; 84 | } 85 | } 86 | } 87 | } 88 | 89 | s.clear_canvas = function () { 90 | for (var i = 0; i < max_colors; i++) { 91 | s.mask[i].clear(); 92 | } 93 | s.body = null; 94 | } 95 | 96 | s.updateImage = function (url) { 97 | s.body = s.loadImage(url); 98 | } 99 | } 100 | } 101 | 102 | function OriginalNameSpace() { 103 | return function (s) { 104 | s.setup = function () { 105 | s.pixelDensity(1); 106 | s.createCanvas(canvas_size, canvas_size); 107 | s.body = null; 108 | s.cursor(s.HAND); 109 | 110 | s.r_x = Array(max_colors).fill(0); 111 | s.r_y = Array(max_colors).fill(0); 112 | s.d_x = Array(max_colors).fill(0); 113 | s.d_y = Array(max_colors).fill(0); 114 | mousePressed_here = false; 115 | } 116 | 117 | 118 | s.draw = function () { 119 | s.background(255); 120 | s.noTint(); 121 | if (s.body != null) { 122 | s.image(s.body, 0, 0, s.width, s.height); 123 | } 124 | s.tint(255, 127); 125 | 126 | for (var i = 0; i < max_colors; i++) { 127 | s.image(p5_input_reference.mask[i], s.r_x[i], s.r_y[i]); 128 | } 129 | 130 | } 131 | 132 | s.mousePressed = function (e) { 133 | s.d_x[palette_selected_index] = s.mouseX; 134 | s.d_y[palette_selected_index] = s.mouseY; 135 | 136 | if (s.mouseX <= s.width && s.mouseX >= 0 && s.mouseY <= s.height && s.mouseY >= 0) { 137 | s.mousePressed_here = true; 138 | } 139 | } 140 | 141 | s.mouseReleased = function (e) { 142 | s.mousePressed_here = false; 143 | } 144 | 145 | s.mouseDragged = function (e) { 146 | if (ui_uninitialized || s.mousePressed_here == false) return; 147 | if (s.mouseX <= s.width && s.mouseX >= 0 && s.mouseY <= s.height && s.mouseY >= 0) { 148 | 149 | s.r_x[palette_selected_index] += s.mouseX - s.d_x[palette_selected_index]; 150 | s.r_y[palette_selected_index] += s.mouseY - s.d_y[palette_selected_index]; 151 | 152 | s.d_x[palette_selected_index] = s.mouseX; 153 | s.d_y[palette_selected_index] = s.mouseY; 154 | } 155 | } 156 | 157 | s.updateImage = function (url) { 158 | s.body = s.loadImage(url); 159 | } 160 | 161 | 162 | s.clear_canvas = function () { 163 | s.body = null; 164 | 165 | for (var i = 0; i < max_colors; i++) { 166 | s.r_x[i] = 0; 167 | s.r_y[i] = 0; 168 | s.d_x[i] = 0; 169 | s.d_y[i] = 0; 170 | } 171 | 172 | } 173 | 174 | } 175 | } 176 | 177 | function generateOutputNameSpace() { 178 | return function (s) { 179 | s.setup = function () { 180 | s.pixelDensity(1); 181 | s.createCanvas(canvas_size, canvas_size); 182 | 183 | s.images = []; 184 | s.currentImage = 0; 185 | s.frameRate(15); 186 | } 187 | 188 | s.draw = function () { 189 | s.background(255); 190 | if (s.images.length > s.currentImage) { 191 | s.background(255); 192 | s.image(s.images[s.currentImage], 0, 0, s.width, s.height); 193 | } 194 | } 195 | 196 | s.updateImages = function (urls) { 197 | for (var i = urls.length - 1; i >= 0; i--) { 198 | var img = s.loadImage(urls[i]); 199 | s.images[i] = img; 200 | } 201 | s.currentImage = urls.length - 1; 202 | } 203 | 204 | s.changeCurrentImage = function (index) { 205 | if (index < s.images.length) { 206 | s.currentImage = index; 207 | } 208 | } 209 | 210 | s.clear_canvas = function () { 211 | s.images = []; 212 | s.currentImage = 0; 213 | } 214 | } 215 | } 216 | 217 | function updateResult() { 218 | disableUI(); 219 | 220 | var canvas_reference = $('#p5-reference canvas').slice(1); 221 | var data_reference = []; 222 | 223 | for (var canvas_i = 0; canvas_i < max_colors; canvas_i++) { 224 | data_reference.push(canvas_reference[canvas_i].toDataURL('image/png').replace(/data:image\/png;base64,/, '')); 225 | } 226 | 227 | $.ajax({ 228 | type: "POST", 229 | url: "/post", 230 | data: JSON.stringify({ "id": id, "original": original_image, "references": palette, "data_reference": data_reference, "shift_original": [p5_input_original.r_x, p5_input_original.r_y], "colors": colors }), 231 | dataType: "json", 232 | contentType: "application/json", 233 | }).done(function (data, textStatus, jqXHR) { 234 | 235 | let urls = data['result']; 236 | 237 | $('#ex1').slider({ 'max': urls.length - 1, "setValue": urls.length - 1 }); 238 | p5_output.updateImages(urls); 239 | 240 | $("#ex1").attr('data-slider-value', urls.length - 1); 241 | $("#ex1").slider('refresh'); 242 | 243 | enableUI(); 244 | }); 245 | } 246 | 247 | function enableUI() { 248 | ui_uninitialized = false; 249 | $("button").removeAttr('disabled'); 250 | $('#ex1').slider('enable'); 251 | } 252 | 253 | function disableUI() { 254 | ui_uninitialized = true; 255 | $("button").attr('disabled', true); 256 | $('#ex1').slider('disable'); 257 | } 258 | 259 | 260 | $(function () { 261 | $("#main-ui-submit").click(function () { 262 | updateResult(); 263 | }); 264 | 265 | $("#sketch-clear").click(function () { 266 | p5_input_reference.clear_canvas(); 267 | p5_input_original.clear_canvas(); 268 | p5_output.clear_canvas() 269 | $('.palette-item-class').remove(); 270 | palette = []; 271 | original_image = null; 272 | palette_selected_index = null; 273 | $("#palette-eraser").click(); 274 | 275 | $("#sketch-clear").attr('disabled', true); 276 | $("#main-ui-submit").attr('disabled', true); 277 | }); 278 | 279 | for (var i = 0; i < image_paths.length; i++) { 280 | var image_name = image_paths[i]; 281 | 282 | $("#class-picker").append( 283 | '' 284 | ); 285 | } 286 | 287 | $("#class-picker").imagepicker({ 288 | hide_select: false, 289 | }); 290 | $('#class-picker').after( 291 | "" 292 | ); 293 | $('#class-picker').after( 294 | "" 295 | ); 296 | 297 | $("#class-picker-submit-reference").after( 298 | "
    " 299 | ) 300 | $("#class-picker").appendTo("#class-picker-ui"); 301 | $("#class-picker-submit-reference").appendTo("#class-picker-ui"); 302 | $("#class-picker-submit-original").appendTo("#class-picker-ui"); 303 | 304 | $("#class-picker-submit-reference").click(function () { 305 | const selected_class = $("#class-picker").val(); 306 | const image_name_without_ext = selected_class.split('.').slice(0, -1).join('.'); 307 | 308 | if (palette.length >= max_colors || palette.indexOf(selected_class) != -1) { 309 | return; 310 | } 311 | 312 | $("#palette-body").append( 313 | "
  • "); 316 | 317 | $("#palette-" + image_name_without_ext).append( 318 | "" 319 | ) 320 | 321 | palette.push((selected_class)); 322 | 323 | $("#palette-" + image_name_without_ext).click(function () { 324 | $(".palette-item.selected").removeClass('selected'); 325 | $(this).addClass('selected'); 326 | p5_input_reference.updateImage(base_path + selected_class); 327 | palette_selected_index = palette.indexOf(selected_class); 328 | 329 | }); 330 | $("#palette-" + image_name_without_ext).click(); 331 | palette_selected_index = palette.indexOf(selected_class); 332 | if (palette.length > 0 && original_image != null) { 333 | enableUI(); 334 | } 335 | }); 336 | 337 | $("#class-picker-submit-original").click(function () { 338 | selected_class = $("#class-picker").val(); 339 | p5_input_original.updateImage(base_path + selected_class); 340 | original_image = selected_class; 341 | if (palette.length > 0 && original_image != null) { 342 | enableUI(); 343 | } 344 | }); 345 | 346 | $("#palette-eraser").click(function () { 347 | $(".palette-item.selected").removeClass('selected'); 348 | $(this).addClass('selected'); 349 | }); 350 | 351 | $('#ex1').slider({ 352 | formatter: function (value) { 353 | return 'interpolation: ' + (value / (16 - 1)).toFixed(2); 354 | } 355 | }); 356 | $('#ex1').slider('disable'); 357 | $("#ex1").change(function () { 358 | p5_output.changeCurrentImage(parseInt($("#ex1").val())); 359 | }); 360 | 361 | p5_input_reference = new p5(ReferenceNameSpace(), "p5-reference"); 362 | p5_input_original = new p5(OriginalNameSpace(), "p5-original"); 363 | p5_output = new p5(generateOutputNameSpace(), "p5-right"); 364 | 365 | // https://cofs.tistory.com/363 366 | var getCookie = function (name) { 367 | var value = document.cookie.match('(^|;) ?' + name + '=([^;]*)(;|$)'); 368 | return value ? value[2] : null; 369 | }; 370 | 371 | var setCookie = function (name, value, day) { 372 | var date = new Date(); 373 | date.setTime(date.getTime() + day * 60 * 60 * 24 * 1000); 374 | document.cookie = name + '=' + value + ';expires=' + date.toUTCString() + ';path=/'; 375 | }; 376 | 377 | id = getCookie("id"); 378 | 379 | if (id == null) { 380 | // https://stackoverflow.com/questions/1349404/generate-random-string-characters-in-javascript 381 | length = 20; 382 | var result = ''; 383 | var characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'; 384 | var charactersLength = characters.length; 385 | for (var i = 0; i < length; i++) { 386 | result += characters.charAt(Math.floor(Math.random() * charactersLength)); 387 | } 388 | id = result; 389 | setCookie("id", result, 1); 390 | } 391 | }) 392 | --------------------------------------------------------------------------------