├── .gitignore ├── LICENSE ├── README.md ├── configs ├── im2im_multiple_base.yaml ├── im2im_single.yaml ├── td_multiple.yaml ├── td_multiple_base.yaml ├── td_single_car.yaml ├── td_single_cat.yaml ├── td_single_church.yaml ├── td_single_ffhq.yaml └── td_single_horse.yaml ├── core ├── __init__.py ├── dataset.py ├── loss.py ├── lpips │ ├── __init__.py │ ├── base_model.py │ ├── dist_model.py │ ├── networks_basic.py │ ├── pretrained_networks.py │ └── weights │ │ ├── v0.0 │ │ ├── alex.pth │ │ ├── squeeze.pth │ │ └── vgg.pth │ │ └── v0.1 │ │ ├── alex.pth │ │ ├── squeeze.pth │ │ └── vgg.pth ├── mappers.py ├── parametrizations.py ├── style_embed_options.py ├── stylegan_patches.py ├── uda_models.py └── utils │ ├── II2S.py │ ├── __init__.py │ ├── arguments.py │ ├── class_registry.py │ ├── common.py │ ├── download_weights.py │ ├── example_utils.py │ ├── fid.py │ ├── image_utils.py │ ├── loggers.py │ ├── loss_utils.py │ ├── math_utils.py │ ├── text_templates.py │ └── train_log.py ├── download.py ├── examples ├── celeb_latents │ ├── Chris.npy │ ├── Gakki.npy │ ├── Green_Lantern.npy │ ├── Morgan.npy │ ├── Obama.npy │ ├── Oprah.npy │ ├── Pichai.npy │ ├── Rock.npy │ ├── Scarlett.npy │ ├── Su.npy │ └── Yui.npy ├── custom_images │ ├── .gitignore │ └── elon_musk.jpeg ├── editing_playground.ipynb ├── evaluation_example.ipynb └── inference_playground.ipynb ├── gan_models ├── BigGAN │ ├── BigGAN.py │ ├── __init__.py │ ├── generator_config.json │ ├── layers.py │ ├── sync_batchnorm │ │ ├── __init__.py │ │ ├── batchnorm.py │ │ ├── batchnorm_reimpl.py │ │ ├── comm.py │ │ ├── replicate.py │ │ └── unittest.py │ └── utils.py ├── ProgGAN │ ├── __init__.py │ └── model.py ├── SNGAN │ ├── __init__.py │ ├── distribution.py │ ├── load.py │ └── sn_gen_resnet.py ├── StyleGAN2 │ ├── convert_weight.py │ ├── model.py │ ├── offsets_model.py │ ├── op │ │ ├── __init__.py │ │ ├── fused_act.py │ │ ├── fused_act_torch_native.py │ │ ├── fused_bias_act.cpp │ │ ├── fused_bias_act_kernel.cu │ │ ├── upfirdn2d.cpp │ │ ├── upfirdn2d.py │ │ ├── upfirdn2d_kernel.cu │ │ └── upfirdn2d_torch_native.py │ └── prop_convert.py ├── __init__.py ├── gan_load.py ├── gan_with_shift.py └── imagenet_classes.json ├── image_domains ├── anastasia.png ├── anime.jpg ├── anime_full.jpg ├── brave.png ├── cached_latents │ └── .gitignore ├── detroit.png ├── digital_painting_jing.png ├── disney_princess.jpg ├── doc_brown.png ├── image_domains.txt ├── jojo.png ├── joker.png ├── mermaid.png ├── moana.png ├── picasso.png ├── pocahontas.png ├── room_girl.png ├── sketch.png ├── speed_paint.png ├── titan_armin.png ├── titan_erwin.png ├── titan_historia.png └── zbrush_girl.png ├── img ├── cover.jpg ├── domain_modulation.png ├── example_im2im_anastasia.jpg ├── example_td_anime.jpg ├── example_td_mapper_20.jpg ├── example_td_mapper_large.jpg └── hdn_diagram.png ├── main.py ├── requirements.txt ├── restyle_encoders ├── __init__.py ├── download.py ├── e4e.py ├── e4e_modules │ ├── __init__.py │ ├── discriminator.py │ └── latent_codes_pool.py ├── encoders │ ├── __init__.py │ ├── fpn_encoders.py │ ├── helpers.py │ ├── map2style.py │ ├── model_irse.py │ ├── restyle_e4e_encoders.py │ └── restyle_psp_encoders.py ├── mtcnn │ ├── __init__.py │ ├── mtcnn.py │ └── mtcnn_pytorch │ │ ├── __init__.py │ │ └── src │ │ ├── __init__.py │ │ ├── align_trans.py │ │ ├── box_utils.py │ │ ├── detector.py │ │ ├── first_stage.py │ │ ├── get_nets.py │ │ ├── matlab_cp2tform.py │ │ ├── visualization_utils.py │ │ └── weights │ │ ├── onet.npy │ │ ├── pnet.npy │ │ └── rnet.npy ├── psp.py └── stylegan2 │ ├── __init__.py │ ├── model.py │ └── op │ ├── __init__.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── text_domains ├── close_test_domains_2000.txt ├── domain_list_20.txt ├── domain_list_stress.txt ├── domain_list_style.txt ├── domain_list_test.txt ├── far_test_domains_2000.txt ├── mixed_content_train_domains.txt ├── mixed_launch_content_domains.txt ├── mixed_launch_style_domains.txt ├── mixed_style_train_domains.txt ├── mixed_train_domains.txt ├── much_domains.txt ├── train_domains_1000.txt ├── train_domains_1500.txt ├── train_domains_2000.txt └── train_synonyms_domains.txt └── trainers.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/.idea/* 2 | __pycache__/ 3 | .ipynb_checkpoints/ 4 | pretrained/ 5 | test_** 6 | wandb/ 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Vadim Titov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/im2im_multiple_base.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | config_dir: /Users/vadimtitov/CLIPResearch/configs/ 3 | config: multidomain.yaml 4 | project: Test 5 | name: Test 6 | tags: 7 | - mapper 8 | seed: 12 9 | root: . 10 | notes: empty notes 11 | step_save: 10000 12 | trainer: im2im_multiple_base 13 | dump_metrics: false 14 | generalisation: 15 | mixing_noise: 0.9 16 | training: 17 | iter_num: 800 18 | batch_size: 8 19 | device: cuda:0 20 | generator: stylegan2 21 | patch_key: cin_mult 22 | train_styles: image_domains/image_domains.txt 23 | mixing_noise: 0.9 24 | mapper_config: 25 | backbone_type: shared 26 | mapper_type: residual_channelin 27 | activation: relu 28 | input_dimension: 512 29 | width: 512 30 | head_depth: 2 31 | backbone_depth: 6 32 | no_coarse: true 33 | no_fine: false 34 | no_medium: false 35 | optimization_setup: 36 | visual_encoders: 37 | - ViT-B/32 38 | - ViT-B/16 39 | loss_funcs: 40 | - direction 41 | - clip_ref 42 | - l2_rec 43 | - lpips_rec 44 | - tt_direction 45 | - offsets_l2 46 | loss_coefs: 47 | - 1.0 48 | - 2.0 49 | - 1.0 50 | - 1.0 51 | - 1.0 52 | - 2.0 53 | optimizer: 54 | weight_decay: 0.0 55 | lr: 0.00005 56 | betas: 57 | - 0.9 58 | - 0.999 59 | generator_args: 60 | checkpoint_path: pretrained/StyleGAN2/stylegan2-ffhq-config-f.pt 61 | logging: 62 | log_every: 10 63 | log_images: 20 64 | latents_to_edit: [] 65 | image_embedding_log: false 66 | truncation: 0.7 67 | num_grid_outputs: 1 68 | checkpointing: 69 | is_on: false 70 | start_from: false 71 | step_backup: 10000000 -------------------------------------------------------------------------------- /configs/im2im_single.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | config_dir: configs 3 | config: im2im_single.yaml 4 | project: Test 5 | name: Test 6 | seed: 0 7 | root: . 8 | notes: empty notes 9 | step_save: 10 10 | trainer: im2im_single 11 | tags: 12 | - test 13 | training: 14 | iter_num: 100 15 | batch_size: 4 16 | device: cuda:0 17 | generator: stylegan2 18 | phase: mapping 19 | patch_key: cin_mult 20 | source_class: Real Person 21 | target_class: image_domains/mermaid.png 22 | auto_layer_k: 16 23 | auto_layer_iters: 0 24 | auto_layer_batch: 8 25 | mixing_noise: 0.9 26 | optimization_setup: 27 | visual_encoders: 28 | - ViT-B/32 29 | - ViT-B/16 30 | loss_funcs: 31 | - direction 32 | - clip_within 33 | - clip_ref 34 | - l2_rec 35 | - lpips_rec 36 | loss_coefs: 37 | - 1.0 38 | - 0.5 39 | - 30.0 40 | - 10.0 41 | - 10.0 42 | g_reg_every: 4 43 | optimizer: 44 | weight_decay: 0.0 45 | lr: 0.01 46 | betas: 47 | - 0.9 48 | - 0.999 49 | generator_args: 50 | checkpoint_path: pretrained/StyleGAN2/stylegan2-ffhq-config-f.pt 51 | logging: 52 | log_every: 10 53 | log_images: 20 54 | latents_to_edit: [] 55 | truncation: 0.5 56 | num_grid_outputs: 1 57 | checkpointing: 58 | is_on: false 59 | start_from: false 60 | step_backup: 500 61 | -------------------------------------------------------------------------------- /configs/td_multiple.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | config_dir: /Users/vadimtitov/CLIPResearch/configs/ 3 | config: multidomain.yaml 4 | project: TdMultiple 5 | name: Test 6 | tags: 7 | - mapper 8 | seed: 12 9 | root: . 10 | notes: empty notes 11 | step_save: 1000 12 | trainer: td_multiple_resample_and_convex 13 | training: 14 | iter_num: 10000 15 | batch_size: 8 16 | device: cuda:0 17 | generator: stylegan2 18 | patch_key: cin_mult 19 | train_domain_list: text_domains/domain_list_20.txt 20 | test_domain_list: text_domains/domain_list_20.txt 21 | mixing_noise: 0.9 22 | convex_hull: 23 | do: true 24 | resample: 25 | do: true 26 | divergence: 0.95 27 | mapper_config: 28 | backbone_type: shared 29 | mapper_type: residual_channelin 30 | activation: relu 31 | input_dimension: 512 32 | width: 512 33 | head_depth: 2 34 | backbone_depth: 4 35 | no_coarse: false 36 | no_fine: false 37 | no_medium: false 38 | optimization_setup: 39 | visual_encoders: 40 | - ViT-B/32 41 | - ViT-B/16 42 | loss_funcs: 43 | - direction 44 | - tt_direction 45 | loss_coefs: 46 | - 1.0 47 | - 0.4 48 | optimizer: 49 | weight_decay: 0.0 50 | lr: 0.00005 51 | betas: 52 | - 0.9 53 | - 0.999 54 | generator_args: 55 | checkpoint_path: pretrained/StyleGAN2/stylegan2-ffhq-config-f.pt 56 | logging: 57 | log_every: 10 58 | log_images: 100 59 | latents_to_edit: [] 60 | image_embedding_log: false 61 | truncation: 0.7 62 | num_grid_outputs: 1 63 | checkpointing: 64 | is_on: false 65 | start_from: false 66 | step_backup: 10000000 -------------------------------------------------------------------------------- /configs/td_multiple_base.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | config_dir: /Users/vadimtitov/CLIPResearch/configs/ 3 | config: multidomain.yaml 4 | project: Test 5 | name: Test 6 | tags: 7 | - mapper 8 | seed: 12 9 | root: . 10 | notes: empty notes 11 | step_save: 10000 12 | trainer: td_multiple_base 13 | training: 14 | iter_num: 800 15 | batch_size: 12 16 | device: cuda:0 17 | generator: stylegan2 18 | patch_key: cin_mult 19 | train_domain_list: text_domains/domain_list_20.txt 20 | test_domain_list: text_domains/domain_list_20.txt 21 | mixing_noise: 0.9 22 | mapper_config: 23 | backbone_type: shared 24 | mapper_type: residual_channelin 25 | activation: relu 26 | input_dimension: 512 27 | width: 512 28 | head_depth: 2 29 | backbone_depth: 4 30 | no_coarse: false 31 | no_fine: false 32 | no_medium: false 33 | optimization_setup: 34 | visual_encoders: 35 | - ViT-B/32 36 | - ViT-B/16 37 | loss_funcs: 38 | - direction 39 | loss_coefs: 40 | - 1.0 41 | optimizer: 42 | weight_decay: 0.0 43 | lr: 0.00005 44 | betas: 45 | - 0.9 46 | - 0.999 47 | generator_args: 48 | checkpoint_path: pretrained/StyleGAN2/stylegan2-ffhq-config-f.pt 49 | logging: 50 | log_every: 10 51 | log_images: 20 52 | latents_to_edit: [] 53 | image_embedding_log: false 54 | truncation: 0.7 55 | num_grid_outputs: 1 56 | checkpointing: 57 | is_on: false 58 | start_from: false 59 | step_backup: 10000000 -------------------------------------------------------------------------------- /configs/td_single_car.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | config_dir: configs 3 | config: td_single_car.yaml 4 | project: Test 5 | tags: 6 | - after demodulation 7 | name: Test 8 | seed: 0 9 | root: . 10 | notes: empty notes 11 | logging: true 12 | step_save: 100 13 | trainer: td_single 14 | training: 15 | iter_num: 400 16 | batch_size: 8 17 | device: cuda:0 18 | generator: stylegan2 19 | phase: mapping 20 | patch_key: cin_mult 21 | source_class: Car 22 | target_class: Golden Car 23 | auto_layer_k: 16 24 | auto_layer_iters: 0 25 | auto_layer_batch: 8 26 | mixing_noise: 0.9 27 | optimization_setup: 28 | visual_encoders: 29 | - ViT-B/32 30 | - ViT-B/16 31 | loss_funcs: 32 | - direction 33 | - indomain 34 | loss_coefs: 35 | - 1.0 36 | - 1.5 37 | g_reg_every: 4 38 | optimizer: 39 | weight_decay: 0.0 40 | lr: 0.02 41 | betas: 42 | - 0.9 43 | - 0.999 44 | generator_args: 45 | checkpoint_path: pretrained/StyleGAN2/stylegan2-car-config-f.pt 46 | img_size: 512 47 | logging: 48 | log_every: 10 49 | log_images: 20 50 | latents_to_edit: [] 51 | truncation: 0.5 52 | num_grid_outputs: 1 53 | checkpointing: 54 | is_on: false 55 | start_from: false 56 | step_backup: 100000 -------------------------------------------------------------------------------- /configs/td_single_cat.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | config_dir: configs 3 | config: td_single_cat.yaml 4 | project: Test 5 | tags: 6 | - after demodulation 7 | name: Test 8 | seed: 12 9 | root: . 10 | notes: empty notes 11 | logging: true 12 | step_save: 100 13 | trainer: td_single 14 | training: 15 | iter_num: 400 16 | batch_size: 8 17 | device: cuda:0 18 | generator: stylegan2 19 | phase: mapping 20 | patch_key: cin_mult 21 | source_class: Cat 22 | target_class: Lion 23 | auto_layer_k: 16 24 | auto_layer_iters: 0 25 | auto_layer_batch: 8 26 | mixing_noise: 0.9 27 | optimization_setup: 28 | visual_encoders: 29 | - ViT-B/32 30 | - ViT-B/16 31 | loss_funcs: 32 | - direction 33 | - offsets_l2 34 | loss_coefs: 35 | - 1.0 36 | - 0.5 37 | g_reg_every: 4 38 | optimizer: 39 | weight_decay: 0.0 40 | lr: 0.02 41 | betas: 42 | - 0.9 43 | - 0.999 44 | generator_args: 45 | checkpoint_path: pretrained/StyleGAN2/stylegan2-cat-config-f.pt 46 | img_size: 256 47 | logging: 48 | log_every: 10 49 | log_images: 20 50 | latents_to_edit: [] 51 | truncation: 0.5 52 | num_grid_outputs: 1 53 | checkpointing: 54 | is_on: false 55 | start_from: false 56 | step_backup: 100000 -------------------------------------------------------------------------------- /configs/td_single_church.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | config_dir: configs 3 | config: td_single_church.yaml 4 | project: Test 5 | tags: 6 | - after demodulation 7 | name: Test 8 | seed: 0 9 | root: . 10 | notes: empty notes 11 | logging: true 12 | step_save: 100 13 | trainer: td_single 14 | training: 15 | iter_num: 100 16 | batch_size: 8 17 | device: cuda:0 18 | generator: stylegan2 19 | phase: mapping 20 | patch_key: cin_mult 21 | source_class: Church 22 | target_class: Hut 23 | auto_layer_k: 16 24 | auto_layer_iters: 0 25 | auto_layer_batch: 8 26 | mixing_noise: 0.9 27 | optimization_setup: 28 | visual_encoders: 29 | - ViT-B/32 30 | - ViT-B/16 31 | loss_funcs: 32 | - direction 33 | - offsets_l2 34 | loss_coefs: 35 | - 1.0 36 | - 0.5 37 | g_reg_every: 4 38 | optimizer: 39 | weight_decay: 0.0 40 | lr: 0.02 41 | betas: 42 | - 0.0 43 | - 0.999 44 | generator_args: 45 | checkpoint_path: pretrained/StyleGAN2/stylegan2-church-config-f.pt 46 | img_size: 256 47 | logging: 48 | log_every: 10 49 | log_images: 20 50 | latents_to_edit: [] 51 | truncation: 0.7 52 | num_grid_outputs: 1 53 | checkpointing: 54 | is_on: false 55 | start_from: false 56 | step_backup: 100000 -------------------------------------------------------------------------------- /configs/td_single_ffhq.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | config_dir: configs 3 | config: td_single_ffhq.yaml 4 | project: Test 5 | tags: 6 | - stylespace 7 | name: Test 8 | seed: 12 9 | root: . 10 | notes: empty notes 11 | logging: true 12 | step_save: 20 13 | trainer: td_single 14 | training: 15 | iter_num: 300 16 | batch_size: 4 17 | device: cuda:0 18 | generator: stylegan2 19 | phase: mapping 20 | patch_key: cin_mult 21 | source_class: Photo 22 | target_class: 3D Render in the Style of Pixar 23 | auto_layer_k: 16 24 | auto_layer_iters: 0 25 | auto_layer_batch: 8 26 | mixing_noise: 0.9 27 | optimization_setup: 28 | visual_encoders: 29 | - ViT-B/32 30 | - ViT-B/16 31 | loss_funcs: 32 | - direction 33 | loss_coefs: 34 | - 1.0 35 | g_reg_every: 4 36 | optimizer: 37 | weight_decay: 0.0 38 | lr: 0.1 39 | betas: 40 | - 0.9 41 | - 0.999 42 | generator_args: 43 | checkpoint_path: pretrained/StyleGAN2/stylegan2-ffhq-config-f.pt 44 | logging: 45 | log_every: 10 46 | log_images: 20 47 | latents_to_edit: [] 48 | truncation: 0.7 49 | num_grid_outputs: 1 50 | checkpointing: 51 | is_on: false 52 | start_from: false 53 | step_backup: 100000 -------------------------------------------------------------------------------- /configs/td_single_horse.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | config_dir: configs 3 | config: td_single_horse.yaml 4 | project: Test 5 | tags: 6 | - after demodulation 7 | name: Tests 8 | seed: 12 9 | root: . 10 | notes: empty notes 11 | logging: true 12 | step_save: 100 13 | trainer: td_single 14 | training: 15 | iter_num: 150 16 | batch_size: 8 17 | device: cuda:0 18 | generator: stylegan2 19 | phase: conv_kernel 20 | patch_key: original 21 | source_class: Horse 22 | target_class: Pony 23 | auto_layer_k: 16 24 | auto_layer_iters: 0 25 | auto_layer_batch: 8 26 | mixing_noise: 0.9 27 | optimization_setup: 28 | visual_encoders: 29 | - ViT-B/32 30 | - ViT-B/16 31 | loss_funcs: 32 | - direction 33 | - indomain 34 | loss_coefs: 35 | - 1.0 36 | - 0.5 37 | g_reg_every: 4 38 | optimizer: 39 | weight_decay: 0.0 40 | lr: 0.02 41 | betas: 42 | - 0.9 43 | - 0.999 44 | generator_args: 45 | checkpoint_path: pretrained/StyleGAN2/stylegan2-horse-config-f.pt 46 | img_size: 256 47 | logging: 48 | log_every: 10 49 | log_images: 20 50 | latents_to_edit: [] 51 | truncation: 0.5 52 | num_grid_outputs: 1 53 | checkpointing: 54 | is_on: false 55 | start_from: false 56 | step_backup: 100000 -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/core/__init__.py -------------------------------------------------------------------------------- /core/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import dlib 4 | import PIL 5 | 6 | from PIL import Image 7 | from pathlib import Path 8 | from torch.utils.data import Dataset 9 | from torchvision import transforms as transforms 10 | 11 | from core.utils.common import align_face 12 | 13 | IMG_EXTENSIONS = [ 14 | '.jpg', '.JPG', '.jpeg', '.JPEG', 15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' 16 | ] 17 | 18 | 19 | def is_image_file(filename): 20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 21 | 22 | 23 | def make_dataset(dir): 24 | images = [] 25 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 26 | for root, _, fnames in sorted(os.walk(dir)): 27 | for fname in fnames: 28 | if is_image_file(fname): 29 | path = os.path.join(root, fname) 30 | images.append(path) 31 | return images 32 | 33 | 34 | class ImagesDataset(Dataset): 35 | def __init__(self, opts, image_path=None, align_input=False): 36 | if type(image_path) == list: 37 | self.image_paths = image_path 38 | elif os.path.isdir(image_path): 39 | self.image_paths = sorted(make_dataset(image_path)) 40 | elif os.path.isfile(image_path): 41 | self.image_paths = [image_path] 42 | else: 43 | raise ValueError(f"Incorrect 'image_path' argument in ImagesDataset, {image_path}") 44 | 45 | self.image_transform = transforms.Compose([ 46 | transforms.ToTensor(), 47 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 48 | ]) 49 | 50 | self.opts = opts 51 | self.align_input = align_input 52 | 53 | if self.align_input: 54 | weight_path = str(Path(__file__).parent.parent / 'pretrained/shape_predictor_68_face_landmarks.dat') 55 | self.predictor = dlib.shape_predictor(weight_path) 56 | 57 | def __len__(self): 58 | return len(self.image_paths) 59 | 60 | def __getitem__(self, index): 61 | im_path = Path(self.image_paths[index]) 62 | 63 | if self.align_input: 64 | im_H = align_face(str(im_path), self.predictor, output_size=self.opts.size) 65 | else: 66 | im_H = Image.open(str(im_path)).convert('RGB') 67 | im_H = im_H.resize((self.opts.size, self.opts.size)) 68 | 69 | im_L = im_H.resize((256, 256), PIL.Image.LANCZOS) 70 | 71 | return { 72 | "image_high_res": im_H, 73 | "image_low_res": im_L, 74 | "image_high_res_torch": self.image_transform(im_H), 75 | "image_low_res_torch": self.image_transform(im_L), 76 | "image_name": im_path.stem 77 | } 78 | -------------------------------------------------------------------------------- /core/lpips/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from core.lpips import dist_model 9 | from skimage.metrics import structural_similarity 10 | 11 | 12 | class PerceptualLoss(torch.nn.Module): 13 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): 14 | # VGG using our perceptually-learned weights (LPIPS metric) 15 | super(PerceptualLoss, self).__init__() 16 | print('Setting up Perceptual loss...') 17 | self.use_gpu = use_gpu 18 | self.spatial = spatial 19 | self.gpu_ids = gpu_ids 20 | self.model = dist_model.DistModel() 21 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids) 22 | print('...[%s] initialized'%self.model.name()) 23 | print('...Done') 24 | 25 | def forward(self, pred, target, normalize=False): 26 | """ 27 | Pred and target are Variables. 28 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 29 | If normalize is False, assumes the images are already between [-1,+1] 30 | 31 | Inputs pred and target are Nx3xHxW 32 | Output pytorch Variable N long 33 | """ 34 | 35 | if normalize: 36 | target = 2 * target - 1 37 | pred = 2 * pred - 1 38 | 39 | return self.model.forward(target, pred) 40 | 41 | 42 | def normalize_tensor(in_feat,eps=1e-10): 43 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 44 | return in_feat/(norm_factor+eps) 45 | 46 | 47 | def l2(p0, p1, range=255.): 48 | return .5*np.mean((p0 / range - p1 / range)**2) 49 | 50 | 51 | def psnr(p0, p1, peak=255.): 52 | return 10 * np.log10(peak ** 2 / np.mean((1. * p0 - 1. * p1) **2)) 53 | 54 | 55 | def dssim(p0, p1, range=255.): 56 | return (1 - structural_similarity(p0, p1, data_range=range, multichannel=True)) / 2. 57 | 58 | 59 | def rgb2lab(in_img, mean_cent=False): 60 | from skimage import color 61 | img_lab = color.rgb2lab(in_img) 62 | if mean_cent: 63 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 64 | return img_lab 65 | 66 | 67 | def tensor2np(tensor_obj): 68 | # change dimension of a tensor object into a numpy array 69 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 70 | 71 | 72 | def np2tensor(np_obj): 73 | # change dimension of np array into tensor array 74 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 75 | 76 | 77 | def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False): 78 | # image tensor to lab tensor 79 | from skimage import color 80 | 81 | img = tensor2im(image_tensor) 82 | img_lab = color.rgb2lab(img) 83 | if mc_only: 84 | img_lab[:, :, 0] = img_lab[:, :, 0]-50 85 | if to_norm and not mc_only: 86 | img_lab[:, :, 0] = img_lab[:, :, 0]-50 87 | img_lab = img_lab/100. 88 | 89 | return np2tensor(img_lab) 90 | 91 | 92 | def tensorlab2tensor(lab_tensor, return_inbnd=False): 93 | from skimage import color 94 | import warnings 95 | warnings.filterwarnings("ignore") 96 | 97 | lab = tensor2np(lab_tensor)*100. 98 | lab[:, :, 0] = lab[:, :, 0]+50 99 | 100 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')), 0, 1) 101 | if return_inbnd: 102 | # convert back to lab, see if we match 103 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 104 | mask = 1.*np.isclose(lab_back, lab, atol=2.) 105 | mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis]) 106 | return im2tensor(rgb_back), mask 107 | 108 | return im2tensor(rgb_back) 109 | 110 | 111 | def rgb2lab(input): 112 | from skimage import color 113 | return color.rgb2lab(input / 255.) 114 | 115 | 116 | def tensor2vec(vector_tensor): 117 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 118 | 119 | 120 | def voc_ap(rec, prec, use_07_metric=False): 121 | """ ap = voc_ap(rec, prec, [use_07_metric]) 122 | Compute VOC AP given precision and recall. 123 | If use_07_metric is true, uses the 124 | VOC 07 11 point method (default:False). 125 | """ 126 | if use_07_metric: 127 | # 11 point metric 128 | ap = 0. 129 | for t in np.arange(0., 1.1, 0.1): 130 | if np.sum(rec >= t) == 0: 131 | p = 0 132 | else: 133 | p = np.max(prec[rec >= t]) 134 | ap = ap + p / 11. 135 | else: 136 | # correct AP calculation 137 | # first append sentinel values at the end 138 | mrec = np.concatenate(([0.], rec, [1.])) 139 | mpre = np.concatenate(([0.], prec, [0.])) 140 | 141 | # compute the precision envelope 142 | for i in range(mpre.size - 1, 0, -1): 143 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 144 | 145 | # to calculate area under PR curve, look for points 146 | # where X axis (recall) changes value 147 | i = np.where(mrec[1:] != mrec[:-1])[0] 148 | 149 | # and sum (\Delta recall) * prec 150 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 151 | return ap 152 | 153 | 154 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 155 | image_numpy = image_tensor[0].cpu().float().numpy() 156 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 157 | return image_numpy.astype(imtype) 158 | 159 | 160 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 161 | return torch.Tensor((image / factor - cent) 162 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 163 | -------------------------------------------------------------------------------- /core/lpips/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.autograd import Variable 5 | from pdb import set_trace as st 6 | from IPython import embed 7 | 8 | class BaseModel(): 9 | def __init__(self): 10 | pass; 11 | 12 | def name(self): 13 | return 'BaseModel' 14 | 15 | def initialize(self, use_gpu=True, gpu_ids=[0]): 16 | self.use_gpu = use_gpu 17 | self.gpu_ids = gpu_ids 18 | 19 | def forward(self): 20 | pass 21 | 22 | def get_image_paths(self): 23 | pass 24 | 25 | def optimize_parameters(self): 26 | pass 27 | 28 | def get_current_visuals(self): 29 | return self.input 30 | 31 | def get_current_errors(self): 32 | return {} 33 | 34 | def save(self, label): 35 | pass 36 | 37 | # helper saving function that can be used by subclasses 38 | def save_network(self, network, path, network_label, epoch_label): 39 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 40 | save_path = os.path.join(path, save_filename) 41 | torch.save(network.state_dict(), save_path) 42 | 43 | # helper loading function that can be used by subclasses 44 | def load_network(self, network, network_label, epoch_label): 45 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 46 | save_path = os.path.join(self.save_dir, save_filename) 47 | print('Loading network from %s'%save_path) 48 | network.load_state_dict(torch.load(save_path)) 49 | 50 | def update_learning_rate(): 51 | pass 52 | 53 | def get_image_paths(self): 54 | return self.image_paths 55 | 56 | def save_done(self, flag=False): 57 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 58 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 59 | -------------------------------------------------------------------------------- /core/lpips/weights/v0.0/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/core/lpips/weights/v0.0/alex.pth -------------------------------------------------------------------------------- /core/lpips/weights/v0.0/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/core/lpips/weights/v0.0/squeeze.pth -------------------------------------------------------------------------------- /core/lpips/weights/v0.0/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/core/lpips/weights/v0.0/vgg.pth -------------------------------------------------------------------------------- /core/lpips/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/core/lpips/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /core/lpips/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/core/lpips/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /core/lpips/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/core/lpips/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /core/style_embed_options.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | ################# II2S options for style image embedding Note: p_norm_lambda = 1e-2 not 1e-3 4 | opts = Namespace() 5 | 6 | # StyleGAN2 setting 7 | opts.size = 1024 8 | opts.ckpt = "pretrained/StyleGAN2/stylegan2-ffhq-config-f.pt" 9 | opts.channel_multiplier = 2 10 | opts.latent = 512 11 | opts.n_mlp = 8 12 | 13 | # loss options 14 | opts.percept_lambda = 1.0 15 | opts.l2_lambda = 1.0 16 | opts.p_norm_lambda = 1e-3 17 | 18 | # arguments 19 | opts.device = 'cuda' 20 | opts.seed = 2 21 | opts.tile_latent = False 22 | opts.opt_name = 'adam' 23 | opts.learning_rate = 0.01 24 | opts.lr_schedule = 'fixed' 25 | opts.steps = 1000 26 | opts.save_intermediate = False 27 | opts.save_interval = 300 28 | opts.verbose = False 29 | 30 | II2S_s_opts = opts -------------------------------------------------------------------------------- /core/stylegan_patches.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from core.utils.class_registry import ClassRegistry 6 | 7 | modulation_patches = ClassRegistry() 8 | decomposition_patches = ClassRegistry() 9 | 10 | 11 | class Patch(nn.Module): 12 | @property 13 | def device(self): 14 | return next(self.parameters()).device 15 | 16 | def to(self, device): 17 | super().to(device) 18 | 19 | 20 | class BaseModulationPatch(Patch): 21 | def __init__(self, conv_weight: torch.Tensor): 22 | super().__init__() 23 | self.shape = conv_weight.shape 24 | self.register_buffer('ones', torch.ones(self.shape)) 25 | _, self.c_out, self.c_in, k_x, k_y = self.shape 26 | 27 | def forward(self, weight, offsets): 28 | raise NotImplementedError() 29 | 30 | 31 | @modulation_patches.add_to_registry("csep_delta") 32 | class ChannelSeparateDelta(BaseModulationPatch): 33 | def forward(self, weight, offsets): 34 | return weight + offsets['in'] + offsets['out'] 35 | 36 | 37 | @modulation_patches.add_to_registry("csep_mult") 38 | class ChannelwiseSepMult(BaseModulationPatch): 39 | def forward(self, weight, offsets): 40 | mult = self.ones + offsets['in'] + offsets['out'] 41 | return weight * mult 42 | 43 | 44 | @modulation_patches.add_to_registry("cfull_mult") 45 | class ChannelwiseFullMult(BaseModulationPatch): 46 | def forward(self, weight, offsets): 47 | mult = self.ones + offsets['shift'] 48 | return weight * mult 49 | 50 | 51 | @modulation_patches.add_to_registry("cfull_delta") 52 | class ChannelwiseFullDelta(BaseModulationPatch): 53 | def forward(self, weight, offsets): 54 | return weight + offsets['shift'] 55 | 56 | 57 | @modulation_patches.add_to_registry("aff_cout") 58 | class ChannelwiseFullDelta(BaseModulationPatch): 59 | def forward(self, weight, offsets): 60 | return offsets['gamma'] * weight + offsets['beta'] 61 | 62 | 63 | @modulation_patches.add_to_registry("aff_cout_no_beta") 64 | class ChannelwiseFullDelta(BaseModulationPatch): 65 | def forward(self, weight, offsets): 66 | return offsets['gamma'] * weight 67 | 68 | 69 | @modulation_patches.add_to_registry("coutk_mult") 70 | class ChanneloutKernelMult(BaseModulationPatch): 71 | def forward(self, weight, offsets): 72 | mult = self.ones + offsets['out'] + offsets['kernel'] 73 | return weight * mult 74 | 75 | 76 | @modulation_patches.add_to_registry("cout_mult") 77 | class ChannelOutMult(BaseModulationPatch): 78 | def forward(self, weight, offsets): 79 | mult = self.ones + offsets['out'] 80 | return weight * mult 81 | 82 | 83 | @modulation_patches.add_to_registry("cin_mult") 84 | class ChannelINMult(BaseModulationPatch): 85 | def forward(self, weight, offsets): 86 | mult = self.ones + offsets['in'] 87 | return weight * mult 88 | 89 | 90 | @modulation_patches.add_to_registry("cink_mult") 91 | class ChannelINKernelMult(BaseModulationPatch): 92 | def forward(self, weight, offsets): 93 | mult = self.ones + offsets['in'] + offsets['kernel'] 94 | return weight * mult 95 | 96 | 97 | class BaseDecompositionPatch(nn.Module): 98 | def __init__(self, weight): 99 | super().__init__() 100 | weight_matrix = weight.cpu().detach().numpy().reshape((weight.shape[-4:])) 101 | self.c_out, self.c_in, self.k_x, self.k_y = weight_matrix.shape 102 | weight_matrix = np.transpose(weight_matrix, (2, 3, 1, 0)) 103 | weight_matrix = np.reshape(weight_matrix, (self.k_x * self.k_y * self.c_in, self.c_out)) 104 | self._decompose_weight(weight_matrix) 105 | 106 | def _decompose_weight(self, weight_matrix: np.ndarray): 107 | raise NotImplementedError() 108 | 109 | def reconstruct(self, offsets): 110 | raise NotImplementedError() 111 | 112 | def forward(self, offsets): 113 | weight = self.reconstruct(offsets) 114 | weight = weight.view(1, self.k_x, self.k_y, self.c_in, self.c_out) 115 | weight = weight.permute(0, 4, 3, 1, 2).contiguous() 116 | return weight 117 | 118 | 119 | class SvdDecompositionPatch(BaseDecompositionPatch): 120 | def _decompose_weight(self, weight: np.ndarray): 121 | u, s, vh = np.linalg.svd(weight, full_matrices=False) 122 | u = torch.FloatTensor(u) 123 | vh = torch.FloatTensor(vh) 124 | s = torch.FloatTensor(s) 125 | 126 | self.register_buffer('s', s) 127 | self.register_buffer('u', u) 128 | self.register_buffer('vh', vh) 129 | 130 | def reconstruct(self, offsets): 131 | raise NotImplementedError() 132 | 133 | 134 | @decomposition_patches.add_to_registry("svd_s") 135 | class SvdSingularDecomposePatch(SvdDecompositionPatch): 136 | def reconstruct(self, offsets): 137 | if offsets is None: 138 | return self.u @ torch.diag_embed(self.s) @ self.vh 139 | 140 | shifted_s = (self.s + offsets['singular']) 141 | return self.u @ torch.diag_embed(shifted_s) @ self.vh 142 | 143 | 144 | @decomposition_patches.add_to_registry("svd_u_k") 145 | class USVDFirstK(SvdDecompositionPatch): 146 | def reconstruct(self, offsets): 147 | # offsets is like [vector_dim, k] 148 | # return torch.cat([self.u_trainable, self.u_frozen], dim=1) @ torch.diag_embed(self.s) @ self.vh 149 | ... 150 | -------------------------------------------------------------------------------- /core/uda_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from core.utils.common import requires_grad 4 | from core.utils.class_registry import ClassRegistry 5 | 6 | from gan_models.StyleGAN2.offsets_model import ( 7 | OffsetsGenerator, 8 | ModModulatedConv2d, 9 | DecModulatedConv2d, 10 | StyleModulatedConv2d 11 | ) 12 | 13 | from core.stylegan_patches import decomposition_patches, modulation_patches 14 | 15 | uda_models = ClassRegistry() 16 | 17 | # default_arguments = Omegaconf.structured(uda_models.make_dataclass_from_args("GenArgs")) 18 | # default_arguments.GenArgs.stylegan2.size ... 19 | 20 | 21 | @uda_models.add_to_registry("stylegan2") 22 | class OffsetsTunningGenerator(torch.nn.Module): 23 | def __init__(self, img_size=1024, latent_size=512, map_layers=8, 24 | channel_multiplier=2, device='cuda:0', checkpoint_path=None): 25 | super().__init__() 26 | 27 | self.generator = OffsetsGenerator( 28 | img_size, latent_size, map_layers, channel_multiplier=channel_multiplier 29 | ).to(device) 30 | 31 | if checkpoint_path is not None: 32 | checkpoint = torch.load(checkpoint_path, map_location=device) 33 | self.generator.load_state_dict(checkpoint["g_ema"], strict=False) 34 | 35 | self.generator.eval() 36 | 37 | with torch.no_grad(): 38 | self.mean_latent = self.generator.mean_latent(4096) 39 | 40 | def patch_layers(self, patch_key): 41 | """ 42 | Modify ModulatedConv2d Layers with <> patch 43 | """ 44 | if patch_key in decomposition_patches: 45 | self._patch_modconv_key(patch_key, DecModulatedConv2d) 46 | elif patch_key in modulation_patches: 47 | self._patch_modconv_key(patch_key, ModModulatedConv2d) 48 | elif patch_key in style_patches: 49 | self._patch_modconv_key(patch_key, StyleModulatedConv2d) 50 | elif patch_key == 'original': 51 | ... 52 | else: 53 | raise ValueError( 54 | f''' 55 | Incorrect patch_key. Got {patch_key}, possible are { 56 | {decomposition_patches}, {modulation_patches}, {style_patches} 57 | } 58 | ''' 59 | ) 60 | return self 61 | 62 | def _patch_modconv_key(self, patch_key, mod_conv_class): 63 | self.generator.conv1.conv = mod_conv_class( 64 | patch_key, self.generator.conv1.conv 65 | ) 66 | 67 | for conv_layer_ix in range(len(self.generator.convs)): 68 | self.generator.convs[conv_layer_ix].conv = mod_conv_class( 69 | patch_key, self.generator.convs[conv_layer_ix].conv 70 | ) 71 | 72 | def get_all_layers(self): 73 | return list(self.generator.children()) 74 | 75 | def get_training_layers(self, phase): 76 | if phase == 'texture': 77 | # learned constant + first convolution + layers 3-10 78 | return list(self.get_all_layers())[1:3] + list(self.get_all_layers()[4][2:10]) 79 | if phase == 'shape': 80 | # layers 1-2 81 | return list(self.get_all_layers())[1:3] + list(self.get_all_layers()[4][0:2]) 82 | if phase == 'no_fine': 83 | # const + layers 1-10 84 | return list(self.get_all_layers())[1:3] + list(self.get_all_layers()[4][:10]) 85 | if phase == 'shape_expanded': 86 | # const + layers 1-3 87 | return list(self.get_all_layers())[1:3] + list(self.get_all_layers()[4][0:3]) 88 | if phase == 'mapping': 89 | return list(self.get_all_layers())[0] 90 | if phase == 'affine': 91 | styled_convs = list(self.get_all_layers())[4] 92 | return [s.conv.modulation for s in styled_convs] 93 | if phase == 'conv_kernel': 94 | styled_convs = list(self.get_all_layers())[4] 95 | return [s.conv.weight for s in styled_convs] 96 | if phase == 'all': 97 | # everything, including mapping and ToRGB 98 | return self.get_all_layers() 99 | else: 100 | # everything except mapping and ToRGB 101 | return list(self.get_all_layers())[1:3] + list(self.get_all_layers()[4][:]) 102 | 103 | def freeze_layers(self, layer_list=None): 104 | """ 105 | Disable training for all layers in list. 106 | """ 107 | if layer_list is None: 108 | self.freeze_layers(self.get_all_layers()) 109 | else: 110 | for layer in layer_list: 111 | requires_grad(layer, False) 112 | 113 | def unfreeze_layers(self, layer_list=None): 114 | """ 115 | Enable training for all layers in list. 116 | """ 117 | if layer_list is None: 118 | self.unfreeze_layers(self.get_all_layers()) 119 | else: 120 | for layer in layer_list: 121 | requires_grad(layer, True) 122 | 123 | def style(self, styles): 124 | """ 125 | Convert z codes to w codes. 126 | """ 127 | styles = [self.generator.style(s) for s in styles] 128 | return styles 129 | 130 | def get_s_code(self, styles, input_is_latent=False): 131 | return self.generator.get_s_code(styles, input_is_latent) 132 | 133 | def modulation_layers(self): 134 | return self.generator.modulation_layers 135 | 136 | def forward(self, 137 | styles, 138 | offsets=None, 139 | return_latents=False, 140 | inject_index=None, 141 | truncation=1, 142 | truncation_latent=None, 143 | input_is_latent=False, 144 | noise=None, 145 | randomize_noise=True): 146 | return self.generator(styles, 147 | offsets=offsets, 148 | return_latents=return_latents, 149 | truncation=truncation, 150 | truncation_latent=self.mean_latent, 151 | noise=noise, 152 | randomize_noise=randomize_noise, 153 | input_is_latent=input_is_latent) 154 | -------------------------------------------------------------------------------- /core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/core/utils/__init__.py -------------------------------------------------------------------------------- /core/utils/arguments.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from omegaconf import OmegaConf 4 | from core.uda_models import uda_models 5 | from core.utils.class_registry import ClassRegistry 6 | 7 | 8 | args = ClassRegistry() 9 | generator_args = uda_models.make_dataclass_from_args() 10 | args.add_to_registry("generator_args")(generator_args) 11 | additional_arguments = args.make_dataclass_from_classes() 12 | 13 | DEFAULT_CONFIG_DIR = 'configs' 14 | 15 | 16 | def get_generator_args(generator_name, base_args, conf_args): 17 | return OmegaConf.create( 18 | { 19 | generator_name: OmegaConf.merge(base_args, conf_args) 20 | } 21 | ) 22 | 23 | 24 | def load_config(): 25 | base_gen_args_config = OmegaConf.structured(additional_arguments) 26 | 27 | conf_cli = OmegaConf.from_cli() 28 | conf_cli.exp.config_dir = DEFAULT_CONFIG_DIR 29 | 30 | if not conf_cli.get('exp', False): 31 | raise ValueError("No config") 32 | 33 | config_path = os.path.join(conf_cli.exp.config_dir, conf_cli.exp.config) 34 | config_file = OmegaConf.load(config_path) 35 | 36 | generator_args = get_generator_args( 37 | config_file.training.generator, 38 | base_gen_args_config.generator_args[config_file.training.generator], 39 | config_file.generator_args 40 | ) 41 | 42 | gen_args = OmegaConf.create({ 43 | 'generator_args': generator_args 44 | }) 45 | 46 | config = OmegaConf.merge(config_file, conf_cli) 47 | config = OmegaConf.merge( 48 | config, 49 | gen_args 50 | ) 51 | 52 | return config 53 | -------------------------------------------------------------------------------- /core/utils/class_registry.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import typing 3 | import omegaconf 4 | import dataclasses 5 | import typing as tp 6 | 7 | 8 | class ClassRegistry: 9 | def __init__(self): 10 | self.classes = dict() 11 | self.args = dict() 12 | self.arg_keys = None 13 | 14 | def __getitem__(self, item): 15 | return self.classes[item] 16 | 17 | def make_dataclass_from_func(self, func, name, arg_keys): 18 | args = inspect.signature(func).parameters 19 | args = [ 20 | (k, typing.Any, omegaconf.MISSING) 21 | if v.default is inspect.Parameter.empty 22 | else (k, typing.Optional[typing.Any], None) 23 | if v.default is None 24 | else ( 25 | k, 26 | type(v.default), 27 | dataclasses.field(default=v.default), 28 | ) 29 | for k, v in args.items() 30 | ] 31 | args = [ 32 | arg 33 | for arg in args 34 | if (arg[0] != "self" and arg[0] != "args" and arg[0] != "kwargs") 35 | ] 36 | if arg_keys: 37 | self.arg_keys = arg_keys 38 | arg_classes = dict() 39 | for key in arg_keys: 40 | arg_classes[key] = dataclasses.make_dataclass(key, args) 41 | return dataclasses.make_dataclass( 42 | name, 43 | [ 44 | (k, v, dataclasses.field(default=v())) 45 | for k, v in arg_classes.items() 46 | ], 47 | ) 48 | return dataclasses.make_dataclass(name, args) 49 | 50 | def make_dataclass_from_classes(self): 51 | return dataclasses.make_dataclass( 52 | 'Name', 53 | [ 54 | (k, v, dataclasses.field(default=v())) 55 | for k, v in self.classes.items() 56 | ], 57 | ) 58 | 59 | def make_dataclass_from_args(self): 60 | return dataclasses.make_dataclass( 61 | 'Name', 62 | [ 63 | (k, v, dataclasses.field(default=v())) 64 | for k, v in self.args.items() 65 | ], 66 | ) 67 | 68 | def _add_single_obj(self, obj, name, arg_keys): 69 | self.classes[name] = obj 70 | if inspect.isfunction(obj): 71 | self.args[name] = self.make_dataclass_from_func( 72 | obj, name, arg_keys 73 | ) 74 | elif inspect.isclass(obj): 75 | self.args[name] = self.make_dataclass_from_func( 76 | obj.__init__, name, arg_keys 77 | ) 78 | 79 | def add_to_registry(self, names: tp.Union[str, tp.List[str]], arg_keys=None): 80 | if not isinstance(names, list): 81 | names = [names] 82 | 83 | def decorator(obj): 84 | for name in names: 85 | self._add_single_obj(obj, name, arg_keys) 86 | 87 | return obj 88 | return decorator 89 | 90 | def __contains__(self, name: str): 91 | return name in self.args.keys() 92 | 93 | def __repr__(self): 94 | return f"{list(self.args.keys())}" 95 | -------------------------------------------------------------------------------- /core/utils/download_weights.py: -------------------------------------------------------------------------------- 1 | import tarfile 2 | import argparse 3 | import subprocess 4 | from pathlib import Path 5 | 6 | 7 | SOURCES = { 8 | 'mnist': 'https://www.dropbox.com/s/rzurpt5gzb14a1q/pretrained_mnist.tar', 9 | 'anime': 'https://www.dropbox.com/s/9aveavgbluvjeu6/pretrained_anime.tar', 10 | 'biggan': 'https://www.dropbox.com/s/zte4oein08ajsij/pretrained_biggan.tar', 11 | 'proggan': 'https://www.dropbox.com/s/707xjn1rla8nwqc/pretrained_proggan.tar', 12 | 'stylegan2': 'https://www.dropbox.com/s/c3aaq7i6soxmpzu/pretrained_stylegan2_ffhq.tar', 13 | } 14 | 15 | 16 | def download(source: str, destination: Path) -> None: 17 | tmp_tar = str(destination / '.tmp.tar') 18 | # urllib has troubles with dropbox 19 | subprocess.run( 20 | ['curl', '-L', '-k', source, '-o', tmp_tar] 21 | ) 22 | tar_file = tarfile.open(tmp_tar, mode='r') 23 | tar_file.extractall(destination) 24 | subprocess.run( 25 | ['rm', tmp_tar] 26 | ) 27 | 28 | 29 | def main(): 30 | parser = argparse.ArgumentParser(description='Pretrained models loader') 31 | parser.add_argument('--models', nargs='+', type=str, 32 | choices=list(SOURCES.keys()) + ['all'], default=['all']) 33 | parser.add_argument('--out', type=str, help='root out dir') 34 | args = parser.parse_args() 35 | 36 | if args.out is None: 37 | args.out = Path(__file__).parent / 'models' 38 | else: 39 | args.out = Path(args.out) 40 | 41 | if not Path(args.out).exists(): 42 | Path(args.out).mkdir() 43 | 44 | models = args.models 45 | if 'all' in models: 46 | models = list(SOURCES.keys()) 47 | 48 | for model in set(models): 49 | source = SOURCES[model] 50 | print(f'downloading {model}\nfrom {source}') 51 | download(source, args.out) 52 | 53 | 54 | if __name__ == '__main__': 55 | main() 56 | -------------------------------------------------------------------------------- /core/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | from torchvision.transforms import Resize 7 | 8 | 9 | class BicubicDownSample(nn.Module): 10 | def bicubic_kernel(self, x, a=-0.50): 11 | """ 12 | This equation is exactly copied from the website below: 13 | https://clouard.users.greyc.fr/Pantheon/experiments/rescaling/index-en.html#bicubic 14 | """ 15 | abs_x = torch.abs(x) 16 | if abs_x <= 1.: 17 | return (a + 2.) * torch.pow(abs_x, 3.) - (a + 3.) * torch.pow(abs_x, 2.) + 1 18 | elif 1. < abs_x < 2.: 19 | return a * torch.pow(abs_x, 3) - 5. * a * torch.pow(abs_x, 2.) + 8. * a * abs_x - 4. * a 20 | else: 21 | return 0.0 22 | 23 | def __init__(self, factor=4, cuda=True, padding='reflect'): 24 | super().__init__() 25 | self.factor = factor 26 | size = factor * 4 27 | k = torch.tensor([self.bicubic_kernel((i - torch.floor(torch.tensor(size / 2)) + 0.5) / factor) 28 | for i in range(size)], dtype=torch.float32) 29 | k = k / torch.sum(k) 30 | # k = torch.einsum('i,j->ij', (k, k)) 31 | k1 = torch.reshape(k, shape=(1, 1, size, 1)) 32 | self.k1 = torch.cat([k1, k1, k1], dim=0) 33 | k2 = torch.reshape(k, shape=(1, 1, 1, size)) 34 | self.k2 = torch.cat([k2, k2, k2], dim=0) 35 | self.cuda = '.cuda' if cuda else '' 36 | self.padding = padding 37 | for param in self.parameters(): 38 | param.requires_grad = False 39 | 40 | def forward(self, x, nhwc=False, clip_round=False, byte_output=False): 41 | # x = torch.from_numpy(x).type('torch.FloatTensor') 42 | filter_height = self.factor * 4 43 | filter_width = self.factor * 4 44 | stride = self.factor 45 | 46 | pad_along_height = max(filter_height - stride, 0) 47 | pad_along_width = max(filter_width - stride, 0) 48 | filters1 = self.k1.type('torch{}.FloatTensor'.format(self.cuda)) 49 | filters2 = self.k2.type('torch{}.FloatTensor'.format(self.cuda)) 50 | 51 | # compute actual padding values for each side 52 | pad_top = pad_along_height // 2 53 | pad_bottom = pad_along_height - pad_top 54 | pad_left = pad_along_width // 2 55 | pad_right = pad_along_width - pad_left 56 | 57 | # apply mirror padding 58 | if nhwc: 59 | x = torch.transpose(torch.transpose( 60 | x, 2, 3), 1, 2) # NHWC to NCHW 61 | 62 | # downscaling performed by 1-d convolution 63 | x = F.pad(x, (0, 0, pad_top, pad_bottom), self.padding) 64 | x = F.conv2d(input=x, weight=filters1, stride=(stride, 1), groups=3) 65 | if clip_round: 66 | x = torch.clamp(torch.round(x), 0.0, 255.) 67 | 68 | x = F.pad(x, (pad_left, pad_right, 0, 0), self.padding) 69 | x = F.conv2d(input=x, weight=filters2, stride=(1, stride), groups=3) 70 | if clip_round: 71 | x = torch.clamp(torch.round(x), 0.0, 255.) 72 | 73 | if nhwc: 74 | x = torch.transpose(torch.transpose(x, 1, 3), 1, 2) 75 | if byte_output: 76 | return x.type('torch.ByteTensor'.format(self.cuda)) 77 | else: 78 | return x 79 | 80 | 81 | def t2im(img_t: torch.Tensor, size: int = 512): 82 | """ 83 | process torch image with shape (3, h, w) to numpy image 84 | 85 | Parameters 86 | ---------- 87 | img_t : torch.Tensor 88 | Image batch with shape (16, 3, H, W) 89 | 90 | size : int 91 | Size for which smaller edge will be resized 92 | 93 | Returns 94 | ------- 95 | img : np.ndarray 96 | Image with shape (3, H, W) with smaller edge resized to parameter 'size' 97 | """ 98 | img = Resize(size)(img_t).permute(1, 2, 0).cpu().detach().numpy() 99 | img = np.round((np.clip(img, -1, 1) + 1) / 2 * 255).astype(np.uint8) 100 | return img 101 | 102 | 103 | def resize_img(img: torch.Tensor, size: int) -> torch.Tensor: 104 | return F.interpolate(img.unsqueeze(0), (size, size))[0] 105 | 106 | 107 | @torch.no_grad() 108 | def construct_paper_image_grid(img: torch.Tensor): 109 | """ 110 | process torch batch image to paper image 111 | 112 | Parameters 113 | ---------- 114 | img : torch.Tensor 115 | Image batch with shape (16, 3, H, W) 116 | 117 | Returns 118 | ------- 119 | base_fig : np.ndarray 120 | Image with shape (3, H, W) with smaller edge resized to 512 121 | """ 122 | half_size = img.size()[-1] // 2 123 | quarter_size = half_size // 2 124 | 125 | base_fig = torch.cat([img[0], img[1]], dim=2) 126 | sub_cols = [torch.cat([resize_img(img[i + j], half_size) for j in range(2)], dim=1) for i in range(2, 8, 2)] 127 | base_fig = torch.cat([base_fig, *sub_cols], dim=2) 128 | 129 | sub_cols = [torch.cat([resize_img(img[i + j], quarter_size) for j in range(4)], dim=1) for i in range(8, 16, 4)] 130 | base_fig = torch.cat([base_fig, *sub_cols], dim=2) 131 | 132 | base_fig = Resize(512)(base_fig).permute(1, 2, 0).cpu().detach().numpy() 133 | base_fig = np.round((np.clip(base_fig, -1, 1) + 1) / 2 * 255).astype(np.uint8) 134 | return base_fig 135 | 136 | 137 | def crop_augmentation(image: torch.Tensor, size=1024, alpha=0.8): 138 | max_ = int(size * (1 - alpha)) 139 | len_ = int(size * alpha) 140 | x, y = np.random.randint(max_, size=2) 141 | return image[..., x:x + len_, y:y + len_] 142 | -------------------------------------------------------------------------------- /core/utils/loggers.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import wandb 3 | import collections 4 | import logging 5 | import os 6 | 7 | from glob import glob 8 | from pathlib import Path 9 | from core.utils.common import get_valid_exp_dir_name 10 | from omegaconf import OmegaConf 11 | 12 | 13 | class LoggingManager: 14 | def __init__(self, trainer_config): 15 | self.config = trainer_config 16 | config_for_logger = OmegaConf.to_container(self.config) 17 | config_for_logger["PID"] = os.getpid() 18 | 19 | self.cached_latents_local_path = None # processed in _self._init_local_dir 20 | self._init_local_dir() 21 | config_for_logger['local_dir'] = self.local_dir 22 | 23 | self.exp_logger = WandbLogger( 24 | project=trainer_config.exp.project, 25 | name=trainer_config.exp.name, 26 | dir=trainer_config.exp.root, 27 | tags=tuple(trainer_config.exp.tags) if trainer_config.exp.tags else None, 28 | notes=trainer_config.exp.notes, 29 | config=config_for_logger, 30 | ) 31 | self.run_dir = self.exp_logger.run_dir 32 | self.console_logger = ConsoleLogger(trainer_config.exp.name) 33 | 34 | def log_values(self, iter_num, num_iters, iter_info, **kwargs): 35 | self.console_logger.log_iter( 36 | iter_num, num_iters, iter_info, **kwargs 37 | ) 38 | self.exp_logger.log(dict(itern=iter_num, **iter_info.to_dict())) 39 | 40 | def log_images(self, iter_num, images): 41 | self.exp_logger.log_images(iter_num, images) 42 | 43 | def log_info(self, output_info): 44 | self.console_logger.logger.info(output_info) 45 | 46 | def _init_local_dir(self): 47 | cached_latents_dir = Path('image_domains/cached_latents') 48 | cached_latents_dir.mkdir(exist_ok=True) 49 | self.cached_latents_local_path = cached_latents_dir 50 | 51 | project_root = Path(__file__).resolve().parent.parent.parent 52 | exp_path = get_valid_exp_dir_name(project_root, self.config.exp.name) 53 | print("Experiment dir: ", exp_path) 54 | self.local_dir = str(exp_path) 55 | os.makedirs(self.local_dir) 56 | 57 | with open(os.path.join(self.local_dir, 'config.yaml'), 'w') as f: 58 | OmegaConf.save(config=self.config, f=f.name) 59 | 60 | self.checkpoint_dir = os.path.join(self.local_dir, "checkpoints") 61 | os.mkdir(self.checkpoint_dir) 62 | self.models_dir = os.path.join(self.local_dir, "models") 63 | os.mkdir(self.models_dir) 64 | 65 | 66 | class WandbLogger: 67 | def __init__(self, **kwargs): 68 | wandb.init(**kwargs) 69 | self.run_dir = wandb.run.dir 70 | code = wandb.Artifact("project-source", type="code") 71 | dirs = [ 72 | 'core', 73 | 'utils', 74 | 'gan_models', 75 | ] 76 | 77 | pathes = [] 78 | 79 | for dir_p in dirs: 80 | pathes.extend(glob(f"{dir_p}/*.py")) 81 | 82 | for path in pathes + ['trainers.py', 'main.py', 'main_multi.py']: 83 | if Path(path).exists(): 84 | code.add_file(path, name=path) 85 | wandb.run.log_artifact(code) 86 | 87 | def finish(self): 88 | wandb.finish() 89 | 90 | def log(self, data): 91 | wandb.log(data) 92 | 93 | def log_images(self, iter_num: int, images: dict): 94 | data = {k: wandb.Image(v, caption=f"iter = {iter_num}") for k, v in images.items()} 95 | wandb.log(data) 96 | 97 | 98 | class ConsoleLogger: 99 | def __init__(self, name): 100 | self.logger = logging.getLogger(name) 101 | self.logger.handlers = [] 102 | self.logger.setLevel(logging.INFO) 103 | log_formatter = logging.Formatter( 104 | "%(asctime)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" 105 | ) 106 | console_handler = logging.StreamHandler() 107 | console_handler.setFormatter(log_formatter) 108 | self.logger.addHandler(console_handler) 109 | 110 | self.logger.propagate = False 111 | 112 | @staticmethod 113 | def format_info(info): 114 | if not info: 115 | return str(info) 116 | log_groups = collections.defaultdict(dict) 117 | for k, v in info.to_dict().items(): 118 | prefix, suffix = k.split("/", 1) 119 | log_groups[prefix][suffix] = ( 120 | f"{v:.3f}" if isinstance(v, float) else str(v) 121 | ) 122 | formatted_info = "" 123 | max_group_size = len(max(log_groups, key=len)) + 2 124 | max_k_size = ( 125 | max([len(max(g, key=len)) for g in log_groups.values()]) + 1 126 | ) 127 | max_v_size = ( 128 | max([len(max(g.values(), key=len)) for g in log_groups.values()]) 129 | + 1 130 | ) 131 | for group, group_info in log_groups.items(): 132 | group_str = [ 133 | f"{k:<{max_k_size}}={v:>{max_v_size}}" 134 | for k, v in group_info.items() 135 | ] 136 | max_g_size = len(max(group_str, key=len)) + 2 137 | group_str = "".join([f"{g:>{max_g_size}}" for g in group_str]) 138 | formatted_info += f"\n{group + ':':<{max_group_size}}{group_str}" 139 | return formatted_info 140 | 141 | def log_iter( 142 | self, iter_num, num_iters, iter_info, event="epoch" 143 | ): 144 | output_info = ( 145 | f"{event.upper()} ITER {iter_num}/{num_iters}:" 146 | ) 147 | 148 | output_info += self.format_info(iter_info) 149 | self.logger.info(output_info) 150 | -------------------------------------------------------------------------------- /core/utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | 6 | def cosine_loss(x, y): 7 | return 1.0 - F.cosine_similarity(x, y) 8 | 9 | 10 | def mse_loss(x, y): 11 | return F.mse_loss(x, y) 12 | 13 | 14 | def mae_loss(x, y): 15 | return F.l1_loss(x, y) 16 | 17 | 18 | def get_tril_elemets(matrix_torch: torch.Tensor): 19 | flat = torch.tril(matrix_torch, diagonal=-1).flatten() 20 | return flat[torch.nonzero(flat)] 21 | 22 | 23 | def get_tril_elements_mask(linear_size): 24 | mask = np.zeros((linear_size, linear_size), dtype=np.bool) 25 | mask[np.tril_indices_from(mask)] = True 26 | np.fill_diagonal(mask, False) 27 | return mask 28 | 29 | 30 | def flatten_with_non_diagonal(input_matix: torch.Tensor): 31 | linear_matrix_size = input_matix.size(0) 32 | 33 | non_diag = input_matix.flatten()[1:].view(linear_matrix_size - 1, linear_matrix_size + 1)[:, :-1] 34 | return non_diag.flatten() 35 | -------------------------------------------------------------------------------- /core/utils/math_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def resample_single_vector(target_vector, cos_lower_bound, n_vectors=1): 5 | """ 6 | Resample one vector 'n_vectors' times with lower bound of cos 'cos_lower_bound' 7 | 8 | Parameters 9 | ---------- 10 | target_vector : torch.Tensor with size() == (1, dim) || (dim) 11 | center of resampling 12 | cos_lower_bound : float 13 | lower bound of cos of resampled vectors 14 | n_vectors : int 15 | number of resampled vectors 16 | 17 | Returns 18 | ------- 19 | omega : torch.Tensor, size [n_vectors, vector_dim] 20 | resampled vectors with cos with target_vector higher than thr_cos 21 | """ 22 | 23 | if target_vector.ndim == 1: 24 | target_vector = target_vector.unsqueeze(0) 25 | 26 | _, dim = target_vector.size() 27 | 28 | u = target_vector / target_vector.norm(dim=-1, keepdim=True) 29 | u = u.repeat(n_vectors, 1) 30 | r = torch.rand_like(u) * 2 - 1 31 | uperp = torch.stack([r[i] - (torch.dot(r[i], u[i]) * u[i]) for i in range(u.size(0))]) 32 | uperp = uperp / uperp.norm(dim=1, keepdim=True) 33 | 34 | cos_theta = torch.rand(n_vectors, device=target_vector.device) * (1 - cos_lower_bound) + cos_lower_bound 35 | cos_theta = cos_theta.unsqueeze(1).repeat(1, target_vector.size(1)) 36 | omega = cos_theta * u + torch.sqrt(1 - cos_theta ** 2) * uperp 37 | 38 | return omega 39 | 40 | 41 | def resample_batch_vectors(target_vector, cos_lower_bound): 42 | """ 43 | Resample 'b' vector 'n_vectors' times with lower bound of cos 'cos_lower_bound' 44 | 45 | Parameters 46 | ---------- 47 | target_vector : torch.Tensor with size() == (b, dim) 48 | center of resampling 49 | cos_lower_bound : float 50 | lower bound of cos of resampled vectors 51 | 52 | Returns 53 | ------- 54 | omega : torch.Tensor, size [n_vectors, vector_dim] 55 | resampled vectors with cos with target_vector higher than thr_cos 56 | """ 57 | 58 | b, dim = target_vector.size() 59 | u = target_vector / target_vector.norm(dim=-1, keepdim=True) 60 | r = torch.rand_like(u) * 2 - 1 61 | uperp = torch.stack([r[i] - (torch.dot(r[i], u[i]) * u[i]) for i in range(u.size(0))]) 62 | uperp = uperp / uperp.norm(dim=1, keepdim=True) 63 | 64 | cos_theta = torch.rand(b, device=target_vector.device) * (1 - cos_lower_bound) + cos_lower_bound 65 | cos_theta = cos_theta.unsqueeze(1).repeat(1, target_vector.size(1)) 66 | omega = cos_theta * u + torch.sqrt(1 - cos_theta ** 2) * uperp 67 | 68 | return omega 69 | 70 | 71 | def resample_batch_templated_embeddings(embeddings, cos_lower_bound): 72 | if embeddings.ndim == 2: 73 | return resample_batch_vectors(embeddings, cos_lower_bound) 74 | 75 | batch, templates, dim = embeddings.shape 76 | embeddings = embeddings.view(-1, dim) 77 | resampled_embeddings = resample_batch_vectors(embeddings, cos_lower_bound) 78 | 79 | resampled_embeddings = resampled_embeddings.view(batch, templates, dim).contiguous() 80 | return resampled_embeddings 81 | 82 | 83 | def convex_hull(target_vectors, alphas): 84 | """ 85 | calculate convex hull with 'alphas' (1 > alpha > 0, \sum alphas = 1) for target vectors 86 | 87 | Parameters 88 | ---------- 89 | target_vectors : torch.Tensor 90 | set of vectors for which convex hull element is calculated. 91 | Size: [b, dim1, dim2] 92 | 93 | alphas : torch.Tensor 94 | appropriate alphas for which element from convex hull will be calculated. 95 | Size: [b, b] 96 | 97 | Returns 98 | ------- 99 | convex_hull_element : torch.Tensor 100 | single element from convex hull 101 | 102 | """ 103 | 104 | convex_hull_element = (target_vectors.unsqueeze(0) * alphas.unsqueeze(2).unsqueeze(3)).sum(dim=1) 105 | convex_hull_element /= convex_hull_element.clone().norm(dim=-1, keepdim=True) 106 | return convex_hull_element 107 | 108 | 109 | def convex_hull_small(target_vectors, alphas): 110 | """ 111 | calculate convex hull with 'alphas' (1 > alpha > 0, \sum alphas = 1) for target vectors 112 | 113 | Parameters 114 | ---------- 115 | target_vectors : torch.Tensor 116 | set of vectors for which convex hull element is calculated. 117 | Size: [b, dim1, dim2] 118 | 119 | alphas : torch.Tensor 120 | appropriate alphas for which element from convex hull will be calculated. 121 | Size: [b, b] 122 | 123 | Returns 124 | ------- 125 | convex_hull_element : torch.Tensor 126 | single element from convex hull 127 | 128 | """ 129 | 130 | convex_hull_element = (target_vectors.unsqueeze(0) * alphas.unsqueeze(2)).sum(dim=1) 131 | convex_hull_element /= convex_hull_element.clone().norm(dim=-1, keepdim=True) 132 | return convex_hull_element 133 | -------------------------------------------------------------------------------- /core/utils/text_templates.py: -------------------------------------------------------------------------------- 1 | imagenet_templates = [ 2 | 'a bad photo of a {}.', 3 | 'a sculpture of a {}.', 4 | 'a photo of the hard to see {}.', 5 | 'a low resolution photo of the {}.', 6 | 'a rendering of a {}.', 7 | 'graffiti of a {}.', 8 | 'a bad photo of the {}.', 9 | 'a cropped photo of the {}.', 10 | 'a tattoo of a {}.', 11 | 'the embroidered {}.', 12 | 'a photo of a hard to see {}.', 13 | 'a bright photo of a {}.', 14 | 'a photo of a clean {}.', 15 | 'a photo of a dirty {}.', 16 | 'a dark photo of the {}.', 17 | 'a drawing of a {}.', 18 | 'a photo of my {}.', 19 | 'the plastic {}.', 20 | 'a photo of the cool {}.', 21 | 'a close-up photo of a {}.', 22 | 'a black and white photo of the {}.', 23 | 'a painting of the {}.', 24 | 'a painting of a {}.', 25 | 'a pixelated photo of the {}.', 26 | 'a sculpture of the {}.', 27 | 'a bright photo of the {}.', 28 | 'a cropped photo of a {}.', 29 | 'a plastic {}.', 30 | 'a photo of the dirty {}.', 31 | 'a jpeg corrupted photo of a {}.', 32 | 'a blurry photo of the {}.', 33 | 'a photo of the {}.', 34 | 'a good photo of the {}.', 35 | 'a rendering of the {}.', 36 | 'a {} in a video game.', 37 | 'a photo of one {}.', 38 | 'a doodle of a {}.', 39 | 'a close-up photo of the {}.', 40 | 'a photo of a {}.', 41 | 'the origami {}.', 42 | 'the {} in a video game.', 43 | 'a sketch of a {}.', 44 | 'a doodle of the {}.', 45 | 'a origami {}.', 46 | 'a low resolution photo of a {}.', 47 | 'the toy {}.', 48 | 'a rendition of the {}.', 49 | 'a photo of the clean {}.', 50 | 'a photo of a large {}.', 51 | 'a rendition of a {}.', 52 | 'a photo of a nice {}.', 53 | 'a photo of a weird {}.', 54 | 'a blurry photo of a {}.', 55 | 'a cartoon {}.', 56 | 'art of a {}.', 57 | 'a sketch of the {}.', 58 | 'a embroidered {}.', 59 | 'a pixelated photo of a {}.', 60 | 'itap of the {}.', 61 | 'a jpeg corrupted photo of the {}.', 62 | 'a good photo of a {}.', 63 | 'a plushie {}.', 64 | 'a photo of the nice {}.', 65 | 'a photo of the small {}.', 66 | 'a photo of the weird {}.', 67 | 'the cartoon {}.', 68 | 'art of the {}.', 69 | 'a drawing of the {}.', 70 | 'a photo of the large {}.', 71 | 'a black and white photo of a {}.', 72 | 'the plushie {}.', 73 | 'a dark photo of a {}.', 74 | 'itap of a {}.', 75 | 'graffiti of the {}.', 76 | 'a toy {}.', 77 | 'itap of my {}.', 78 | 'a photo of a cool {}.', 79 | 'a photo of a small {}.', 80 | 'a tattoo of the {}.', 81 | ] 82 | 83 | part_templates = [ 84 | 'the paw of a {}.', 85 | 'the nose of a {}.', 86 | 'the eye of the {}.', 87 | 'the ears of a {}.', 88 | 'an eye of a {}.', 89 | 'the tongue of a {}.', 90 | 'the fur of the {}.', 91 | 'colorful {} fur.', 92 | 'a snout of a {}.', 93 | 'the teeth of the {}.', 94 | 'the {}s fangs.', 95 | 'a claw of the {}.', 96 | 'the face of the {}', 97 | 'a neck of a {}', 98 | 'the head of the {}', 99 | ] 100 | 101 | imagenet_templates_small = [ 102 | 'a photo of a {}.', 103 | 'a rendering of a {}.', 104 | 'a cropped photo of the {}.', 105 | 'the photo of a {}.', 106 | 'a photo of a clean {}.', 107 | 'a photo of a dirty {}.', 108 | 'a dark photo of the {}.', 109 | 'a photo of my {}.', 110 | 'a photo of the cool {}.', 111 | 'a close-up photo of a {}.', 112 | 'a bright photo of the {}.', 113 | 'a cropped photo of a {}.', 114 | 'a photo of the {}.', 115 | 'a good photo of the {}.', 116 | 'a photo of one {}.', 117 | 'a close-up photo of the {}.', 118 | 'a rendition of the {}.', 119 | 'a photo of the clean {}.', 120 | 'a rendition of a {}.', 121 | 'a photo of a nice {}.', 122 | 'a good photo of a {}.', 123 | 'a photo of the nice {}.', 124 | 'a photo of the small {}.', 125 | 'a photo of the weird {}.', 126 | 'a photo of the large {}.', 127 | 'a photo of a cool {}.', 128 | 'a photo of a small {}.', 129 | ] -------------------------------------------------------------------------------- /core/utils/train_log.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import collections 3 | import torch 4 | import time 5 | import datetime 6 | 7 | 8 | def strf_time_delta(td): 9 | td_str = "" 10 | if td.days > 0: 11 | td_str += f"{td.days} days, " if td.days > 1 else f"{td.days} day, " 12 | hours = td.seconds // 3600 13 | if hours > 0: 14 | td_str += f"{hours}h " 15 | minutes = (td.seconds // 60) % 60 16 | if minutes > 0: 17 | td_str += f"{minutes}m " 18 | seconds = td.seconds % 60 + td.microseconds * 1e-6 19 | td_str += f"{seconds:.1f}s" 20 | return td_str 21 | 22 | 23 | class Timer: 24 | def __init__(self, info=None, log_event=None): 25 | self.info = info 26 | self.log_event = log_event 27 | 28 | def __enter__(self): 29 | self.start = torch.cuda.Event(enable_timing=True) 30 | self.end = torch.cuda.Event(enable_timing=True) 31 | self.start.record() 32 | return self 33 | 34 | def __exit__(self, exc_type, exc_val, exc_tb): 35 | self.end.record() 36 | torch.cuda.synchronize() 37 | self.duration = self.start.elapsed_time(self.end) / 1000 38 | if self.info: 39 | self.info[f"duration/{self.log_event}"] = self.duration 40 | 41 | 42 | class TimeLog: 43 | def __init__(self, logger, total_num, event): 44 | self.logger = logger 45 | self.total_num = total_num 46 | self.event = event.upper() 47 | self.start = time.time() 48 | 49 | def now(self, current_num): 50 | elapsed = time.time() - self.start 51 | left = self.total_num * elapsed / (current_num + 1) - elapsed 52 | elapsed = strf_time_delta(datetime.timedelta(seconds=elapsed)) 53 | left = strf_time_delta(datetime.timedelta(seconds=left)) 54 | self.logger.log_info( 55 | f"TIME ELAPSED SINCE {self.event} START: {elapsed}" 56 | ) 57 | self.logger.log_info(f"TIME LEFT UNTIL {self.event} END: {left}") 58 | 59 | def end(self): 60 | elapsed = time.time() - self.start 61 | elapsed = strf_time_delta(datetime.timedelta(seconds=elapsed)) 62 | self.logger.log_info( 63 | f"TIME ELAPSED SINCE {self.event} START: {elapsed}" 64 | ) 65 | self.logger.log_info(f"{self.event} ENDS") 66 | 67 | 68 | class MeanTracker(object): 69 | def __init__(self, name): 70 | self.values = [] 71 | self.name = name 72 | 73 | def add(self, val): 74 | self.values.append(float(val)) 75 | 76 | def mean(self): 77 | return np.mean(self.values) 78 | 79 | def flush(self): 80 | mean = self.mean() 81 | self.values = [] 82 | return self.name, mean 83 | 84 | 85 | class _StreamingMean: 86 | def __init__(self, val=None, counts=None): 87 | if val is None: 88 | self.mean = 0.0 89 | self.counts = 0 90 | else: 91 | if isinstance(val, torch.Tensor): 92 | val = val.data.cpu().numpy() 93 | self.mean = val 94 | if counts is not None: 95 | self.counts = counts 96 | else: 97 | self.counts = 1 98 | 99 | def update(self, mean, counts=1): 100 | if isinstance(mean, torch.Tensor): 101 | mean = mean.data.cpu().numpy() 102 | elif isinstance(mean, _StreamingMean): 103 | mean, counts = mean.mean, mean.counts * counts 104 | assert counts >= 0 105 | if counts == 0: 106 | return 107 | total = self.counts + counts 108 | self.mean = self.counts / total * self.mean + counts / total * mean 109 | self.counts = total 110 | 111 | def __add__(self, other): 112 | new = self.__class__(self.mean, self.counts) 113 | if isinstance(other, _StreamingMean): 114 | if other.counts == 0: 115 | return new 116 | else: 117 | new.update(other.mean, other.counts) 118 | else: 119 | new.update(other) 120 | return new 121 | 122 | 123 | class StreamingMeans(collections.defaultdict): 124 | def __init__(self): 125 | super().__init__(_StreamingMean) 126 | 127 | def __setitem__(self, key, value): 128 | if isinstance(value, _StreamingMean): 129 | super().__setitem__(key, value) 130 | else: 131 | super().__setitem__(key, _StreamingMean(value)) 132 | 133 | def update(self, *args, **kwargs): 134 | for_update = dict(*args, **kwargs) 135 | for k, v in for_update.items(): 136 | self[k].update(v) 137 | 138 | def to_dict(self, prefix=""): 139 | return dict((prefix + k, v.mean) for k, v in self.items()) 140 | 141 | def to_str(self): 142 | return ", ".join([f"{k} = {v:.3f}" for k, v in self.to_dict().items()]) -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | import click 2 | import subprocess 3 | 4 | from pathlib import Path 5 | 6 | 7 | def download_curl(source: str, destination: str) -> None: 8 | subprocess.run( 9 | ['curl', '-L', '-k', source, '-o', destination], 10 | stdout=subprocess.DEVNULL 11 | ) 12 | 13 | 14 | def untar(path: str, destination: str = None): 15 | command = ['tar', '-xzvf', path] 16 | 17 | if destination is not None: 18 | command += ['-C', destination] 19 | if not os.path.exists(destination): 20 | os.makedirs(destination) 21 | subprocess.run(command, stdout=subprocess.DEVNULL) 22 | 23 | 24 | def download_gdrive(file_id: str, destination: str) -> None: 25 | subprocess.run(['gdown', '--id', file_id, '-O', destination]) 26 | 27 | 28 | def unzip(path: str, res_path: str = None): 29 | command = ['unzip', path] 30 | 31 | if res_path is not None: 32 | command += ['-d', res_path] 33 | subprocess.run(command, stdout=subprocess.DEVNULL) 34 | 35 | 36 | def bzip2(path: str): 37 | subprocess.run(['bzip2', '-d', path]) 38 | 39 | 40 | def rm_file(path: str): 41 | subprocess.run(['rm', path]) 42 | 43 | 44 | class Setup: 45 | def __init__(self): 46 | self.pretrained_root = Path(__file__).parent / 'pretrained' 47 | self.pretrained_root.mkdir(exist_ok=True) 48 | 49 | def _download(self, data): 50 | 51 | file_dest = str(self.pretrained_root / data['name']) 52 | 53 | if 'link' in data: 54 | download_curl(data['link'], file_dest) 55 | elif 'id' in data: 56 | download_gdrive(data['id'], file_dest) 57 | 58 | if file_dest.endswith('bz2'): 59 | bzip2(file_dest) 60 | rm_file(file_dest) 61 | elif file_dest.endswith('tar.gz'): 62 | untar(file_dest, str(self.pretrained_root / data['uncompressed_dir'])) 63 | rm_file(file_dest) 64 | elif file_dest.endswith('.zip'): 65 | unzip(file_dest, str(self.pretrained_root / data['uncompressed_dir'])) 66 | rm_file(file_dest) 67 | 68 | def setup(self, values): 69 | for value in values: 70 | self._download(SOURCES[value]) 71 | 72 | 73 | SOURCES = { 74 | 'sg2': { 75 | 'link': 'https://nxt.2a2i.org/index.php/s/2K3jbFD3Tg7QmHA/download/StyleGAN2.zip', 76 | 'name': 'StyleGAN2.zip', 77 | 'uncompressed_dir': '' 78 | }, 79 | 'dlib': { 80 | 'link': 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2', 81 | 'name': 'shape_predictor_68_face_landmarks.dat.bz2' 82 | }, 83 | 'restyle_psp': { 84 | 'id': '1nbxCIVw9H3YnQsoIPykNEFwWJnHVHlVd', 85 | 'name': 'restyle_psp_ffhq_encode.pt' 86 | }, 87 | 'e4e': { 88 | 'id': '1o6ijA3PkcewZvwJJ73dJ0fxhndn0nnh7', 89 | 'name': 'e4e_ffhq_encode.pt' 90 | }, 91 | 'checkpoints': { 92 | 'link': 'https://www.dropbox.com/s/r8816i09t9n94hy/checkpoints.zip?dl=0', 93 | 'name': 'checkpoints.zip', 94 | 'uncompressed_dir': '' 95 | } 96 | } 97 | 98 | 99 | @click.command() 100 | @click.argument('value', default=None, nargs=-1) 101 | def main(value): 102 | setuper = Setup() 103 | values = value if value else SOURCES.keys() 104 | setuper.setup(values) 105 | 106 | 107 | if __name__ == '__main__': 108 | main() 109 | -------------------------------------------------------------------------------- /examples/celeb_latents/Chris.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/examples/celeb_latents/Chris.npy -------------------------------------------------------------------------------- /examples/celeb_latents/Gakki.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/examples/celeb_latents/Gakki.npy -------------------------------------------------------------------------------- /examples/celeb_latents/Green_Lantern.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/examples/celeb_latents/Green_Lantern.npy -------------------------------------------------------------------------------- /examples/celeb_latents/Morgan.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/examples/celeb_latents/Morgan.npy -------------------------------------------------------------------------------- /examples/celeb_latents/Obama.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/examples/celeb_latents/Obama.npy -------------------------------------------------------------------------------- /examples/celeb_latents/Oprah.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/examples/celeb_latents/Oprah.npy -------------------------------------------------------------------------------- /examples/celeb_latents/Pichai.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/examples/celeb_latents/Pichai.npy -------------------------------------------------------------------------------- /examples/celeb_latents/Rock.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/examples/celeb_latents/Rock.npy -------------------------------------------------------------------------------- /examples/celeb_latents/Scarlett.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/examples/celeb_latents/Scarlett.npy -------------------------------------------------------------------------------- /examples/celeb_latents/Su.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/examples/celeb_latents/Su.npy -------------------------------------------------------------------------------- /examples/celeb_latents/Yui.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/examples/celeb_latents/Yui.npy -------------------------------------------------------------------------------- /examples/custom_images/.gitignore: -------------------------------------------------------------------------------- 1 | *.jpg 2 | *.png 3 | *.JPG 4 | -------------------------------------------------------------------------------- /examples/custom_images/elon_musk.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/examples/custom_images/elon_musk.jpeg -------------------------------------------------------------------------------- /gan_models/BigGAN/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/gan_models/BigGAN/__init__.py -------------------------------------------------------------------------------- /gan_models/BigGAN/generator_config.json: -------------------------------------------------------------------------------- 1 | {"num_epochs": 100, 2 | "BN_eps": 1e-05, 3 | "G_batch_size": 512, 4 | "sv_log_interval": 10, 5 | "shuffle": true, 6 | "batch_size": 256, 7 | "G_mixed_precision": false, 8 | "toggle_grads": true, 9 | "mybn": false, 10 | "augment": false, 11 | "D_B2": 0.999, 12 | "D_attn": "64", 13 | "log_G_spectra": false, 14 | "G_shared": true, 15 | "num_D_steps": 1, 16 | "num_best_copies": 5, 17 | "load_in_mem": false, 18 | "split_D": false, 19 | "sample_npz": true, 20 | "D_B1": 0.0, 21 | "cross_replica": false, 22 | "SN_eps": 1e-06, 23 | "G_lr": 0.0001, 24 | "num_G_SV_itrs": 1, 25 | "pin_memory": true, 26 | "D_mixed_precision": false, 27 | "num_G_SVs": 1, 28 | "G_fp16": false, 29 | "sample_interps": true, 30 | "test_every": 2000, 31 | "sample_random": true, 32 | "num_D_SV_itrs": 1, 33 | "config_from_name": false, 34 | "G_eval_mode": true, 35 | "D_nl": "inplace_relu", 36 | "G_param": "SN", 37 | "num_inception_images": 50000, 38 | "save_every": 1000, 39 | "D_lr": 0.0004, 40 | "sample_inception_metrics": true, 41 | "G_attn": "64", 42 | "G_depth": 1, 43 | "which_train_fn": "GAN", 44 | "norm_style": "bn", 45 | "sample_num_npz": 50000, 46 | "hashname": false, 47 | "sample_sheet_folder_num": -1, 48 | "resume": false, 49 | "D_ortho": 0.0, 50 | "ema_start": 20000, 51 | "num_workers": 8, 52 | "dataset": "I128_hdf5", 53 | "ema": true, 54 | "num_D_accumulations": 8, 55 | "no_fid": false, 56 | "D_fp16": false, 57 | "G_init": "ortho", 58 | "D_init": "ortho", 59 | "D_ch": 96, 60 | "dim_z": 120, 61 | "D_wide": true, 62 | "accumulate_stats": false, 63 | "num_D_SVs": 1, 64 | "G_B1": 0.0, 65 | "use_ema": true, 66 | "pbar": "mine", 67 | "sample_trunc_curves": "0.05_0.05_1.0", 68 | "use_multiepoch_sampler": true, 69 | "num_G_accumulations": 8, 70 | "G_ch": 96, 71 | "G_B2": 0.999, 72 | "D_depth": 1, 73 | "D_param": "SN", 74 | "G_ortho": 0.0, 75 | "seed": 0, 76 | "log_D_spectra": false, 77 | "num_save_copies": 2, 78 | "hier": true, 79 | "G_nl": "inplace_relu", 80 | "skip_init": true, 81 | "sample_sheets": true, 82 | "z_var": 1.0, 83 | "adam_eps": 1e-06, 84 | "experiment_name": "", 85 | "ema_decay": 0.9999, 86 | "model": "BigGAN", 87 | "shared_dim": 128, 88 | "which_best": "IS", 89 | "parallel": true, 90 | "num_standing_accumulations": 16 91 | } -------------------------------------------------------------------------------- /gan_models/BigGAN/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import ( 12 | SynchronizedBatchNorm1d, 13 | SynchronizedBatchNorm2d, 14 | SynchronizedBatchNorm3d, 15 | ) 16 | from .replicate import DataParallelWithCallback, patch_replication_callback 17 | -------------------------------------------------------------------------------- /gan_models/BigGAN/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ["BatchNormReimpl"] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | 28 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 29 | super().__init__() 30 | 31 | self.num_features = num_features 32 | self.eps = eps 33 | self.momentum = momentum 34 | self.weight = nn.Parameter(torch.empty(num_features)) 35 | self.bias = nn.Parameter(torch.empty(num_features)) 36 | self.register_buffer("running_mean", torch.zeros(num_features)) 37 | self.register_buffer("running_var", torch.ones(num_features)) 38 | self.reset_parameters() 39 | 40 | def reset_running_stats(self): 41 | self.running_mean.zero_() 42 | self.running_var.fill_(1) 43 | 44 | def reset_parameters(self): 45 | self.reset_running_stats() 46 | init.uniform_(self.weight) 47 | init.zeros_(self.bias) 48 | 49 | def forward(self, input_): 50 | batchsize, channels, height, width = input_.size() 51 | numel = batchsize * height * width 52 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 53 | sum_ = input_.sum(1) 54 | sum_of_square = input_.pow(2).sum(1) 55 | mean = sum_ / numel 56 | sumvar = sum_of_square - sum_ * mean 57 | 58 | self.running_mean = ( 59 | 1 - self.momentum 60 | ) * self.running_mean + self.momentum * mean.detach() 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | 1 - self.momentum 64 | ) * self.running_var + self.momentum * unbias_var.detach() 65 | 66 | bias_var = sumvar / numel 67 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 68 | output = (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze( 69 | 1 70 | ) * self.weight.unsqueeze(1) + self.bias.unsqueeze(1) 71 | 72 | return ( 73 | output.view(channels, batchsize, height, width) 74 | .permute(1, 0, 2, 3) 75 | .contiguous() 76 | ) 77 | -------------------------------------------------------------------------------- /gan_models/BigGAN/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ["FutureResult", "SlavePipe", "SyncMaster"] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, "Previous result has't been fetched." 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple("MasterRegistry", ["result"]) 43 | _SlavePipeBase = collections.namedtuple( 44 | "_SlavePipeBase", ["identifier", "queue", "result"] 45 | ) 46 | 47 | 48 | class SlavePipe(_SlavePipeBase): 49 | """Pipe for master-slave communication.""" 50 | 51 | def run_slave(self, msg): 52 | self.queue.put((self.identifier, msg)) 53 | ret = self.result.get() 54 | self.queue.put(True) 55 | return ret 56 | 57 | 58 | class SyncMaster(object): 59 | """An abstract `SyncMaster` object. 60 | 61 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 62 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 63 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 64 | and passed to a registered callback. 65 | - After receiving the messages, the master device should gather the information and determine to message passed 66 | back to each slave devices. 67 | """ 68 | 69 | def __init__(self, master_callback): 70 | """ 71 | 72 | Args: 73 | master_callback: a callback to be invoked after having collected messages from slave devices. 74 | """ 75 | self._master_callback = master_callback 76 | self._queue = queue.Queue() 77 | self._registry = collections.OrderedDict() 78 | self._activated = False 79 | 80 | def __getstate__(self): 81 | return {"master_callback": self._master_callback} 82 | 83 | def __setstate__(self, state): 84 | self.__init__(state["master_callback"]) 85 | 86 | def register_slave(self, identifier): 87 | """ 88 | Register an slave device. 89 | 90 | Args: 91 | identifier: an identifier, usually is the device id. 92 | 93 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 94 | 95 | """ 96 | if self._activated: 97 | assert self._queue.empty(), "Queue is not clean before next initialization." 98 | self._activated = False 99 | self._registry.clear() 100 | future = FutureResult() 101 | self._registry[identifier] = _MasterRegistry(future) 102 | return SlavePipe(identifier, self._queue, future) 103 | 104 | def run_master(self, master_msg): 105 | """ 106 | Main entry for the master device in each forward pass. 107 | The messages were first collected from each devices (including the master device), and then 108 | an callback will be invoked to compute the message to be sent back to each devices 109 | (including the master device). 110 | 111 | Args: 112 | master_msg: the message that the master want to send to itself. This will be placed as the first 113 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 114 | 115 | Returns: the message to be sent back to the master device. 116 | 117 | """ 118 | self._activated = True 119 | 120 | intermediates = [(0, master_msg)] 121 | for i in range(self.nr_slaves): 122 | intermediates.append(self._queue.get()) 123 | 124 | results = self._master_callback(intermediates) 125 | assert results[0][0] == 0, "The first result should belongs to the master." 126 | 127 | for i, res in results: 128 | if i == 0: 129 | continue 130 | self._registry[i].result.put(res) 131 | 132 | for i in range(self.nr_slaves): 133 | assert self._queue.get() is True 134 | 135 | return results[0][1] 136 | 137 | @property 138 | def nr_slaves(self): 139 | return len(self._registry) 140 | -------------------------------------------------------------------------------- /gan_models/BigGAN/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | "CallbackContext", 17 | "execute_replication_callbacks", 18 | "DataParallelWithCallback", 19 | "patch_replication_callback", 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, "__data_parallel_replicate__"): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /gan_models/BigGAN/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = "NaN" 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ("Tensor close check failed\n" "adiff={}\n" "rdiff={}\n").format( 24 | adiff, rdiff 25 | ) 26 | self.assertTrue(torch.allclose(x, y), message) 27 | -------------------------------------------------------------------------------- /gan_models/BigGAN/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import print_function 5 | import torch.nn as nn 6 | 7 | # Convenience dicts 8 | imsize_dict = { 9 | "I32": 32, 10 | "I32_hdf5": 32, 11 | "I64": 64, 12 | "I64_hdf5": 64, 13 | "I128": 128, 14 | "I128_hdf5": 128, 15 | "I256": 256, 16 | "I256_hdf5": 256, 17 | "C10": 32, 18 | "C100": 32, 19 | } 20 | nclass_dict = { 21 | "I32": 1000, 22 | "I32_hdf5": 1000, 23 | "I64": 1000, 24 | "I64_hdf5": 1000, 25 | "I128": 1000, 26 | "I128_hdf5": 1000, 27 | "I256": 1000, 28 | "I256_hdf5": 1000, 29 | "C10": 10, 30 | "C100": 100, 31 | } 32 | activation_dict = { 33 | "inplace_relu": nn.ReLU(inplace=True), 34 | "relu": nn.ReLU(inplace=False), 35 | "ir": nn.ReLU(inplace=True), 36 | } 37 | -------------------------------------------------------------------------------- /gan_models/ProgGAN/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/gan_models/ProgGAN/__init__.py -------------------------------------------------------------------------------- /gan_models/ProgGAN/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | This work is based on the Theano/Lasagne implementation of 5 | Progressive Growing of GANs paper from tkarras: 6 | https://github.com/tkarras/progressive_growing_of_gans 7 | 8 | PyTorch Model definition 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from collections import OrderedDict 16 | 17 | 18 | class PixelNormLayer(nn.Module): 19 | def __init__(self): 20 | super(PixelNormLayer, self).__init__() 21 | 22 | def forward(self, x): 23 | return x / torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8) 24 | 25 | 26 | class WScaleLayer(nn.Module): 27 | def __init__(self, size): 28 | super(WScaleLayer, self).__init__() 29 | self.scale = nn.Parameter(torch.randn([1])) 30 | self.b = nn.Parameter(torch.randn(size)) 31 | self.size = size 32 | 33 | def forward(self, x): 34 | x_size = x.size() 35 | x = x * self.scale + self.b.view(1, -1, 1, 1).expand( 36 | x_size[0], self.size, x_size[2], x_size[3]) 37 | 38 | return x 39 | 40 | 41 | class NormConvBlock(nn.Module): 42 | def __init__(self, in_channels, out_channels, kernel_size, padding): 43 | super(NormConvBlock, self).__init__() 44 | self.norm = PixelNormLayer() 45 | self.conv = nn.Conv2d( 46 | in_channels, out_channels, kernel_size, 1, padding, bias=False) 47 | self.wscale = WScaleLayer(out_channels) 48 | 49 | def forward(self, x): 50 | x = self.norm(x) 51 | x = self.conv(x) 52 | x = F.leaky_relu(self.wscale(x), negative_slope=0.2) 53 | return x 54 | 55 | 56 | class NormUpscaleConvBlock(nn.Module): 57 | def __init__(self, in_channels, out_channels, kernel_size, padding): 58 | super(NormUpscaleConvBlock, self).__init__() 59 | self.norm = PixelNormLayer() 60 | self.up = nn.Upsample(scale_factor=2, mode='nearest') 61 | self.conv = nn.Conv2d( 62 | in_channels, out_channels, kernel_size, 1, padding, bias=False) 63 | self.wscale = WScaleLayer(out_channels) 64 | 65 | def forward(self, x): 66 | x = self.norm(x) 67 | x = self.up(x) 68 | x = self.conv(x) 69 | x = F.leaky_relu(self.wscale(x), negative_slope=0.2) 70 | return x 71 | 72 | 73 | class Generator(nn.Module): 74 | def __init__(self): 75 | super(Generator, self).__init__() 76 | 77 | self.features = nn.Sequential( 78 | NormConvBlock(512, 512, kernel_size=4, padding=3), 79 | NormConvBlock(512, 512, kernel_size=3, padding=1), 80 | NormUpscaleConvBlock(512, 512, kernel_size=3, padding=1), 81 | NormConvBlock(512, 512, kernel_size=3, padding=1), 82 | NormUpscaleConvBlock(512, 512, kernel_size=3, padding=1), 83 | NormConvBlock(512, 512, kernel_size=3, padding=1), 84 | NormUpscaleConvBlock(512, 512, kernel_size=3, padding=1), 85 | NormConvBlock(512, 512, kernel_size=3, padding=1), 86 | NormUpscaleConvBlock(512, 256, kernel_size=3, padding=1), 87 | NormConvBlock(256, 256, kernel_size=3, padding=1), 88 | NormUpscaleConvBlock(256, 128, kernel_size=3, padding=1), 89 | NormConvBlock(128, 128, kernel_size=3, padding=1), 90 | NormUpscaleConvBlock(128, 64, kernel_size=3, padding=1), 91 | NormConvBlock(64, 64, kernel_size=3, padding=1), 92 | NormUpscaleConvBlock(64, 32, kernel_size=3, padding=1), 93 | NormConvBlock(32, 32, kernel_size=3, padding=1), 94 | NormUpscaleConvBlock(32, 16, kernel_size=3, padding=1), 95 | NormConvBlock(16, 16, kernel_size=3, padding=1)) 96 | 97 | self.output = nn.Sequential(OrderedDict([ 98 | ('norm', PixelNormLayer()), 99 | ('conv', nn.Conv2d(16, 100 | 3, 101 | kernel_size=1, 102 | padding=0, 103 | bias=False)), 104 | ('wscale', WScaleLayer(3)) 105 | ])) 106 | 107 | def forward(self, x): 108 | x = self.features(x) 109 | x = self.output(x) 110 | return x 111 | -------------------------------------------------------------------------------- /gan_models/SNGAN/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/gan_models/SNGAN/__init__.py -------------------------------------------------------------------------------- /gan_models/SNGAN/distribution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class BaseDistribution(nn.Module): 6 | def __init__(self, dim, device='cuda'): 7 | super(BaseDistribution, self).__init__() 8 | self.device = device 9 | self.dim = dim 10 | 11 | def cuda(self, device=None): 12 | super(BaseDistribution, self).cuda(device) 13 | self.device = 'cuda' if device is None else device 14 | 15 | def cpu(self): 16 | super(BaseDistribution, self).cpu() 17 | self.device='cpu' 18 | 19 | def to(self, device): 20 | super(BaseDistribution, self).to(device) 21 | self.device = device 22 | 23 | def forward(self, batch_size): 24 | raise NotImplementedError 25 | 26 | 27 | class NormalDistribution(BaseDistribution): 28 | def __init__(self, dim): 29 | super(NormalDistribution, self).__init__(dim) 30 | 31 | def forward(self, batch_size): 32 | return torch.randn([batch_size, self.dim]).to(self.device) 33 | -------------------------------------------------------------------------------- /gan_models/SNGAN/load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | from gan_models.SNGAN.sn_gen_resnet import SN_RES_GEN_CONFIGS, make_resnet_generator 5 | from gan_models.SNGAN.distribution import NormalDistribution 6 | 7 | 8 | MODELS = { 9 | 'sn_resnet32': 32, 10 | 'sn_resnet64': 64, 11 | } 12 | 13 | 14 | DISTRIBUTIONS = { 15 | 'normal': NormalDistribution, 16 | } 17 | 18 | 19 | class Args: 20 | def __init__(self, **kwargs): 21 | self.nonfixed_noise = False 22 | self.noises_count = 1 23 | self.equal_split = False 24 | self.generator_batch_norm = False 25 | self.gen_sn = False 26 | self.distribution_params = "{}" 27 | 28 | self.__dict__.update(kwargs) 29 | 30 | 31 | def load_model_from_state_dict(root_dir): 32 | args = Args(**json.load(open(os.path.join(root_dir, 'args.json')))) 33 | generator_model_path = os.path.join(root_dir, 'generator.pt') 34 | 35 | try: 36 | image_channels = args.image_channels 37 | except Exception: 38 | image_channels = 3 39 | 40 | gen_config = SN_RES_GEN_CONFIGS[args.model] 41 | generator= make_resnet_generator(gen_config, channels=image_channels, 42 | distribution=NormalDistribution(args.latent_dim), 43 | img_size=MODELS[args.model]) 44 | 45 | generator.load_state_dict( 46 | torch.load(generator_model_path, map_location=torch.device('cpu')), strict=False) 47 | return generator 48 | -------------------------------------------------------------------------------- /gan_models/SNGAN/sn_gen_resnet.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torch import nn 4 | import numpy as np 5 | from gan_models.SNGAN.distribution import NormalDistribution 6 | 7 | 8 | ResNetGenConfig = namedtuple('ResNetGenConfig', ['channels', 'seed_dim']) 9 | SN_RES_GEN_CONFIGS = { 10 | 'sn_resnet32': ResNetGenConfig([256, 256, 256, 256], 4), 11 | 'sn_resnet64': ResNetGenConfig([16 * 64, 8 * 64, 4 * 64, 2 * 64, 64], 4), 12 | } 13 | 14 | 15 | class Reshape(nn.Module): 16 | def __init__(self, target_shape): 17 | super(Reshape, self).__init__() 18 | self.target_shape = target_shape 19 | 20 | def forward(self, input): 21 | return input.view(self.target_shape) 22 | 23 | 24 | class ResBlockGenerator(nn.Module): 25 | def __init__(self, in_channels, out_channels): 26 | super(ResBlockGenerator, self).__init__() 27 | 28 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=1) 29 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding=1) 30 | 31 | nn.init.xavier_uniform_(self.conv1.weight.data, np.sqrt(2)) 32 | nn.init.xavier_uniform_(self.conv2.weight.data, np.sqrt(2)) 33 | 34 | self.model = nn.Sequential( 35 | nn.BatchNorm2d(in_channels), 36 | nn.ReLU(inplace=True), 37 | nn.Upsample(scale_factor=2), 38 | self.conv1, 39 | nn.BatchNorm2d(out_channels), 40 | nn.ReLU(inplace=True), 41 | self.conv2 42 | ) 43 | 44 | if in_channels == out_channels: 45 | self.bypass = nn.Upsample(scale_factor=2) 46 | else: 47 | self.bypass = nn.Sequential( 48 | nn.Upsample(scale_factor=2), 49 | nn.Conv2d(in_channels, out_channels, 3, 1, padding=1) 50 | ) 51 | nn.init.xavier_uniform_(self.bypass[1].weight.data, 1.0) 52 | 53 | def forward(self, x): 54 | return self.model(x) + self.bypass(x) 55 | 56 | 57 | class GenWrapper(nn.Module): 58 | def __init__(self, model, out_img_shape, distribution): 59 | super(GenWrapper, self).__init__() 60 | 61 | self.model = model 62 | self.out_img_shape = out_img_shape 63 | self.distribution = distribution 64 | self.force_no_grad = False 65 | 66 | def cuda(self, device=None): 67 | super(GenWrapper, self).cuda(device) 68 | self.distribution.cuda() 69 | 70 | def forward(self, batch_size): 71 | if self.force_no_grad: 72 | with torch.no_grad(): 73 | img = self.model(self.distribution(batch_size)) 74 | else: 75 | img = self.model(self.distribution(batch_size)) 76 | 77 | img = img.view(img.shape[0], *self.out_img_shape) 78 | return img 79 | 80 | 81 | def make_resnet_generator(resnet_gen_config, img_size=128, channels=3, 82 | distribution=NormalDistribution(128)): 83 | def make_dense(): 84 | dense = nn.Linear( 85 | distribution.dim, resnet_gen_config.seed_dim**2 * resnet_gen_config.channels[0]) 86 | nn.init.xavier_uniform_(dense.weight.data, 1.) 87 | return dense 88 | 89 | def make_final(): 90 | final = nn.Conv2d(resnet_gen_config.channels[-1], channels, 3, stride=1, padding=1) 91 | nn.init.xavier_uniform_(final.weight.data, 1.) 92 | return final 93 | 94 | model_channels = resnet_gen_config.channels 95 | 96 | input_layers = [ 97 | make_dense(), 98 | Reshape([-1, model_channels[0], 4, 4]) 99 | ] 100 | res_blocks = [ 101 | ResBlockGenerator(model_channels[i], model_channels[i + 1]) 102 | for i in range(len(model_channels) - 1) 103 | ] 104 | out_layers = [ 105 | nn.BatchNorm2d(model_channels[-1]), 106 | nn.ReLU(inplace=True), 107 | make_final(), 108 | nn.Tanh() 109 | ] 110 | 111 | model = nn.Sequential(*(input_layers + res_blocks + out_layers)) 112 | 113 | return GenWrapper(model, [channels, img_size, img_size], distribution) 114 | -------------------------------------------------------------------------------- /gan_models/StyleGAN2/op/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 3 | from .upfirdn2d import upfirdn2d 4 | except: 5 | from .upfirdn2d_torch_native import upfirdn2d 6 | from .fused_act_torch_native import FusedLeakyReLU, fused_leaky_relu 7 | -------------------------------------------------------------------------------- /gan_models/StyleGAN2/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | fused = load( 11 | 'fused', 12 | sources=[ 13 | os.path.join(module_path, 'fused_bias_act.cpp'), 14 | os.path.join(module_path, 'fused_bias_act_kernel.cu'), 15 | ], 16 | ) 17 | 18 | 19 | class FusedLeakyReLUFunctionBackward(Function): 20 | @staticmethod 21 | def forward(ctx, grad_output, out, negative_slope, scale): 22 | ctx.save_for_backward(out) 23 | ctx.negative_slope = negative_slope 24 | ctx.scale = scale 25 | 26 | empty = grad_output.new_empty(0) 27 | 28 | grad_input = fused.fused_bias_act( 29 | grad_output, empty, out, 3, 1, negative_slope, scale 30 | ) 31 | 32 | dim = [0] 33 | 34 | if grad_input.ndim > 2: 35 | dim += list(range(2, grad_input.ndim)) 36 | 37 | grad_bias = grad_input.sum(dim).detach() 38 | 39 | return grad_input, grad_bias 40 | 41 | @staticmethod 42 | def backward(ctx, gradgrad_input, gradgrad_bias): 43 | out, = ctx.saved_tensors 44 | gradgrad_out = fused.fused_bias_act( 45 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 46 | ) 47 | 48 | return gradgrad_out, None, None, None 49 | 50 | 51 | class FusedLeakyReLUFunction(Function): 52 | @staticmethod 53 | def forward(ctx, input, bias, negative_slope, scale): 54 | empty = input.new_empty(0) 55 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 56 | ctx.save_for_backward(out) 57 | ctx.negative_slope = negative_slope 58 | ctx.scale = scale 59 | 60 | return out 61 | 62 | @staticmethod 63 | def backward(ctx, grad_output): 64 | out, = ctx.saved_tensors 65 | 66 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 67 | grad_output, out, ctx.negative_slope, ctx.scale 68 | ) 69 | 70 | return grad_input, grad_bias, None, None 71 | 72 | 73 | class FusedLeakyReLU(nn.Module): 74 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 75 | super().__init__() 76 | 77 | self.bias = nn.Parameter(torch.zeros(channel)) 78 | self.negative_slope = negative_slope 79 | self.scale = scale 80 | 81 | def forward(self, input): 82 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 83 | 84 | 85 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 86 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 87 | -------------------------------------------------------------------------------- /gan_models/StyleGAN2/op/fused_act_torch_native.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | module_path = os.path.dirname(__file__) 8 | 9 | 10 | class FusedLeakyReLU(nn.Module): 11 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 12 | super().__init__() 13 | 14 | self.bias = nn.Parameter(torch.zeros(channel)) 15 | self.negative_slope = negative_slope 16 | self.scale = scale 17 | 18 | def forward(self, input): 19 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 20 | 21 | 22 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 23 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 24 | if input.ndim == 3: 25 | return ( 26 | F.leaky_relu( 27 | input + bias.view(1, *rest_dim, bias.shape[0]), negative_slope=negative_slope 28 | ) 29 | * scale 30 | ) 31 | else: 32 | return ( 33 | F.leaky_relu( 34 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope 35 | ) 36 | * scale 37 | ) 38 | -------------------------------------------------------------------------------- /gan_models/StyleGAN2/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /gan_models/StyleGAN2/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /gan_models/StyleGAN2/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /gan_models/StyleGAN2/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | upfirdn2d_op = load( 11 | 'upfirdn2d', 12 | sources=[ 13 | os.path.join(module_path, 'upfirdn2d.cpp'), 14 | os.path.join(module_path, 'upfirdn2d_kernel.cu'), 15 | ], 16 | ) 17 | 18 | 19 | class UpFirDn2dBackward(Function): 20 | @staticmethod 21 | def forward( 22 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 23 | ): 24 | 25 | up_x, up_y = up 26 | down_x, down_y = down 27 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 28 | 29 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 30 | 31 | grad_input = upfirdn2d_op.upfirdn2d( 32 | grad_output, 33 | grad_kernel, 34 | down_x, 35 | down_y, 36 | up_x, 37 | up_y, 38 | g_pad_x0, 39 | g_pad_x1, 40 | g_pad_y0, 41 | g_pad_y1, 42 | ) 43 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 44 | 45 | ctx.save_for_backward(kernel) 46 | 47 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 48 | 49 | ctx.up_x = up_x 50 | ctx.up_y = up_y 51 | ctx.down_x = down_x 52 | ctx.down_y = down_y 53 | ctx.pad_x0 = pad_x0 54 | ctx.pad_x1 = pad_x1 55 | ctx.pad_y0 = pad_y0 56 | ctx.pad_y1 = pad_y1 57 | ctx.in_size = in_size 58 | ctx.out_size = out_size 59 | 60 | return grad_input 61 | 62 | @staticmethod 63 | def backward(ctx, gradgrad_input): 64 | kernel, = ctx.saved_tensors 65 | 66 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 67 | 68 | gradgrad_out = upfirdn2d_op.upfirdn2d( 69 | gradgrad_input, 70 | kernel, 71 | ctx.up_x, 72 | ctx.up_y, 73 | ctx.down_x, 74 | ctx.down_y, 75 | ctx.pad_x0, 76 | ctx.pad_x1, 77 | ctx.pad_y0, 78 | ctx.pad_y1, 79 | ) 80 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 81 | gradgrad_out = gradgrad_out.view( 82 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 83 | ) 84 | 85 | return gradgrad_out, None, None, None, None, None, None, None, None 86 | 87 | 88 | class UpFirDn2d(Function): 89 | @staticmethod 90 | def forward(ctx, input, kernel, up, down, pad): 91 | up_x, up_y = up 92 | down_x, down_y = down 93 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 94 | 95 | kernel_h, kernel_w = kernel.shape 96 | batch, channel, in_h, in_w = input.shape 97 | ctx.in_size = input.shape 98 | 99 | input = input.reshape(-1, in_h, in_w, 1) 100 | 101 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 102 | 103 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 104 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 105 | ctx.out_size = (out_h, out_w) 106 | 107 | ctx.up = (up_x, up_y) 108 | ctx.down = (down_x, down_y) 109 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 110 | 111 | g_pad_x0 = kernel_w - pad_x0 - 1 112 | g_pad_y0 = kernel_h - pad_y0 - 1 113 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 114 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 115 | 116 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 117 | 118 | out = upfirdn2d_op.upfirdn2d( 119 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 120 | ) 121 | # out = out.view(major, out_h, out_w, minor) 122 | out = out.view(-1, channel, out_h, out_w) 123 | 124 | return out 125 | 126 | @staticmethod 127 | def backward(ctx, grad_output): 128 | kernel, grad_kernel = ctx.saved_tensors 129 | 130 | grad_input = UpFirDn2dBackward.apply( 131 | grad_output, 132 | kernel, 133 | grad_kernel, 134 | ctx.up, 135 | ctx.down, 136 | ctx.pad, 137 | ctx.g_pad, 138 | ctx.in_size, 139 | ctx.out_size, 140 | ) 141 | 142 | return grad_input, None, None, None, None 143 | 144 | 145 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 146 | out = UpFirDn2d.apply( 147 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 148 | ) 149 | 150 | return out 151 | 152 | 153 | def upfirdn2d_native( 154 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 155 | ): 156 | _, in_h, in_w, minor = input.shape 157 | kernel_h, kernel_w = kernel.shape 158 | 159 | out = input.view(-1, in_h, 1, in_w, 1, minor) 160 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 161 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 162 | 163 | out = F.pad( 164 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 165 | ) 166 | out = out[ 167 | :, 168 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 169 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 170 | :, 171 | ] 172 | 173 | out = out.permute(0, 3, 1, 2) 174 | out = out.reshape( 175 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 176 | ) 177 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 178 | out = F.conv2d(out, w) 179 | out = out.reshape( 180 | -1, 181 | minor, 182 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 183 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 184 | ) 185 | out = out.permute(0, 2, 3, 1) 186 | 187 | return out[:, ::down_y, ::down_x, :] 188 | 189 | -------------------------------------------------------------------------------- /gan_models/StyleGAN2/op/upfirdn2d_torch_native.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | module_path = os.path.dirname(__file__) 8 | 9 | 10 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 11 | out = upfirdn2d_native( 12 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 13 | ) 14 | 15 | return out 16 | 17 | 18 | def upfirdn2d_native( 19 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 20 | ): 21 | _, channel, in_h, in_w = input.shape 22 | input = input.reshape(-1, in_h, in_w, 1) 23 | 24 | _, in_h, in_w, minor = input.shape 25 | kernel_h, kernel_w = kernel.shape 26 | 27 | out = input.view(-1, in_h, 1, in_w, 1, minor) 28 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 29 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 30 | 31 | out = F.pad( 32 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 33 | ) 34 | out = out[ 35 | :, 36 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 37 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 38 | :, 39 | ] 40 | 41 | out = out.permute(0, 3, 1, 2) 42 | out = out.reshape( 43 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 44 | ) 45 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 46 | out = F.conv2d(out, w) 47 | out = out.reshape( 48 | -1, 49 | minor, 50 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 51 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 52 | ) 53 | out = out.permute(0, 2, 3, 1) 54 | out = out[:, ::down_y, ::down_x, :] 55 | 56 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 57 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 58 | 59 | return out.view(-1, channel, out_h, out_w) 60 | -------------------------------------------------------------------------------- /gan_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/gan_models/__init__.py -------------------------------------------------------------------------------- /gan_models/gan_load.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | from gan_models.BigGAN import BigGAN, utils 7 | from gan_models.ProgGAN.model import Generator as ProgGenerator 8 | from gan_models.SNGAN.load import load_model_from_state_dict 9 | 10 | try: 11 | from gan_models.StyleGAN2.model import Discriminator as StyleGan2Discriminator 12 | from gan_models.StyleGAN2.model import Generator as StyleGAN2Generator 13 | except Exception as e: 14 | print('StyleGAN2 load fail: {}'.format(e)) 15 | 16 | from core.utils.class_registry import ClassRegistry 17 | 18 | generator_registry = ClassRegistry() 19 | 20 | 21 | class ConditionedBigGAN(nn.Module): 22 | def __init__(self, big_gan, target_classes=(239, )): 23 | super(ConditionedBigGAN, self).__init__() 24 | self.big_gan = big_gan 25 | self.target_classes = nn.Parameter(torch.tensor(target_classes, dtype=torch.int64), 26 | requires_grad=False) 27 | 28 | self.dim_z = self.big_gan.dim_z 29 | 30 | def set_classes(self, cl): 31 | try: 32 | cl[0] 33 | except Exception: 34 | cl = [cl] 35 | self.target_classes.data = torch.tensor(cl, dtype=torch.int64) 36 | 37 | def mixed_classes(self, batch_size): 38 | device = next(self.parameters()).device 39 | if len(self.target_classes.data.shape) == 0: 40 | return self.target_classes.repeat(batch_size).cuda() 41 | else: 42 | return torch.from_numpy( 43 | np.random.choice(self.target_classes.cpu(), [batch_size])).to(device) 44 | 45 | def forward(self, z, classes=None): 46 | if classes is None: 47 | classes = self.mixed_classes(z.shape[0]).to(z.device) 48 | 49 | cl_emb = self.big_gan.shared(classes).to(z.device) 50 | return self.big_gan(z, cl_emb) 51 | 52 | 53 | class StyleGAN2Wrapper(nn.Module): 54 | def __init__(self, g, shift_in_w): 55 | super(StyleGAN2Wrapper, self).__init__() 56 | self.style_gan2 = g 57 | self.shift_in_w = shift_in_w 58 | self.dim_z = 512 59 | self.dim_shift = self.style_gan2.style_dim if shift_in_w else self.dim_z 60 | 61 | def forward(self, input, input_is_latent=False): 62 | return self.style_gan2([input], input_is_latent=input_is_latent)[0] 63 | 64 | def gen_shifted(self, z, shift): 65 | if self.shift_in_w: 66 | w = self.style_gan2.get_latent(z) 67 | return self.forward(w + shift, input_is_latent=True) 68 | else: 69 | return self.forward(z + shift, input_is_latent=False) 70 | 71 | 72 | @generator_registry.add_func_to_registry("stylegan2") 73 | def make_style_gan2(size, weights, latent_dim=512, n_layers_mlp=8, shift_in_w=True): 74 | G = StyleGAN2Generator(size, latent_dim, n_layers_mlp) 75 | G.load_state_dict(torch.load(weights, map_location='cpu')['g_ema']) 76 | G.cuda().eval() 77 | 78 | return StyleGAN2Wrapper(G, shift_in_w=shift_in_w) 79 | 80 | 81 | def make_style_gan2_discriminator(size, weights_path): 82 | D = StyleGan2Discriminator(size) 83 | D.load_state_dict(torch.load(weights_path, map_location='cpu')['d']) 84 | return D 85 | 86 | 87 | @generator_registry.add_func_to_registry("biggan") 88 | def make_big_gan(config_path, weights_path, target_classes): 89 | with open(config_path, 'r') as f: 90 | config = json.load(f) 91 | 92 | config['resolution'] = utils.imsize_dict[config['dataset']] 93 | config['n_classes'] = utils.nclass_dict[config['dataset']] 94 | config['G_activation'] = utils.activation_dict[config['G_nl']] 95 | config['D_activation'] = utils.activation_dict[config['D_nl']] 96 | config['skip_init'] = True 97 | config['no_optim'] = True 98 | 99 | G = BigGAN.Generator(**config) 100 | G.load_state_dict(torch.load(weights_path, map_location='cpu'), strict=True) 101 | 102 | return ConditionedBigGAN(G, target_classes).eval() 103 | 104 | 105 | @generator_registry.add_func_to_registry("proggan") 106 | def make_proggan(weights_root): 107 | model = ProgGenerator() 108 | model.load_state_dict(torch.load(weights_root, map_location='cpu')) 109 | model.cuda() 110 | 111 | setattr(model, 'dim_z', [512, 1, 1]) 112 | return model 113 | 114 | 115 | @generator_registry.add_func_to_registry("sn_anime") 116 | def make_sngan(gan_dir): 117 | gan = load_model_from_state_dict(gan_dir) 118 | G = gan.model.eval() 119 | setattr(G, 'dim_z', gan.distribution.dim) 120 | return G 121 | 122 | 123 | @generator_registry.add_func_to_registry("sn_mnist") 124 | def make_sngan(gan_dir): 125 | gan = load_model_from_state_dict(gan_dir) 126 | G = gan.model.eval() 127 | setattr(G, 'dim_z', gan.distribution.dim) 128 | return G 129 | -------------------------------------------------------------------------------- /gan_models/gan_with_shift.py: -------------------------------------------------------------------------------- 1 | import types 2 | from functools import wraps 3 | 4 | 5 | def add_forward_with_shift(generator): 6 | def gen_shifted(self, z, shift, *args, **kwargs): 7 | return self.forward(z + shift, *args, **kwargs) 8 | 9 | generator.gen_shifted = types.MethodType(gen_shifted, generator) 10 | generator.dim_shift = generator.dim_z 11 | 12 | 13 | def gan_with_shift(gan_factory): 14 | @wraps(gan_factory) 15 | def wrapper(*args, **kwargs): 16 | gan = gan_factory(*args, **kwargs) 17 | add_forward_with_shift(gan) 18 | return gan 19 | 20 | return wrapper 21 | -------------------------------------------------------------------------------- /image_domains/anastasia.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/image_domains/anastasia.png -------------------------------------------------------------------------------- /image_domains/anime.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/image_domains/anime.jpg -------------------------------------------------------------------------------- /image_domains/anime_full.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/image_domains/anime_full.jpg -------------------------------------------------------------------------------- /image_domains/brave.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/image_domains/brave.png -------------------------------------------------------------------------------- /image_domains/cached_latents/.gitignore: -------------------------------------------------------------------------------- 1 | *.npy 2 | -------------------------------------------------------------------------------- /image_domains/detroit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/image_domains/detroit.png -------------------------------------------------------------------------------- /image_domains/digital_painting_jing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/image_domains/digital_painting_jing.png -------------------------------------------------------------------------------- /image_domains/disney_princess.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/image_domains/disney_princess.jpg -------------------------------------------------------------------------------- /image_domains/doc_brown.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/image_domains/doc_brown.png -------------------------------------------------------------------------------- /image_domains/image_domains.txt: -------------------------------------------------------------------------------- 1 | ./image_domains/sketch.png 2 | ./image_domains/anastasia.png 3 | ./image_domains/digital_painting_jing.png 4 | ./image_domains/mermaid.png 5 | ./image_domains/speed_paint.png 6 | ./image_domains/titan_armin.png 7 | ./image_domains/titan_erwin.png -------------------------------------------------------------------------------- /image_domains/jojo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/image_domains/jojo.png -------------------------------------------------------------------------------- /image_domains/joker.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/image_domains/joker.png -------------------------------------------------------------------------------- /image_domains/mermaid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/image_domains/mermaid.png -------------------------------------------------------------------------------- /image_domains/moana.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/image_domains/moana.png -------------------------------------------------------------------------------- /image_domains/picasso.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/image_domains/picasso.png -------------------------------------------------------------------------------- /image_domains/pocahontas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/image_domains/pocahontas.png -------------------------------------------------------------------------------- /image_domains/room_girl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/image_domains/room_girl.png -------------------------------------------------------------------------------- /image_domains/sketch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/image_domains/sketch.png -------------------------------------------------------------------------------- /image_domains/speed_paint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/image_domains/speed_paint.png -------------------------------------------------------------------------------- /image_domains/titan_armin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/image_domains/titan_armin.png -------------------------------------------------------------------------------- /image_domains/titan_erwin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/image_domains/titan_erwin.png -------------------------------------------------------------------------------- /image_domains/titan_historia.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/image_domains/titan_historia.png -------------------------------------------------------------------------------- /image_domains/zbrush_girl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/image_domains/zbrush_girl.png -------------------------------------------------------------------------------- /img/cover.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/img/cover.jpg -------------------------------------------------------------------------------- /img/domain_modulation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/img/domain_modulation.png -------------------------------------------------------------------------------- /img/example_im2im_anastasia.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/img/example_im2im_anastasia.jpg -------------------------------------------------------------------------------- /img/example_td_anime.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/img/example_td_anime.jpg -------------------------------------------------------------------------------- /img/example_td_mapper_20.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/img/example_td_mapper_20.jpg -------------------------------------------------------------------------------- /img/example_td_mapper_large.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/img/example_td_mapper_large.jpg -------------------------------------------------------------------------------- /img/hdn_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/img/hdn_diagram.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | from trainers import trainer_registry 3 | 4 | from core.utils.common import setup_seed 5 | from core.utils.arguments import load_config 6 | from pprint import pprint 7 | 8 | 9 | def run_experiment(exp_config): 10 | pprint(OmegaConf.to_container(exp_config)) 11 | setup_seed(exp_config.exp.seed) 12 | trainer = trainer_registry[exp_config.exp.trainer](exp_config) 13 | trainer.setup() 14 | trainer.train_loop() 15 | 16 | 17 | def run_experiment_from_ckpt(): 18 | ... 19 | 20 | 21 | if __name__ == '__main__': 22 | base_config = load_config() 23 | 24 | if base_config.get('checkpoint'): 25 | ... 26 | 27 | run_experiment(base_config) 28 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | omegaconf 2 | wandb 3 | ftfy 4 | regex 5 | tqdm 6 | git+https://github.com/openai/CLIP.git 7 | scikit-image 8 | dlib 9 | click -------------------------------------------------------------------------------- /restyle_encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/restyle_encoders/__init__.py -------------------------------------------------------------------------------- /restyle_encoders/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | 4 | from pathlib import Path 5 | 6 | 7 | def get_download_model_command(save_path, file_id, file_name): 8 | """ Get wget download command for downloading the desired model and save to directory ../pretrained_models. """ 9 | if not os.path.exists(save_path): 10 | os.makedirs(save_path) 11 | url = r"""wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id={FILE_ID}" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt""".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path) 12 | return url 13 | 14 | 15 | MODEL_PATHS = { 16 | "ffhq_encode": {"id": "1sw6I2lRIB0MpuJkpc8F5BJiSZrc0hjfE", "name": "restyle_psp_ffhq_encode.pt"}, 17 | "cars_encode": {"id": "1zJHqHRQ8NOnVohVVCGbeYMMr6PDhRpPR", "name": "restyle_psp_cars_encode.pt"}, 18 | "church_encode": {"id": "1bcxx7mw-1z7dzbJI_z7oGpWG1oQAvMaD", "name": "restyle_psp_church_encode.pt"}, 19 | "horse_encode": {"id": "19_sUpTYtJmhSAolKLm3VgI-ptYqd-hgY", "name": "restyle_e4e_horse_encode.pt"}, 20 | "afhq_wild_encode": {"id": "1GyFXVTNDUw3IIGHmGS71ChhJ1Rmslhk7", "name": "restyle_psp_afhq_wild_encode.pt"}, 21 | "toonify": {"id": "1GtudVDig59d4HJ_8bGEniz5huaTSGO_0", "name": "restyle_psp_toonify.pt"} 22 | } 23 | 24 | 25 | if __name__ == "__main__": 26 | exp = 'ffhq_encode' 27 | path = MODEL_PATHS[exp] 28 | path_to_save = Path(os.getcwd()).resolve() / 'pretrained' 29 | download_command = get_download_model_command(str(path_to_save), file_id=path["id"], file_name=path["name"]) 30 | 31 | if not os.path.exists(path_to_save / path['name']) or os.path.getsize(path_to_save / path['name']) < 1000000: 32 | print(f'Downloading ReStyle model for {exp}...') 33 | subprocess.run(f"wget {download_command}", shell=True, check=True) 34 | 35 | # if google drive receives too many requests, we'll reach the quota limit and be unable to download the model 36 | if os.path.getsize(EXPERIMENT_ARGS['model_path']) < 1000000: 37 | raise ValueError("Pretrained model was unable to be downloaded correctly!") 38 | else: 39 | print('Done.') 40 | else: 41 | print(f'ReStyle model for {exp} already exists!') -------------------------------------------------------------------------------- /restyle_encoders/e4e.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines the core research contribution 3 | """ 4 | import math 5 | import torch 6 | from torch import nn 7 | 8 | from models.stylegan2.model import Generator 9 | from configs.paths_config import model_paths 10 | from models.encoders import restyle_e4e_encoders 11 | from utils.model_utils import RESNET_MAPPING 12 | 13 | 14 | class e4e(nn.Module): 15 | 16 | def __init__(self, opts): 17 | super(e4e, self).__init__() 18 | self.set_opts(opts) 19 | self.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2 20 | # Define architecture 21 | self.encoder = self.set_encoder() 22 | self.decoder = Generator(self.opts.output_size, 512, 8, channel_multiplier=2) 23 | self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 24 | # Load weights if needed 25 | self.load_weights() 26 | 27 | def set_encoder(self): 28 | if self.opts.encoder_type == 'ProgressiveBackboneEncoder': 29 | encoder = restyle_e4e_encoders.ProgressiveBackboneEncoder(50, 'ir_se', self.n_styles, self.opts) 30 | elif self.opts.encoder_type == 'ResNetProgressiveBackboneEncoder': 31 | encoder = restyle_e4e_encoders.ResNetProgressiveBackboneEncoder(self.n_styles, self.opts) 32 | else: 33 | raise Exception(f'{self.opts.encoder_type} is not a valid encoders') 34 | return encoder 35 | 36 | def load_weights(self): 37 | if self.opts.checkpoint_path is not None: 38 | print(f'Loading ReStyle e4e from checkpoint: {self.opts.checkpoint_path}') 39 | ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') 40 | self.encoder.load_state_dict(self.__get_keys(ckpt, 'encoder'), strict=False) 41 | self.decoder.load_state_dict(self.__get_keys(ckpt, 'decoder'), strict=True) 42 | self.__load_latent_avg(ckpt) 43 | else: 44 | encoder_ckpt = self.__get_encoder_checkpoint() 45 | self.encoder.load_state_dict(encoder_ckpt, strict=False) 46 | print(f'Loading decoder weights from pretrained path: {self.opts.stylegan_weights}') 47 | ckpt = torch.load(self.opts.stylegan_weights) 48 | self.decoder.load_state_dict(ckpt['g_ema'], strict=True) 49 | self.__load_latent_avg(ckpt, repeat=self.n_styles) 50 | 51 | def forward(self, x, latent=None, resize=True, latent_mask=None, input_code=False, randomize_noise=True, 52 | inject_latent=None, return_latents=False, alpha=None, average_code=False, input_is_full=False): 53 | if input_code: 54 | codes = x 55 | else: 56 | codes = self.encoder(x) 57 | # residual step 58 | if x.shape[1] == 6 and latent is not None: 59 | # learn error with respect to previous iteration 60 | codes = codes + latent 61 | else: 62 | # first iteration is with respect to the avg latent code 63 | codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1) 64 | 65 | if latent_mask is not None: 66 | for i in latent_mask: 67 | if inject_latent is not None: 68 | if alpha is not None: 69 | codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i] 70 | else: 71 | codes[:, i] = inject_latent[:, i] 72 | else: 73 | codes[:, i] = 0 74 | 75 | if average_code: 76 | input_is_latent = True 77 | else: 78 | input_is_latent = (not input_code) or (input_is_full) 79 | 80 | images, result_latent = self.decoder([codes], 81 | input_is_latent=input_is_latent, 82 | randomize_noise=randomize_noise, 83 | return_latents=return_latents) 84 | 85 | if resize: 86 | images = self.face_pool(images) 87 | 88 | if return_latents: 89 | return images, result_latent 90 | else: 91 | return images 92 | 93 | def set_opts(self, opts): 94 | self.opts = opts 95 | 96 | def __load_latent_avg(self, ckpt, repeat=None): 97 | if 'latent_avg' in ckpt: 98 | self.latent_avg = ckpt['latent_avg'].to(self.opts.device) 99 | if repeat is not None: 100 | self.latent_avg = self.latent_avg.repeat(repeat, 1) 101 | else: 102 | self.latent_avg = None 103 | 104 | def __get_encoder_checkpoint(self): 105 | if "ffhq" in self.opts.dataset_type: 106 | print('Loading encoders weights from irse50!') 107 | encoder_ckpt = torch.load(model_paths['ir_se50']) 108 | # Transfer the RGB input of the irse50 network to the first 3 input channels of pSp's encoder 109 | if self.opts.input_nc != 3: 110 | shape = encoder_ckpt['input_layer.0.weight'].shape 111 | altered_input_layer = torch.randn(shape[0], self.opts.input_nc, shape[2], shape[3], dtype=torch.float32) 112 | altered_input_layer[:, :3, :, :] = encoder_ckpt['input_layer.0.weight'] 113 | encoder_ckpt['input_layer.0.weight'] = altered_input_layer 114 | return encoder_ckpt 115 | else: 116 | print('Loading encoders weights from resnet34!') 117 | encoder_ckpt = torch.load(model_paths['resnet34']) 118 | # Transfer the RGB input of the resnet34 network to the first 3 input channels of pSp's encoder 119 | if self.opts.input_nc != 3: 120 | shape = encoder_ckpt['conv1.weight'].shape 121 | altered_input_layer = torch.randn(shape[0], self.opts.input_nc, shape[2], shape[3], dtype=torch.float32) 122 | altered_input_layer[:, :3, :, :] = encoder_ckpt['conv1.weight'] 123 | encoder_ckpt['conv1.weight'] = altered_input_layer 124 | mapped_encoder_ckpt = dict(encoder_ckpt) 125 | for p, v in encoder_ckpt.items(): 126 | for original_name, psp_name in RESNET_MAPPING.items(): 127 | if original_name in p: 128 | mapped_encoder_ckpt[p.replace(original_name, psp_name)] = v 129 | mapped_encoder_ckpt.pop(p) 130 | return encoder_ckpt 131 | 132 | @staticmethod 133 | def __get_keys(d, name): 134 | if 'state_dict' in d: 135 | d = d['state_dict'] 136 | d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} 137 | return d_filt 138 | -------------------------------------------------------------------------------- /restyle_encoders/e4e_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/restyle_encoders/e4e_modules/__init__.py -------------------------------------------------------------------------------- /restyle_encoders/e4e_modules/discriminator.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class LatentCodesDiscriminator(nn.Module): 5 | def __init__(self, style_dim, n_mlp): 6 | super().__init__() 7 | 8 | self.style_dim = style_dim 9 | 10 | layers = [] 11 | for i in range(n_mlp-1): 12 | layers.append( 13 | nn.Linear(style_dim, style_dim) 14 | ) 15 | layers.append(nn.LeakyReLU(0.2)) 16 | layers.append(nn.Linear(512, 1)) 17 | self.mlp = nn.Sequential(*layers) 18 | 19 | def forward(self, w): 20 | return self.mlp(w) 21 | -------------------------------------------------------------------------------- /restyle_encoders/e4e_modules/latent_codes_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class LatentCodesPool: 6 | """This class implements latent codes buffer that stores previously generated w latent codes. 7 | This buffer enables us to update discriminators using a history of generated w's 8 | rather than the ones produced by the latest encoder. 9 | """ 10 | 11 | def __init__(self, pool_size): 12 | """Initialize the ImagePool class 13 | Parameters: 14 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 15 | """ 16 | self.pool_size = pool_size 17 | if self.pool_size > 0: # create an empty pool 18 | self.num_ws = 0 19 | self.ws = [] 20 | 21 | def query(self, ws): 22 | """Return w's from the pool. 23 | Parameters: 24 | ws: the latest generated w's from the generator 25 | Returns w's from the buffer. 26 | By 50/100, the buffer will return input w's. 27 | By 50/100, the buffer will return w's previously stored in the buffer, 28 | and insert the current w's to the buffer. 29 | """ 30 | if self.pool_size == 0: # if the buffer size is 0, do nothing 31 | return ws 32 | return_ws = [] 33 | for w in ws: # ws.shape: (batch, 512) or (batch, n_latent, 512) 34 | # w = torch.unsqueeze(image.data, 0) 35 | if w.ndim == 2: 36 | i = random.randint(0, len(w) - 1) # apply a random latent index as a candidate 37 | w = w[i] 38 | self.handle_w(w, return_ws) 39 | return_ws = torch.stack(return_ws, 0) # collect all the images and return 40 | return return_ws 41 | 42 | def handle_w(self, w, return_ws): 43 | if self.num_ws < self.pool_size: # if the buffer is not full; keep inserting current codes to the buffer 44 | self.num_ws = self.num_ws + 1 45 | self.ws.append(w) 46 | return_ws.append(w) 47 | else: 48 | p = random.uniform(0, 1) 49 | if p > 0.5: # by 50% chance, the buffer will return a previously stored latent code, and insert the current code into the buffer 50 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 51 | tmp = self.ws[random_id].clone() 52 | self.ws[random_id] = w 53 | return_ws.append(tmp) 54 | else: # by another 50% chance, the buffer will return the current image 55 | return_ws.append(w) 56 | -------------------------------------------------------------------------------- /restyle_encoders/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/restyle_encoders/encoders/__init__.py -------------------------------------------------------------------------------- /restyle_encoders/encoders/fpn_encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module 5 | from torchvision.models.resnet import resnet34 6 | 7 | from restyle_encoders.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE 8 | from restyle_encoders.encoders.map2style import GradualStyleBlock 9 | 10 | 11 | class GradualStyleEncoder(Module): 12 | """ 13 | Original encoder architecture from pixel2style2pixel. This classes uses an FPN-based architecture applied over 14 | an ResNet IRSE-50 backbone. 15 | Note this class is designed to be used for the human facial domain. 16 | """ 17 | def __init__(self, num_layers, mode='ir', n_styles=18, opts=None): 18 | super(GradualStyleEncoder, self).__init__() 19 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 20 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 21 | blocks = get_blocks(num_layers) 22 | if mode == 'ir': 23 | unit_module = bottleneck_IR 24 | elif mode == 'ir_se': 25 | unit_module = bottleneck_IR_SE 26 | self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False), 27 | BatchNorm2d(64), 28 | PReLU(64)) 29 | modules = [] 30 | for block in blocks: 31 | for bottleneck in block: 32 | modules.append(unit_module(bottleneck.in_channel, 33 | bottleneck.depth, 34 | bottleneck.stride)) 35 | self.body = Sequential(*modules) 36 | 37 | self.styles = nn.ModuleList() 38 | self.style_count = n_styles 39 | self.coarse_ind = 3 40 | self.middle_ind = 7 41 | for i in range(self.style_count): 42 | if i < self.coarse_ind: 43 | style = GradualStyleBlock(512, 512, 16) 44 | elif i < self.middle_ind: 45 | style = GradualStyleBlock(512, 512, 32) 46 | else: 47 | style = GradualStyleBlock(512, 512, 64) 48 | self.styles.append(style) 49 | self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) 50 | self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) 51 | 52 | def _upsample_add(self, x, y): 53 | _, _, H, W = y.size() 54 | return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y 55 | 56 | def forward(self, x): 57 | x = self.input_layer(x) 58 | 59 | latents = [] 60 | modulelist = list(self.body._modules.values()) 61 | for i, l in enumerate(modulelist): 62 | x = l(x) 63 | if i == 6: 64 | c1 = x 65 | elif i == 20: 66 | c2 = x 67 | elif i == 23: 68 | c3 = x 69 | 70 | for j in range(self.coarse_ind): 71 | latents.append(self.styles[j](c3)) 72 | 73 | p2 = self._upsample_add(c3, self.latlayer1(c2)) 74 | for j in range(self.coarse_ind, self.middle_ind): 75 | latents.append(self.styles[j](p2)) 76 | 77 | p1 = self._upsample_add(p2, self.latlayer2(c1)) 78 | for j in range(self.middle_ind, self.style_count): 79 | latents.append(self.styles[j](p1)) 80 | 81 | out = torch.stack(latents, dim=1) 82 | return out 83 | 84 | 85 | class ResNetGradualStyleEncoder(Module): 86 | """ 87 | Original encoder architecture from pixel2style2pixel. This classes uses an FPN-based architecture applied over 88 | an ResNet34 backbone. 89 | """ 90 | def __init__(self, n_styles=18, opts=None): 91 | super(ResNetGradualStyleEncoder, self).__init__() 92 | 93 | self.conv1 = nn.Conv2d(opts.input_nc, 64, kernel_size=7, stride=2, padding=3, bias=False) 94 | self.bn1 = BatchNorm2d(64) 95 | self.relu = PReLU(64) 96 | 97 | resnet_basenet = resnet34(pretrained=True) 98 | blocks = [ 99 | resnet_basenet.layer1, 100 | resnet_basenet.layer2, 101 | resnet_basenet.layer3, 102 | resnet_basenet.layer4 103 | ] 104 | 105 | modules = [] 106 | for block in blocks: 107 | for bottleneck in block: 108 | modules.append(bottleneck) 109 | 110 | self.body = Sequential(*modules) 111 | 112 | self.styles = nn.ModuleList() 113 | self.style_count = n_styles 114 | self.coarse_ind = 3 115 | self.middle_ind = 7 116 | for i in range(self.style_count): 117 | if i < self.coarse_ind: 118 | style = GradualStyleBlock(512, 512, 16) 119 | elif i < self.middle_ind: 120 | style = GradualStyleBlock(512, 512, 32) 121 | else: 122 | style = GradualStyleBlock(512, 512, 64) 123 | self.styles.append(style) 124 | self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) 125 | self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) 126 | 127 | def _upsample_add(self, x, y): 128 | _, _, H, W = y.size() 129 | return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y 130 | 131 | def forward(self, x): 132 | x = self.conv1(x) 133 | x = self.bn1(x) 134 | x = self.relu(x) 135 | 136 | latents = [] 137 | modulelist = list(self.body._modules.values()) 138 | for i, l in enumerate(modulelist): 139 | x = l(x) 140 | if i == 6: 141 | c1 = x 142 | elif i == 12: 143 | c2 = x 144 | elif i == 15: 145 | c3 = x 146 | 147 | for j in range(self.coarse_ind): 148 | latents.append(self.styles[j](c3)) 149 | 150 | p2 = self._upsample_add(c3, self.latlayer1(c2)) 151 | for j in range(self.coarse_ind, self.middle_ind): 152 | latents.append(self.styles[j](p2)) 153 | 154 | p1 = self._upsample_add(p2, self.latlayer2(c1)) 155 | for j in range(self.middle_ind, self.style_count): 156 | latents.append(self.styles[j](p1)) 157 | 158 | out = torch.stack(latents, dim=1) 159 | return out 160 | -------------------------------------------------------------------------------- /restyle_encoders/encoders/helpers.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 4 | 5 | """ 6 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 7 | """ 8 | 9 | 10 | class Flatten(Module): 11 | def forward(self, input): 12 | return input.view(input.size(0), -1) 13 | 14 | 15 | def l2_norm(input, axis=1): 16 | norm = torch.norm(input, 2, axis, True) 17 | output = torch.div(input, norm) 18 | return output 19 | 20 | 21 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 22 | """ A named tuple describing a ResNet block. """ 23 | 24 | 25 | def get_block(in_channel, depth, num_units, stride=2): 26 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 27 | 28 | 29 | def get_blocks(num_layers): 30 | if num_layers == 50: 31 | blocks = [ 32 | get_block(in_channel=64, depth=64, num_units=3), 33 | get_block(in_channel=64, depth=128, num_units=4), 34 | get_block(in_channel=128, depth=256, num_units=14), 35 | get_block(in_channel=256, depth=512, num_units=3) 36 | ] 37 | elif num_layers == 100: 38 | blocks = [ 39 | get_block(in_channel=64, depth=64, num_units=3), 40 | get_block(in_channel=64, depth=128, num_units=13), 41 | get_block(in_channel=128, depth=256, num_units=30), 42 | get_block(in_channel=256, depth=512, num_units=3) 43 | ] 44 | elif num_layers == 152: 45 | blocks = [ 46 | get_block(in_channel=64, depth=64, num_units=3), 47 | get_block(in_channel=64, depth=128, num_units=8), 48 | get_block(in_channel=128, depth=256, num_units=36), 49 | get_block(in_channel=256, depth=512, num_units=3) 50 | ] 51 | else: 52 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 53 | return blocks 54 | 55 | 56 | class SEModule(Module): 57 | def __init__(self, channels, reduction): 58 | super(SEModule, self).__init__() 59 | self.avg_pool = AdaptiveAvgPool2d(1) 60 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 61 | self.relu = ReLU(inplace=True) 62 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 63 | self.sigmoid = Sigmoid() 64 | 65 | def forward(self, x): 66 | module_input = x 67 | x = self.avg_pool(x) 68 | x = self.fc1(x) 69 | x = self.relu(x) 70 | x = self.fc2(x) 71 | x = self.sigmoid(x) 72 | return module_input * x 73 | 74 | 75 | class bottleneck_IR(Module): 76 | def __init__(self, in_channel, depth, stride): 77 | super(bottleneck_IR, self).__init__() 78 | if in_channel == depth: 79 | self.shortcut_layer = MaxPool2d(1, stride) 80 | else: 81 | self.shortcut_layer = Sequential( 82 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 83 | BatchNorm2d(depth) 84 | ) 85 | self.res_layer = Sequential( 86 | BatchNorm2d(in_channel), 87 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 88 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 89 | ) 90 | 91 | def forward(self, x): 92 | shortcut = self.shortcut_layer(x) 93 | res = self.res_layer(x) 94 | return res + shortcut 95 | 96 | 97 | class bottleneck_IR_SE(Module): 98 | def __init__(self, in_channel, depth, stride): 99 | super(bottleneck_IR_SE, self).__init__() 100 | if in_channel == depth: 101 | self.shortcut_layer = MaxPool2d(1, stride) 102 | else: 103 | self.shortcut_layer = Sequential( 104 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 105 | BatchNorm2d(depth) 106 | ) 107 | self.res_layer = Sequential( 108 | BatchNorm2d(in_channel), 109 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 110 | PReLU(depth), 111 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 112 | BatchNorm2d(depth), 113 | SEModule(depth, 16) 114 | ) 115 | 116 | def forward(self, x): 117 | shortcut = self.shortcut_layer(x) 118 | res = self.res_layer(x) 119 | return res + shortcut 120 | -------------------------------------------------------------------------------- /restyle_encoders/encoders/map2style.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import nn 3 | from torch.nn import Conv2d, Module 4 | 5 | from gan_models.StyleGAN2.model import EqualLinear 6 | 7 | 8 | class GradualStyleBlock(Module): 9 | def __init__(self, in_c, out_c, spatial): 10 | super(GradualStyleBlock, self).__init__() 11 | self.out_c = out_c 12 | self.spatial = spatial 13 | num_pools = int(np.log2(spatial)) 14 | modules = [] 15 | modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1), 16 | nn.LeakyReLU()] 17 | for i in range(num_pools - 1): 18 | modules += [ 19 | Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1), 20 | nn.LeakyReLU() 21 | ] 22 | self.convs = nn.Sequential(*modules) 23 | self.linear = EqualLinear(out_c, out_c, lr_mul=1) 24 | 25 | def forward(self, x): 26 | x = self.convs(x) 27 | x = x.view(-1, self.out_c) 28 | x = self.linear(x) 29 | return x 30 | -------------------------------------------------------------------------------- /restyle_encoders/encoders/model_irse.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 2 | from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 3 | 4 | """ 5 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 6 | """ 7 | 8 | 9 | class Backbone(Module): 10 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 11 | super(Backbone, self).__init__() 12 | assert input_size in [112, 224], "input_size should be 112 or 224" 13 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 14 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 15 | blocks = get_blocks(num_layers) 16 | if mode == 'ir': 17 | unit_module = bottleneck_IR 18 | elif mode == 'ir_se': 19 | unit_module = bottleneck_IR_SE 20 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 21 | BatchNorm2d(64), 22 | PReLU(64)) 23 | if input_size == 112: 24 | self.output_layer = Sequential(BatchNorm2d(512), 25 | Dropout(drop_ratio), 26 | Flatten(), 27 | Linear(512 * 7 * 7, 512), 28 | BatchNorm1d(512, affine=affine)) 29 | else: 30 | self.output_layer = Sequential(BatchNorm2d(512), 31 | Dropout(drop_ratio), 32 | Flatten(), 33 | Linear(512 * 14 * 14, 512), 34 | BatchNorm1d(512, affine=affine)) 35 | 36 | modules = [] 37 | for block in blocks: 38 | for bottleneck in block: 39 | modules.append(unit_module(bottleneck.in_channel, 40 | bottleneck.depth, 41 | bottleneck.stride)) 42 | self.body = Sequential(*modules) 43 | 44 | def forward(self, x): 45 | x = self.input_layer(x) 46 | x = self.body(x) 47 | x = self.output_layer(x) 48 | return l2_norm(x) 49 | 50 | 51 | def IR_50(input_size): 52 | """Constructs a ir-50 model.""" 53 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 54 | return model 55 | 56 | 57 | def IR_101(input_size): 58 | """Constructs a ir-101 model.""" 59 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 60 | return model 61 | 62 | 63 | def IR_152(input_size): 64 | """Constructs a ir-152 model.""" 65 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 66 | return model 67 | 68 | 69 | def IR_SE_50(input_size): 70 | """Constructs a ir_se-50 model.""" 71 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 72 | return model 73 | 74 | 75 | def IR_SE_101(input_size): 76 | """Constructs a ir_se-101 model.""" 77 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 78 | return model 79 | 80 | 81 | def IR_SE_152(input_size): 82 | """Constructs a ir_se-152 model.""" 83 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 84 | return model 85 | -------------------------------------------------------------------------------- /restyle_encoders/encoders/restyle_e4e_encoders.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from torch import nn 3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module 4 | from torchvision.models import resnet34 5 | 6 | from models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE 7 | from models.encoders.map2style import GradualStyleBlock 8 | 9 | 10 | class ProgressiveStage(Enum): 11 | WTraining = 0 12 | Delta1Training = 1 13 | Delta2Training = 2 14 | Delta3Training = 3 15 | Delta4Training = 4 16 | Delta5Training = 5 17 | Delta6Training = 6 18 | Delta7Training = 7 19 | Delta8Training = 8 20 | Delta9Training = 9 21 | Delta10Training = 10 22 | Delta11Training = 11 23 | Delta12Training = 12 24 | Delta13Training = 13 25 | Delta14Training = 14 26 | Delta15Training = 15 27 | Delta16Training = 16 28 | Delta17Training = 17 29 | Inference = 18 30 | 31 | 32 | class ProgressiveBackboneEncoder(Module): 33 | """ 34 | The simpler backbone architecture used by ReStyle where all style vectors are extracted from the final 16x16 feature 35 | map of the encoder. This classes uses the simplified architecture applied over an ResNet IRSE50 backbone with the 36 | progressive training scheme from e4e_modules. 37 | Note this class is designed to be used for the human facial domain. 38 | """ 39 | def __init__(self, num_layers, mode='ir', n_styles=18, opts=None): 40 | super(ProgressiveBackboneEncoder, self).__init__() 41 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 42 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 43 | blocks = get_blocks(num_layers) 44 | if mode == 'ir': 45 | unit_module = bottleneck_IR 46 | elif mode == 'ir_se': 47 | unit_module = bottleneck_IR_SE 48 | 49 | self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False), 50 | BatchNorm2d(64), 51 | PReLU(64)) 52 | modules = [] 53 | for block in blocks: 54 | for bottleneck in block: 55 | modules.append(unit_module(bottleneck.in_channel, 56 | bottleneck.depth, 57 | bottleneck.stride)) 58 | self.body = Sequential(*modules) 59 | 60 | self.styles = nn.ModuleList() 61 | self.style_count = n_styles 62 | for i in range(self.style_count): 63 | style = GradualStyleBlock(512, 512, 16) 64 | self.styles.append(style) 65 | self.progressive_stage = ProgressiveStage.Inference 66 | 67 | def get_deltas_starting_dimensions(self): 68 | ''' Get a list of the initial dimension of every delta from which it is applied ''' 69 | return list(range(self.style_count)) # Each dimension has a delta applied to 70 | 71 | def set_progressive_stage(self, new_stage: ProgressiveStage): 72 | # In this encoder we train all the pyramid (At least as a first stage experiment 73 | self.progressive_stage = new_stage 74 | print('Changed progressive stage to: ', new_stage) 75 | 76 | def forward(self, x): 77 | x = self.input_layer(x) 78 | x = self.body(x) 79 | 80 | # get initial w0 from first map2style layer 81 | w0 = self.styles[0](x) 82 | w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2) 83 | 84 | # learn the deltas up to the current stage 85 | stage = self.progressive_stage.value 86 | for i in range(1, min(stage + 1, self.style_count)): 87 | delta_i = self.styles[i](x) 88 | w[:, i] += delta_i 89 | return w 90 | 91 | 92 | class ResNetProgressiveBackboneEncoder(Module): 93 | """ 94 | The simpler backbone architecture used by ReStyle where all style vectors are extracted from the final 16x16 feature 95 | map of the encoder. This classes uses the simplified architecture applied over an ResNet34 backbone with the 96 | progressive training scheme from e4e_modules. 97 | """ 98 | def __init__(self, n_styles=18, opts=None): 99 | super(ResNetProgressiveBackboneEncoder, self).__init__() 100 | 101 | self.conv1 = nn.Conv2d(opts.input_nc, 64, kernel_size=7, stride=2, padding=3, bias=False) 102 | self.bn1 = BatchNorm2d(64) 103 | self.relu = PReLU(64) 104 | 105 | resnet_basenet = resnet34(pretrained=True) 106 | blocks = [ 107 | resnet_basenet.layer1, 108 | resnet_basenet.layer2, 109 | resnet_basenet.layer3, 110 | resnet_basenet.layer4 111 | ] 112 | modules = [] 113 | for block in blocks: 114 | for bottleneck in block: 115 | modules.append(bottleneck) 116 | self.body = Sequential(*modules) 117 | 118 | self.styles = nn.ModuleList() 119 | self.style_count = n_styles 120 | for i in range(self.style_count): 121 | style = GradualStyleBlock(512, 512, 16) 122 | self.styles.append(style) 123 | self.progressive_stage = ProgressiveStage.Inference 124 | 125 | def get_deltas_starting_dimensions(self): 126 | ''' Get a list of the initial dimension of every delta from which it is applied ''' 127 | return list(range(self.style_count)) # Each dimension has a delta applied to 128 | 129 | def set_progressive_stage(self, new_stage: ProgressiveStage): 130 | # In this encoder we train all the pyramid (At least as a first stage experiment 131 | self.progressive_stage = new_stage 132 | print('Changed progressive stage to: ', new_stage) 133 | 134 | def forward(self, x): 135 | x = self.conv1(x) 136 | x = self.bn1(x) 137 | x = self.relu(x) 138 | x = self.body(x) 139 | 140 | # get initial w0 from first map2style layer 141 | w0 = self.styles[0](x) 142 | w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2) 143 | 144 | # learn the deltas up to the current stage 145 | stage = self.progressive_stage.value 146 | for i in range(1, min(stage + 1, self.style_count)): 147 | delta_i = self.styles[i](x) 148 | w[:, i] += delta_i 149 | return w 150 | -------------------------------------------------------------------------------- /restyle_encoders/encoders/restyle_psp_encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module 4 | from torchvision.models.resnet import resnet34 5 | 6 | from restyle_encoders.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE 7 | from restyle_encoders.encoders.map2style import GradualStyleBlock 8 | 9 | 10 | class BackboneEncoder(Module): 11 | """ 12 | The simpler backbone architecture used by ReStyle where all style vectors are extracted from the final 16x16 feature 13 | map of the encoder. This classes uses the simplified architecture applied over an ResNet IRSE-50 backbone. 14 | Note this class is designed to be used for the human facial domain. 15 | """ 16 | def __init__(self, num_layers, mode='ir', n_styles=18, opts=None): 17 | super(BackboneEncoder, self).__init__() 18 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 19 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 20 | blocks = get_blocks(num_layers) 21 | if mode == 'ir': 22 | unit_module = bottleneck_IR 23 | elif mode == 'ir_se': 24 | unit_module = bottleneck_IR_SE 25 | 26 | self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False), 27 | BatchNorm2d(64), 28 | PReLU(64)) 29 | modules = [] 30 | for block in blocks: 31 | for bottleneck in block: 32 | modules.append(unit_module(bottleneck.in_channel, 33 | bottleneck.depth, 34 | bottleneck.stride)) 35 | self.body = Sequential(*modules) 36 | 37 | self.styles = nn.ModuleList() 38 | self.style_count = n_styles 39 | for i in range(self.style_count): 40 | style = GradualStyleBlock(512, 512, 16) 41 | self.styles.append(style) 42 | 43 | def forward(self, x): 44 | x = self.input_layer(x) 45 | x = self.body(x) 46 | latents = [] 47 | for j in range(self.style_count): 48 | latents.append(self.styles[j](x)) 49 | out = torch.stack(latents, dim=1) 50 | return out 51 | 52 | 53 | class ResNetBackboneEncoder(Module): 54 | """ 55 | The simpler backbone architecture used by ReStyle where all style vectors are extracted from the final 16x16 feature 56 | map of the encoder. This classes uses the simplified architecture applied over an ResNet34 backbone. 57 | """ 58 | def __init__(self, n_styles=18, opts=None): 59 | super(ResNetBackboneEncoder, self).__init__() 60 | 61 | self.conv1 = nn.Conv2d(opts.input_nc, 64, kernel_size=7, stride=2, padding=3, bias=False) 62 | self.bn1 = BatchNorm2d(64) 63 | self.relu = PReLU(64) 64 | 65 | resnet_basenet = resnet34(pretrained=True) 66 | blocks = [ 67 | resnet_basenet.layer1, 68 | resnet_basenet.layer2, 69 | resnet_basenet.layer3, 70 | resnet_basenet.layer4 71 | ] 72 | modules = [] 73 | for block in blocks: 74 | for bottleneck in block: 75 | modules.append(bottleneck) 76 | self.body = Sequential(*modules) 77 | 78 | self.styles = nn.ModuleList() 79 | self.style_count = n_styles 80 | for i in range(self.style_count): 81 | style = GradualStyleBlock(512, 512, 16) 82 | self.styles.append(style) 83 | 84 | def forward(self, x): 85 | x = self.conv1(x) 86 | x = self.bn1(x) 87 | x = self.relu(x) 88 | x = self.body(x) 89 | latents = [] 90 | for j in range(self.style_count): 91 | latents.append(self.styles[j](x)) 92 | out = torch.stack(latents, dim=1) 93 | return out 94 | -------------------------------------------------------------------------------- /restyle_encoders/mtcnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/restyle_encoders/mtcnn/__init__.py -------------------------------------------------------------------------------- /restyle_encoders/mtcnn/mtcnn_pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/restyle_encoders/mtcnn/mtcnn_pytorch/__init__.py -------------------------------------------------------------------------------- /restyle_encoders/mtcnn/mtcnn_pytorch/src/__init__.py: -------------------------------------------------------------------------------- 1 | from .visualization_utils import show_bboxes 2 | from .detector import detect_faces 3 | -------------------------------------------------------------------------------- /restyle_encoders/mtcnn/mtcnn_pytorch/src/detector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from .get_nets import PNet, RNet, ONet 4 | from .box_utils import nms, calibrate_box, get_image_boxes, convert_to_square 5 | from .first_stage import run_first_stage 6 | 7 | 8 | def detect_faces(image, min_face_size=20.0, 9 | thresholds=[0.6, 0.7, 0.8], 10 | nms_thresholds=[0.7, 0.7, 0.7]): 11 | """ 12 | Arguments: 13 | image: an instance of PIL.Image. 14 | min_face_size: a float number. 15 | thresholds: a list of length 3. 16 | nms_thresholds: a list of length 3. 17 | 18 | Returns: 19 | two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10], 20 | bounding boxes and facial landmarks. 21 | """ 22 | 23 | # LOAD MODELS 24 | pnet = PNet() 25 | rnet = RNet() 26 | onet = ONet() 27 | onet.eval() 28 | 29 | # BUILD AN IMAGE PYRAMID 30 | width, height = image.size 31 | min_length = min(height, width) 32 | 33 | min_detection_size = 12 34 | factor = 0.707 # sqrt(0.5) 35 | 36 | # scales for scaling the image 37 | scales = [] 38 | 39 | # scales the image so that 40 | # minimum size that we can detect equals to 41 | # minimum face size that we want to detect 42 | m = min_detection_size / min_face_size 43 | min_length *= m 44 | 45 | factor_count = 0 46 | while min_length > min_detection_size: 47 | scales.append(m * factor ** factor_count) 48 | min_length *= factor 49 | factor_count += 1 50 | 51 | # STAGE 1 52 | 53 | # it will be returned 54 | bounding_boxes = [] 55 | 56 | with torch.no_grad(): 57 | # run P-Net on different scales 58 | for s in scales: 59 | boxes = run_first_stage(image, pnet, scale=s, threshold=thresholds[0]) 60 | bounding_boxes.append(boxes) 61 | 62 | # collect boxes (and offsets, and scores) from different scales 63 | bounding_boxes = [i for i in bounding_boxes if i is not None] 64 | bounding_boxes = np.vstack(bounding_boxes) 65 | 66 | keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0]) 67 | bounding_boxes = bounding_boxes[keep] 68 | 69 | # use offsets predicted by pnet to transform bounding boxes 70 | bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:]) 71 | # shape [n_boxes, 5] 72 | 73 | bounding_boxes = convert_to_square(bounding_boxes) 74 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 75 | 76 | # STAGE 2 77 | 78 | img_boxes = get_image_boxes(bounding_boxes, image, size=24) 79 | img_boxes = torch.FloatTensor(img_boxes) 80 | 81 | output = rnet(img_boxes) 82 | offsets = output[0].data.numpy() # shape [n_boxes, 4] 83 | probs = output[1].data.numpy() # shape [n_boxes, 2] 84 | 85 | keep = np.where(probs[:, 1] > thresholds[1])[0] 86 | bounding_boxes = bounding_boxes[keep] 87 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 88 | offsets = offsets[keep] 89 | 90 | keep = nms(bounding_boxes, nms_thresholds[1]) 91 | bounding_boxes = bounding_boxes[keep] 92 | bounding_boxes = calibrate_box(bounding_boxes, offsets[keep]) 93 | bounding_boxes = convert_to_square(bounding_boxes) 94 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 95 | 96 | # STAGE 3 97 | 98 | img_boxes = get_image_boxes(bounding_boxes, image, size=48) 99 | if len(img_boxes) == 0: 100 | return [], [] 101 | img_boxes = torch.FloatTensor(img_boxes) 102 | output = onet(img_boxes) 103 | landmarks = output[0].data.numpy() # shape [n_boxes, 10] 104 | offsets = output[1].data.numpy() # shape [n_boxes, 4] 105 | probs = output[2].data.numpy() # shape [n_boxes, 2] 106 | 107 | keep = np.where(probs[:, 1] > thresholds[2])[0] 108 | bounding_boxes = bounding_boxes[keep] 109 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 110 | offsets = offsets[keep] 111 | landmarks = landmarks[keep] 112 | 113 | # compute landmark points 114 | width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0 115 | height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0 116 | xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1] 117 | landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5] 118 | landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10] 119 | 120 | bounding_boxes = calibrate_box(bounding_boxes, offsets) 121 | keep = nms(bounding_boxes, nms_thresholds[2], mode='min') 122 | bounding_boxes = bounding_boxes[keep] 123 | landmarks = landmarks[keep] 124 | 125 | return bounding_boxes, landmarks 126 | -------------------------------------------------------------------------------- /restyle_encoders/mtcnn/mtcnn_pytorch/src/first_stage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from PIL import Image 4 | import numpy as np 5 | from .box_utils import nms, _preprocess 6 | 7 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 8 | device = 'cuda:0' 9 | 10 | 11 | def run_first_stage(image, net, scale, threshold): 12 | """Run P-Net, generate bounding boxes, and do NMS. 13 | 14 | Arguments: 15 | image: an instance of PIL.Image. 16 | net: an instance of pytorch's nn.Module, P-Net. 17 | scale: a float number, 18 | scale width and height of the image by this number. 19 | threshold: a float number, 20 | threshold on the probability of a face when generating 21 | bounding boxes from predictions of the net. 22 | 23 | Returns: 24 | a float numpy array of shape [n_boxes, 9], 25 | bounding boxes with scores and offsets (4 + 1 + 4). 26 | """ 27 | 28 | # scale the image and convert it to a float array 29 | width, height = image.size 30 | sw, sh = math.ceil(width * scale), math.ceil(height * scale) 31 | img = image.resize((sw, sh), Image.BILINEAR) 32 | img = np.asarray(img, 'float32') 33 | 34 | img = torch.FloatTensor(_preprocess(img)).to(device) 35 | with torch.no_grad(): 36 | output = net(img) 37 | probs = output[1].cpu().data.numpy()[0, 1, :, :] 38 | offsets = output[0].cpu().data.numpy() 39 | # probs: probability of a face at each sliding window 40 | # offsets: transformations to true bounding boxes 41 | 42 | boxes = _generate_bboxes(probs, offsets, scale, threshold) 43 | if len(boxes) == 0: 44 | return None 45 | 46 | keep = nms(boxes[:, 0:5], overlap_threshold=0.5) 47 | return boxes[keep] 48 | 49 | 50 | def _generate_bboxes(probs, offsets, scale, threshold): 51 | """Generate bounding boxes at places 52 | where there is probably a face. 53 | 54 | Arguments: 55 | probs: a float numpy array of shape [n, m]. 56 | offsets: a float numpy array of shape [1, 4, n, m]. 57 | scale: a float number, 58 | width and height of the image were scaled by this number. 59 | threshold: a float number. 60 | 61 | Returns: 62 | a float numpy array of shape [n_boxes, 9] 63 | """ 64 | 65 | # applying P-Net is equivalent, in some sense, to 66 | # moving 12x12 window with stride 2 67 | stride = 2 68 | cell_size = 12 69 | 70 | # indices of boxes where there is probably a face 71 | inds = np.where(probs > threshold) 72 | 73 | if inds[0].size == 0: 74 | return np.array([]) 75 | 76 | # transformations of bounding boxes 77 | tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)] 78 | # they are defined as: 79 | # w = x2 - x1 + 1 80 | # h = y2 - y1 + 1 81 | # x1_true = x1 + tx1*w 82 | # x2_true = x2 + tx2*w 83 | # y1_true = y1 + ty1*h 84 | # y2_true = y2 + ty2*h 85 | 86 | offsets = np.array([tx1, ty1, tx2, ty2]) 87 | score = probs[inds[0], inds[1]] 88 | 89 | # P-Net is applied to scaled images 90 | # so we need to rescale bounding boxes back 91 | bounding_boxes = np.vstack([ 92 | np.round((stride * inds[1] + 1.0) / scale), 93 | np.round((stride * inds[0] + 1.0) / scale), 94 | np.round((stride * inds[1] + 1.0 + cell_size) / scale), 95 | np.round((stride * inds[0] + 1.0 + cell_size) / scale), 96 | score, offsets 97 | ]) 98 | # why one is added? 99 | 100 | return bounding_boxes.T 101 | -------------------------------------------------------------------------------- /restyle_encoders/mtcnn/mtcnn_pytorch/src/get_nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | import numpy as np 6 | 7 | from configs.paths_config import model_paths 8 | PNET_PATH = model_paths["mtcnn_pnet"] 9 | ONET_PATH = model_paths["mtcnn_onet"] 10 | RNET_PATH = model_paths["mtcnn_rnet"] 11 | 12 | 13 | class Flatten(nn.Module): 14 | 15 | def __init__(self): 16 | super(Flatten, self).__init__() 17 | 18 | def forward(self, x): 19 | """ 20 | Arguments: 21 | x: a float tensor with shape [batch_size, c, h, w]. 22 | Returns: 23 | a float tensor with shape [batch_size, c*h*w]. 24 | """ 25 | 26 | # without this pretrained model isn't working 27 | x = x.transpose(3, 2).contiguous() 28 | 29 | return x.view(x.size(0), -1) 30 | 31 | 32 | class PNet(nn.Module): 33 | 34 | def __init__(self): 35 | super().__init__() 36 | 37 | # suppose we have input with size HxW, then 38 | # after first layer: H - 2, 39 | # after pool: ceil((H - 2)/2), 40 | # after second conv: ceil((H - 2)/2) - 2, 41 | # after last conv: ceil((H - 2)/2) - 4, 42 | # and the same for W 43 | 44 | self.features = nn.Sequential(OrderedDict([ 45 | ('conv1', nn.Conv2d(3, 10, 3, 1)), 46 | ('prelu1', nn.PReLU(10)), 47 | ('pool1', nn.MaxPool2d(2, 2, ceil_mode=True)), 48 | 49 | ('conv2', nn.Conv2d(10, 16, 3, 1)), 50 | ('prelu2', nn.PReLU(16)), 51 | 52 | ('conv3', nn.Conv2d(16, 32, 3, 1)), 53 | ('prelu3', nn.PReLU(32)) 54 | ])) 55 | 56 | self.conv4_1 = nn.Conv2d(32, 2, 1, 1) 57 | self.conv4_2 = nn.Conv2d(32, 4, 1, 1) 58 | 59 | weights = np.load(PNET_PATH, allow_pickle=True)[()] 60 | for n, p in self.named_parameters(): 61 | p.data = torch.FloatTensor(weights[n]) 62 | 63 | def forward(self, x): 64 | """ 65 | Arguments: 66 | x: a float tensor with shape [batch_size, 3, h, w]. 67 | Returns: 68 | b: a float tensor with shape [batch_size, 4, h', w']. 69 | a: a float tensor with shape [batch_size, 2, h', w']. 70 | """ 71 | x = self.features(x) 72 | a = self.conv4_1(x) 73 | b = self.conv4_2(x) 74 | a = F.softmax(a, dim=-1) 75 | return b, a 76 | 77 | 78 | class RNet(nn.Module): 79 | 80 | def __init__(self): 81 | super().__init__() 82 | 83 | self.features = nn.Sequential(OrderedDict([ 84 | ('conv1', nn.Conv2d(3, 28, 3, 1)), 85 | ('prelu1', nn.PReLU(28)), 86 | ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)), 87 | 88 | ('conv2', nn.Conv2d(28, 48, 3, 1)), 89 | ('prelu2', nn.PReLU(48)), 90 | ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)), 91 | 92 | ('conv3', nn.Conv2d(48, 64, 2, 1)), 93 | ('prelu3', nn.PReLU(64)), 94 | 95 | ('flatten', Flatten()), 96 | ('conv4', nn.Linear(576, 128)), 97 | ('prelu4', nn.PReLU(128)) 98 | ])) 99 | 100 | self.conv5_1 = nn.Linear(128, 2) 101 | self.conv5_2 = nn.Linear(128, 4) 102 | 103 | weights = np.load(RNET_PATH, allow_pickle=True)[()] 104 | for n, p in self.named_parameters(): 105 | p.data = torch.FloatTensor(weights[n]) 106 | 107 | def forward(self, x): 108 | """ 109 | Arguments: 110 | x: a float tensor with shape [batch_size, 3, h, w]. 111 | Returns: 112 | b: a float tensor with shape [batch_size, 4]. 113 | a: a float tensor with shape [batch_size, 2]. 114 | """ 115 | x = self.features(x) 116 | a = self.conv5_1(x) 117 | b = self.conv5_2(x) 118 | a = F.softmax(a, dim=-1) 119 | return b, a 120 | 121 | 122 | class ONet(nn.Module): 123 | 124 | def __init__(self): 125 | super().__init__() 126 | 127 | self.features = nn.Sequential(OrderedDict([ 128 | ('conv1', nn.Conv2d(3, 32, 3, 1)), 129 | ('prelu1', nn.PReLU(32)), 130 | ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)), 131 | 132 | ('conv2', nn.Conv2d(32, 64, 3, 1)), 133 | ('prelu2', nn.PReLU(64)), 134 | ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)), 135 | 136 | ('conv3', nn.Conv2d(64, 64, 3, 1)), 137 | ('prelu3', nn.PReLU(64)), 138 | ('pool3', nn.MaxPool2d(2, 2, ceil_mode=True)), 139 | 140 | ('conv4', nn.Conv2d(64, 128, 2, 1)), 141 | ('prelu4', nn.PReLU(128)), 142 | 143 | ('flatten', Flatten()), 144 | ('conv5', nn.Linear(1152, 256)), 145 | ('drop5', nn.Dropout(0.25)), 146 | ('prelu5', nn.PReLU(256)), 147 | ])) 148 | 149 | self.conv6_1 = nn.Linear(256, 2) 150 | self.conv6_2 = nn.Linear(256, 4) 151 | self.conv6_3 = nn.Linear(256, 10) 152 | 153 | weights = np.load(ONET_PATH, allow_pickle=True)[()] 154 | for n, p in self.named_parameters(): 155 | p.data = torch.FloatTensor(weights[n]) 156 | 157 | def forward(self, x): 158 | """ 159 | Arguments: 160 | x: a float tensor with shape [batch_size, 3, h, w]. 161 | Returns: 162 | c: a float tensor with shape [batch_size, 10]. 163 | b: a float tensor with shape [batch_size, 4]. 164 | a: a float tensor with shape [batch_size, 2]. 165 | """ 166 | x = self.features(x) 167 | a = self.conv6_1(x) 168 | b = self.conv6_2(x) 169 | c = self.conv6_3(x) 170 | a = F.softmax(a, dim=-1) 171 | return c, b, a 172 | -------------------------------------------------------------------------------- /restyle_encoders/mtcnn/mtcnn_pytorch/src/visualization_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageDraw 2 | 3 | 4 | def show_bboxes(img, bounding_boxes, facial_landmarks=[]): 5 | """Draw bounding boxes and facial landmarks. 6 | 7 | Arguments: 8 | img: an instance of PIL.Image. 9 | bounding_boxes: a float numpy array of shape [n, 5]. 10 | facial_landmarks: a float numpy array of shape [n, 10]. 11 | 12 | Returns: 13 | an instance of PIL.Image. 14 | """ 15 | 16 | img_copy = img.copy() 17 | draw = ImageDraw.Draw(img_copy) 18 | 19 | for b in bounding_boxes: 20 | draw.rectangle([ 21 | (b[0], b[1]), (b[2], b[3]) 22 | ], outline='white') 23 | 24 | for p in facial_landmarks: 25 | for i in range(5): 26 | draw.ellipse([ 27 | (p[i] - 1.0, p[i + 5] - 1.0), 28 | (p[i] + 1.0, p[i + 5] + 1.0) 29 | ], outline='blue') 30 | 31 | return img_copy 32 | -------------------------------------------------------------------------------- /restyle_encoders/mtcnn/mtcnn_pytorch/src/weights/onet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/restyle_encoders/mtcnn/mtcnn_pytorch/src/weights/onet.npy -------------------------------------------------------------------------------- /restyle_encoders/mtcnn/mtcnn_pytorch/src/weights/pnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/restyle_encoders/mtcnn/mtcnn_pytorch/src/weights/pnet.npy -------------------------------------------------------------------------------- /restyle_encoders/mtcnn/mtcnn_pytorch/src/weights/rnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/restyle_encoders/mtcnn/mtcnn_pytorch/src/weights/rnet.npy -------------------------------------------------------------------------------- /restyle_encoders/stylegan2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MACderRu/HyperDomainNet/2986493c6de3aa5840528b398f33e5817e29f4f5/restyle_encoders/stylegan2/__init__.py -------------------------------------------------------------------------------- /restyle_encoders/stylegan2/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /restyle_encoders/stylegan2/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | module_path = os.path.dirname(__file__) 9 | fused = load( 10 | 'fused', 11 | sources=[ 12 | os.path.join(module_path, 'fused_bias_act.cpp'), 13 | os.path.join(module_path, 'fused_bias_act_kernel.cu'), 14 | ], 15 | ) 16 | 17 | 18 | class FusedLeakyReLUFunctionBackward(Function): 19 | @staticmethod 20 | def forward(ctx, grad_output, out, negative_slope, scale): 21 | ctx.save_for_backward(out) 22 | ctx.negative_slope = negative_slope 23 | ctx.scale = scale 24 | 25 | empty = grad_output.new_empty(0) 26 | 27 | grad_input = fused.fused_bias_act( 28 | grad_output, empty, out, 3, 1, negative_slope, scale 29 | ) 30 | 31 | dim = [0] 32 | 33 | if grad_input.ndim > 2: 34 | dim += list(range(2, grad_input.ndim)) 35 | 36 | grad_bias = grad_input.sum(dim).detach() 37 | 38 | return grad_input, grad_bias 39 | 40 | @staticmethod 41 | def backward(ctx, gradgrad_input, gradgrad_bias): 42 | out, = ctx.saved_tensors 43 | gradgrad_out = fused.fused_bias_act( 44 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 45 | ) 46 | 47 | return gradgrad_out, None, None, None 48 | 49 | 50 | class FusedLeakyReLUFunction(Function): 51 | @staticmethod 52 | def forward(ctx, input, bias, negative_slope, scale): 53 | empty = input.new_empty(0) 54 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 55 | ctx.save_for_backward(out) 56 | ctx.negative_slope = negative_slope 57 | ctx.scale = scale 58 | 59 | return out 60 | 61 | @staticmethod 62 | def backward(ctx, grad_output): 63 | out, = ctx.saved_tensors 64 | 65 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 66 | grad_output, out, ctx.negative_slope, ctx.scale 67 | ) 68 | 69 | return grad_input, grad_bias, None, None 70 | 71 | 72 | class FusedLeakyReLU(nn.Module): 73 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 74 | super().__init__() 75 | 76 | self.bias = nn.Parameter(torch.zeros(channel)) 77 | self.negative_slope = negative_slope 78 | self.scale = scale 79 | 80 | def forward(self, input): 81 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 82 | 83 | 84 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 85 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 86 | -------------------------------------------------------------------------------- /restyle_encoders/stylegan2/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /restyle_encoders/stylegan2/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /restyle_encoders/stylegan2/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /restyle_encoders/stylegan2/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.autograd import Function 5 | from torch.utils.cpp_extension import load 6 | 7 | module_path = os.path.dirname(__file__) 8 | upfirdn2d_op = load( 9 | 'upfirdn2d', 10 | sources=[ 11 | os.path.join(module_path, 'upfirdn2d.cpp'), 12 | os.path.join(module_path, 'upfirdn2d_kernel.cu'), 13 | ], 14 | ) 15 | 16 | 17 | class UpFirDn2dBackward(Function): 18 | @staticmethod 19 | def forward( 20 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 21 | ): 22 | up_x, up_y = up 23 | down_x, down_y = down 24 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 25 | 26 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 27 | 28 | grad_input = upfirdn2d_op.upfirdn2d( 29 | grad_output, 30 | grad_kernel, 31 | down_x, 32 | down_y, 33 | up_x, 34 | up_y, 35 | g_pad_x0, 36 | g_pad_x1, 37 | g_pad_y0, 38 | g_pad_y1, 39 | ) 40 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 41 | 42 | ctx.save_for_backward(kernel) 43 | 44 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 45 | 46 | ctx.up_x = up_x 47 | ctx.up_y = up_y 48 | ctx.down_x = down_x 49 | ctx.down_y = down_y 50 | ctx.pad_x0 = pad_x0 51 | ctx.pad_x1 = pad_x1 52 | ctx.pad_y0 = pad_y0 53 | ctx.pad_y1 = pad_y1 54 | ctx.in_size = in_size 55 | ctx.out_size = out_size 56 | 57 | return grad_input 58 | 59 | @staticmethod 60 | def backward(ctx, gradgrad_input): 61 | kernel, = ctx.saved_tensors 62 | 63 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 64 | 65 | gradgrad_out = upfirdn2d_op.upfirdn2d( 66 | gradgrad_input, 67 | kernel, 68 | ctx.up_x, 69 | ctx.up_y, 70 | ctx.down_x, 71 | ctx.down_y, 72 | ctx.pad_x0, 73 | ctx.pad_x1, 74 | ctx.pad_y0, 75 | ctx.pad_y1, 76 | ) 77 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 78 | gradgrad_out = gradgrad_out.view( 79 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 80 | ) 81 | 82 | return gradgrad_out, None, None, None, None, None, None, None, None 83 | 84 | 85 | class UpFirDn2d(Function): 86 | @staticmethod 87 | def forward(ctx, input, kernel, up, down, pad): 88 | up_x, up_y = up 89 | down_x, down_y = down 90 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 91 | 92 | kernel_h, kernel_w = kernel.shape 93 | batch, channel, in_h, in_w = input.shape 94 | ctx.in_size = input.shape 95 | 96 | input = input.reshape(-1, in_h, in_w, 1) 97 | 98 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 99 | 100 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 101 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 102 | ctx.out_size = (out_h, out_w) 103 | 104 | ctx.up = (up_x, up_y) 105 | ctx.down = (down_x, down_y) 106 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 107 | 108 | g_pad_x0 = kernel_w - pad_x0 - 1 109 | g_pad_y0 = kernel_h - pad_y0 - 1 110 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 111 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 112 | 113 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 114 | 115 | out = upfirdn2d_op.upfirdn2d( 116 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 117 | ) 118 | # out = out.view(major, out_h, out_w, minor) 119 | out = out.view(-1, channel, out_h, out_w) 120 | 121 | return out 122 | 123 | @staticmethod 124 | def backward(ctx, grad_output): 125 | kernel, grad_kernel = ctx.saved_tensors 126 | 127 | grad_input = UpFirDn2dBackward.apply( 128 | grad_output, 129 | kernel, 130 | grad_kernel, 131 | ctx.up, 132 | ctx.down, 133 | ctx.pad, 134 | ctx.g_pad, 135 | ctx.in_size, 136 | ctx.out_size, 137 | ) 138 | 139 | return grad_input, None, None, None, None 140 | 141 | 142 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 143 | out = UpFirDn2d.apply( 144 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 145 | ) 146 | 147 | return out 148 | 149 | 150 | def upfirdn2d_native( 151 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 152 | ): 153 | _, in_h, in_w, minor = input.shape 154 | kernel_h, kernel_w = kernel.shape 155 | 156 | out = input.view(-1, in_h, 1, in_w, 1, minor) 157 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 158 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 159 | 160 | out = F.pad( 161 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 162 | ) 163 | out = out[ 164 | :, 165 | max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0), 166 | max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0), 167 | :, 168 | ] 169 | 170 | out = out.permute(0, 3, 1, 2) 171 | out = out.reshape( 172 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 173 | ) 174 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 175 | out = F.conv2d(out, w) 176 | out = out.reshape( 177 | -1, 178 | minor, 179 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 180 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 181 | ) 182 | out = out.permute(0, 2, 3, 1) 183 | 184 | return out[:, ::down_y, ::down_x, :] 185 | -------------------------------------------------------------------------------- /text_domains/close_test_domains_2000.txt: -------------------------------------------------------------------------------- 1 | Graffiti - Photo 2 | Pencil painting - Photo 3 | 3D render - Photo 4 | Cubism graffiti - Photo 5 | Cubism pencil painting - Photo 6 | Cubism 3d render - Photo 7 | Pop Art graffiti - Photo 8 | Pop Art pencil painting - Photo 9 | Pop Art 3d render - Photo 10 | Mona Lisa graffiti - Photo 11 | Mona Lisa pencil painting - Photo 12 | Mona Lisa 3d render - Photo 13 | Starry Night graffiti - Photo 14 | Starry Night pencil painting - Photo 15 | Starry Night 3d render - Photo 16 | Old-timey graffiti - Photo 17 | Old-timey pencil painting - Photo 18 | Old-timey 3d render - Photo 19 | Edvard Munch graffiti - Photo 20 | Edvard Munch pencil painting - Photo 21 | Edvard Munch 3d render - Photo 22 | Ukiyo-e graffiti - Photo 23 | Ukiyo-e pencil painting - Photo 24 | Ukiyo-e 3d render - Photo 25 | Anime Painting - Photo 26 | Mona Lisa Painting - Photo 27 | 3D Render in the Style of Pixar - Photo 28 | Sketch - Photo 29 | A painting in Ukiyo-e style - Photo 30 | Werewolf - Photo 31 | Zombie - Photo -------------------------------------------------------------------------------- /text_domains/domain_list_20.txt: -------------------------------------------------------------------------------- 1 | Anime Painting - Photo 2 | Impressionism Painting - Photo 3 | Mona Lisa Painting - Photo 4 | 3D Render in the Style of Pixar - Photo 5 | Painting in the Style of Edvard Munch - Photo 6 | Cubism Painting - Photo 7 | Sketch - Photo 8 | Dali Painting - Photo 9 | Fernando Botero Painting - Photo 10 | A painting in Ukiyo-e style - Photo 11 | Tolkien Elf - Human 12 | Neanderthal - Human 13 | The Shrek - Human 14 | Zombie - Human 15 | The Hulk - Human 16 | The Thanos - Human 17 | Werewolf - Human 18 | Nicolas Cage - Human 19 | The Joker - Human 20 | Mark Zuckerberg - Human -------------------------------------------------------------------------------- /text_domains/domain_list_stress.txt: -------------------------------------------------------------------------------- 1 | Anime Painting - Photo 2 | Impressionism Painting - Photo 3 | Mona Lisa Painting - Photo 4 | 3D Render in the Style of Pixar - Photo 5 | Painting in the Style of Edvard Munch - Photo 6 | Cubism Painting - Photo 7 | Sketch - Photo 8 | Dali Painting - Photo 9 | Fernando Botero Painting - Photo 10 | A painting in Ukiyo-e style - Photo 11 | Modigliani Painting - Photo 12 | Pop Art - Photo 13 | A painting by Van Gogh - Photo 14 | A painting by Raphael - Photo 15 | Old-timey photograph - Photo 16 | Cartoon - Photo 17 | Documentary - Photo 18 | Natural painting - Photo 19 | Portrait - Photo 20 | Graffiti - Photo 21 | Minimalist - Photo 22 | Renaissance painting - Photo 23 | A painting in the style of Renaissance - Photo 24 | Japanese painting - Photo 25 | Manga - Photo 26 | A painting in the abstract style - Photo 27 | Surreal painting - Photo 28 | Vintage painting - Photo 29 | A painting in the style of Antiquity - Photo 30 | Statue - Human 31 | Tolkien Elf - Human 32 | Neanderthal - Human 33 | The Shrek - Human 34 | Zombie - Human 35 | The Hulk - Human 36 | The Thanos - Human 37 | Werewolf - Human 38 | Nicolas Cage - Human 39 | The Joker - Human 40 | Mark Zuckerberg - Human 41 | Donald Trump - Human 42 | Will Smith - Human 43 | Brad Pitt - Human 44 | Elon Musk - Human 45 | Cactus - Human 46 | White Walker - Human 47 | Vampire - Human 48 | Plastic Puppet - Human 49 | Witcher - Human -------------------------------------------------------------------------------- /text_domains/domain_list_style.txt: -------------------------------------------------------------------------------- 1 | Anime Painting - Photo 2 | Impressionism Painting - Photo 3 | Mona Lisa Painting - Photo 4 | 3D Render in the style of Pixar - Photo 5 | Painting in the Style of Edvard Munch - Photo 6 | Cubism Painting - Photo 7 | Sketch - Photo 8 | Dali Painting - Photo 9 | Fernando Botero Painting - Photo 10 | A painting in Ukiyo-e style - Photo -------------------------------------------------------------------------------- /text_domains/domain_list_test.txt: -------------------------------------------------------------------------------- 1 | Mona Lisa Painting - Photo 2 | 3D Render in the Style of Pixar - Photo -------------------------------------------------------------------------------- /text_domains/far_test_domains_2000.txt: -------------------------------------------------------------------------------- 1 | Caricature - Photo 2 | Mosaic - Photo 3 | Oil painting - Photo 4 | Pixar caricature - Photo 5 | Pixar mosaic - Photo 6 | Pixar oil painting - Photo 7 | Manga caricature - Photo 8 | Manga mosaic - Photo 9 | Manga oil painting - Photo 10 | Pablo Picasso caricature - Photo 11 | Pablo Picasso mosaic - Photo 12 | Pablo Picasso oil painting - Photo -------------------------------------------------------------------------------- /text_domains/mixed_content_train_domains.txt: -------------------------------------------------------------------------------- 1 | Tolkien Elf - Photo 2 | White Walker - Photo 3 | Gnom - Photo 4 | Goblin - Photo 5 | Banshee - Photo 6 | Demon - Photo 7 | The Shrek - Photo 8 | The Hulk - Photo 9 | The Thanos - Photo 10 | The Batman - Photo 11 | Witcher - Photo 12 | Vampire - Photo 13 | Homer Simpson - Photo 14 | Homo sapiens - Photo 15 | Emma Stone - Photo 16 | Hilary Clinton - Photo 17 | Rihanna - Photo 18 | Taylor Swift - Photo 19 | Beyonce - Photo 20 | Emma Watson - Photo 21 | Kim Kardashian - Photo 22 | Lady Gaga - Photo 23 | Chris Pratt - Photo 24 | Tom Hanks - Photo 25 | Tom Cruise - Photo 26 | Jack Nicholson - Photo 27 | Leonardo DiCaprio - Photo 28 | Johnny Depp - Photo 29 | Matt Daemon - Photo 30 | Robert Downey Jr. - Photo 31 | Will Smith - Photo 32 | Robet De Niro - Photo 33 | Morgan Freeman - Photo 34 | Keanu Reeves - Photo 35 | Brad Pitt - Photo 36 | Ryan Reynolds - Photo 37 | George Clooney - Photo 38 | Scarlett Johansson - Photo 39 | Nicole Kidman - Photo 40 | Bradley Cooper - Photo 41 | Matthew McConaughey - Photo 42 | Hugh Jackman - Photo 43 | Chris Evans - Photo 44 | Chris Hemsworth - Photo 45 | Daniel Radcliffe - Photo 46 | Orlando Bloom - Photo 47 | Freddie Mercury - Photo 48 | Eminem - Photo 49 | Elvis Presley - Photo 50 | Adele - Photo 51 | Ed Sheeran - Photo 52 | Katy Perry - Photo 53 | Donald Trump - Photo 54 | Nicolas Cage - Photo 55 | Mark Zuckerberg - Photo 56 | Elon Musk - Photo 57 | Statue - Photo 58 | Metal Statue - Photo 59 | Glass Statue - Photo 60 | Plastic Puppet - Photo 61 | Puppet - Photo 62 | Bust - Photo 63 | Woman - Photo 64 | Man - Photo -------------------------------------------------------------------------------- /text_domains/mixed_launch_content_domains.txt: -------------------------------------------------------------------------------- 1 | Emma Stone - Human 2 | Hilary Clinton - Human 3 | Rihanna - Human 4 | Taylor Swift - Human 5 | Beyonce - Human 6 | Emma Watson - Human 7 | Kim Kardashian - Human 8 | Lady Gaga - Human 9 | Chris Pratt - Human 10 | Tom Hanks - Human 11 | Tom Cruise - Human 12 | Jack Nicholson - Human 13 | Leonardo DiCaprio - Human 14 | Johnny Depp - Human 15 | Matt Damon - Human 16 | Robert Downey Jr. - Human 17 | Will Smith - Human 18 | Robet De Niro - Human 19 | Morgan Freeman - Human 20 | Keanu Reeves - Human 21 | Brad Pitt - Human 22 | Ryan Reynolds - Human 23 | George Clooney - Human 24 | Scarlett Johansson - Human 25 | Nicole Kidman - Human 26 | Bradley Cooper - Human 27 | Matthew McConaughey - Human 28 | Hugh Jackman - Human 29 | Chris Evans - Human 30 | Chris Hemsworth - Human 31 | Daniel Radcliffe - Human 32 | Orlando Bloom - Human 33 | Freddie Mercury - Human 34 | Eminem - Human 35 | Elvis Presley - Human 36 | Adele - Human 37 | Ed Sheeran - Human 38 | Katy Perry - Human 39 | Donald Trump - Human 40 | Nicolas Cage - Human 41 | Mark Zuckerberg - Human 42 | Elon Musk - Human 43 | Tolkien Elf - Human 44 | White Walker - Human 45 | Dwarf - Human 46 | Goblin - Human 47 | Banshee - Human 48 | Demon - Human 49 | The Shrek - Human 50 | The Hulk - Human 51 | The Thanos - Human 52 | The Batman - Human 53 | Witcher - Human 54 | Vampire - Human 55 | Homer Simpson - Human 56 | Homo Sapiens - Human 57 | Undead - Human 58 | Monster - Human 59 | Clown - Human 60 | Devil - Human 61 | Succubus - Human 62 | Ghost - Human 63 | Lycanthrope - Human 64 | Caveman - Human 65 | Primitive Man - Human 66 | Statue - Human 67 | Metal Statue - Human 68 | Stone Statue - Human 69 | Glass Statue - Human 70 | Plastic Puppet - Human 71 | Puppet - Human 72 | Stone Bust - Human 73 | Bust - Human 74 | Woman - Human 75 | Man - Human -------------------------------------------------------------------------------- /text_domains/mixed_style_train_domains.txt: -------------------------------------------------------------------------------- 1 | Portrait - Photo 2 | Image - Photo 3 | Photo - Photo 4 | Painting - Photo 5 | Graffiti - Photo 6 | Photograph - Photo 7 | Cartoon - Photo 8 | 3D render - Photo 9 | Drawing - Photo 10 | Graphics - Photo 11 | Mosaic - Photo 12 | Caricature - Photo 13 | Raphael - Photo 14 | Salvaror Dali - Photo 15 | Edvard Munch - Photo 16 | Modigliani - Photo 17 | Van Gogh - Photo 18 | Claude Monet - Photo 19 | Leonardo Da Vinci - Photo 20 | Pop Art - Photo 21 | Impressionism - Photo 22 | Renaissance - Photo 23 | Abstract - Photo 24 | Vintage - Photo 25 | Antiquity - Photo 26 | Cubism - Photo 27 | Disney - Photo 28 | Russian - Photo 29 | Japanese - Photo 30 | Spanish - Photo 31 | Italian - Photo 32 | Dutch - Photo 33 | German - Photo 34 | Surreal - Photo 35 | WaltDisney - Photo 36 | DreamWorks - Photo 37 | Manga - Photo 38 | Modern - Photo 39 | Realism - Photo 40 | Starry Night - Photo 41 | Old-timey - Photo 42 | Pencil - Photo 43 | Gouache - Photo 44 | Acrylic - Photo 45 | Watercolor - Photo 46 | Oil - Photo 47 | Black - Photo 48 | White - Photo 49 | Blue - Photo 50 | Charcoal - Photo -------------------------------------------------------------------------------- /text_domains/mixed_train_domains.txt: -------------------------------------------------------------------------------- 1 | Portrait - Photo 2 | Image - Photo 3 | Photo - Photo 4 | Painting - Photo 5 | Graffiti - Photo 6 | Photograph - Photo 7 | Cartoon - Photo 8 | 3D render - Photo 9 | Drawing - Photo 10 | Graphics - Photo 11 | Mosaic - Photo 12 | Caricature - Photo 13 | Raphael - Photo 14 | Salvaror Dali - Photo 15 | Edvard Munch - Photo 16 | Modigliani - Photo 17 | Van Gogh - Photo 18 | Claude Monet - Photo 19 | Leonardo Da Vinci - Photo 20 | Pop Art - Photo 21 | Impressionism - Photo 22 | Renaissance - Photo 23 | Abstract - Photo 24 | Vintage - Photo 25 | Antiquity - Photo 26 | Cubism - Photo 27 | Disney - Photo 28 | Russian - Photo 29 | Japanese - Photo 30 | Spanish - Photo 31 | Italian - Photo 32 | Dutch - Photo 33 | German - Photo 34 | Surreal - Photo 35 | WaltDisney - Photo 36 | DreamWorks - Photo 37 | Manga - Photo 38 | Modern - Photo 39 | Realism - Photo 40 | Starry Night - Photo 41 | Old-timey - Photo 42 | Pencil - Photo 43 | Gouache - Photo 44 | Acrylic - Photo 45 | Watercolor - Photo 46 | Oil - Photo 47 | Black - Photo 48 | White - Photo 49 | Blue - Photo 50 | Charcoal - Photo 51 | Tolkien Elf - Human 52 | White Walker - Human 53 | Gnom - Human 54 | Goblin - Human 55 | Banshee - Human 56 | Demon - Human 57 | The Shrek - Human 58 | The Hulk - Human 59 | The Thanos - Human 60 | The Batman - Human 61 | Witcher - Human 62 | Vampire - Human 63 | Homer Simpson - Human 64 | Homo sapiens - Human 65 | Emma Stone - Human 66 | Hilary Clinton - Human 67 | Rihanna - Human 68 | Taylor Swift - Human 69 | Beyonce - Human 70 | Emma Watson - Human 71 | Kim Kardashian - Human 72 | Lady Gaga - Human 73 | Chris Pratt - Human 74 | Tom Hanks - Human 75 | Tom Cruise - Human 76 | Jack Nicholson - Human 77 | Leonardo DiCaprio - Human 78 | Johnny Depp - Human 79 | Matt Daemon - Human 80 | Robert Downey Jr. - Human 81 | Will Smith - Human 82 | Robet De Niro - Human 83 | Morgan Freeman - Human 84 | Keanu Reeves - Human 85 | Brad Pitt - Human 86 | Ryan Reynolds - Human 87 | George Clooney - Human 88 | Scarlett Johansson - Human 89 | Nicole Kidman - Human 90 | Bradley Cooper - Human 91 | Matthew McConaughey - Human 92 | Hugh Jackman - Human 93 | Chris Evans - Human 94 | Chris Hemsworth - Human 95 | Daniel Radcliffe - Human 96 | Orlando Bloom - Human 97 | Freddie Mercury - Human 98 | Eminem - Human 99 | Elvis Presley - Human 100 | Adele - Human 101 | Ed Sheeran - Human 102 | Katy Perry - Human 103 | Donald Trump - Human 104 | Nicolas Cage - Human 105 | Mark Zuckerberg - Human 106 | Elon Musk - Human 107 | Statue - Human 108 | Metal Statue - Human 109 | Glass Statue - Human 110 | Plastic Puppet - Human 111 | Puppet - Human 112 | Bust - Human 113 | Woman - Human 114 | Man - Human -------------------------------------------------------------------------------- /text_domains/train_synonyms_domains.txt: -------------------------------------------------------------------------------- 1 | Portrait - Photo 2 | Image - Photo 3 | Photo - Photo 4 | Painting - Photo 5 | Graffiti - Photo 6 | Pencil Painting - Photo 7 | Photograph - Photo 8 | Cartoon - Photo 9 | Stereo View - Photo 10 | Animation - Photo 11 | Drawing - Photo 12 | Dancers at the bar - Photo 13 | Una Familia - Photo 14 | Graphics - Photo 15 | Mosaic - Photo 16 | Caricature - Photo 17 | Raphael - Photo 18 | Salvaror Dali - Photo 19 | Edvard Munch - Photo 20 | Modigliani - Photo 21 | Van Gogh - Photo 22 | Claude Monet - Photo 23 | Leonardo Da Vinci - Photo 24 | Pop Art - Photo 25 | Impressionism - Photo 26 | Renaissance - Photo 27 | Abstract - Photo 28 | Vintage - Photo 29 | Antiquity - Photo 30 | Cubism - Photo 31 | Disney - Photo 32 | Russian - Photo 33 | Chineese - Photo 34 | Japanese - Photo 35 | Spanish - Photo 36 | Italian - Photo 37 | Dutch - Photo 38 | German - Photo 39 | Surreal - Photo 40 | WaltDisney - Photo 41 | DreamWorks - Photo 42 | Manga - Photo 43 | Komodo - Photo 44 | Modern - Photo 45 | Realism - Photo 46 | Starry Night - Photo 47 | Old-timey - Photo 48 | Pencil - Photo 49 | Gouache - Photo 50 | Acrylic - Photo 51 | Watercolor - Photo 52 | Oil - Photo 53 | Black - Photo 54 | White - Photo 55 | Blue - Photo 56 | Charcoal - Photo 57 | Tolkien Elf - Human 58 | White Walker - Human 59 | Gnom - Human 60 | Goblin - Human 61 | Banshee - Human 62 | Demon - Human 63 | The Shrek - Human 64 | The Hulk - Human 65 | The Thanos - Human 66 | The Batman - Human 67 | Witcher - Human 68 | Vampire - Human 69 | Homer Simpson - Human 70 | Homo Sapiens - Human 71 | Emma Stone - Human 72 | Hilary Clinton - Human 73 | Rihanna - Human 74 | Taylor Swift - Human 75 | Beyonce - Human 76 | Emma Watson - Human 77 | Kim Kardashian - Human 78 | Lady Gaga - Human 79 | Chris Pratt - Human 80 | Tom Hanks - Human 81 | Tom Cruise - Human 82 | Jack Nicholson - Human 83 | Leonardo DiCaprio - Human 84 | Johnny Depp - Human 85 | Matt Daemon - Human 86 | Robert Downey Jr. - Human 87 | Will Smith - Human 88 | Robet De Niro - Human 89 | Morgan Freeman - Human 90 | Keanu Reeves - Human 91 | Brad Pitt - Human 92 | Ryan Reynolds - Human 93 | George Clooney - Human 94 | Scarlett Johansson - Human 95 | Nicole Kidman - Human 96 | Bradley Cooper - Human 97 | Matthew McConaughey - Human 98 | Hugh Jackman - Human 99 | Chris Evans - Human 100 | Chris Hemsworth - Human 101 | Daniel Radcliffe - Human 102 | Orlando Bloom - Human 103 | Freddie Mercury - Human 104 | Eminem - Human 105 | Elvis Presley - Human 106 | Adele - Human 107 | Ed Sheeran - Human 108 | Katy Perry - Human 109 | Donald Trump - Human 110 | Nicolas Cage - Human 111 | Mark Zuckerberg - Human 112 | Elon Musk - Human 113 | Statue - Human 114 | Metal Statue - Human 115 | Glass Statue - Human 116 | Plastic Puppet - Human 117 | Puppet - Human 118 | Bust - Human 119 | Woman - Human 120 | Man - Human 121 | Undead - Human 122 | Monster - Human 123 | Clown - Human 124 | Devil - Human 125 | Succubus - Human 126 | Ghost - Human 127 | Lycanthrope - Human 128 | Caveman - Human 129 | Primitive Man - Human --------------------------------------------------------------------------------