├── .gitignore ├── DissimilarDomains ├── .gitignore ├── BASE_README.md ├── Dockerfile ├── LICENSE.txt ├── README.md ├── calc_metrics.py ├── dataset_tool.py ├── dnnlib │ ├── __init__.py │ └── util.py ├── docker_run.sh ├── docs │ ├── dataset-tool-help.txt │ ├── license.html │ ├── stylegan2-ada-teaser-1024x252.png │ ├── stylegan2-ada-training-curves.png │ └── train-help.txt ├── editing │ └── styleflow │ │ ├── README.md │ │ ├── cnf.py │ │ ├── diffeq_layers.py │ │ ├── editor.py │ │ ├── flow.py │ │ ├── normalization.py │ │ ├── odefunc.py │ │ └── utils.py ├── examples │ ├── General Results ICCV.ipynb │ ├── I2I ICCV.ipynb │ ├── Quick Start ICCV.ipynb │ ├── Semantic Editing ICCV.ipynb │ ├── inverted_samples │ │ ├── 41_0 │ │ │ ├── 0_proj.png │ │ │ ├── 0_projected_z.npz │ │ │ ├── 0_target.png │ │ │ └── projected_z.npz │ │ └── pixabay_dog_003552 │ │ │ ├── 0_proj.png │ │ │ ├── 0_projected_z.npz │ │ │ ├── 0_target.png │ │ │ └── projected_z.npz │ └── nb_utils.py ├── generate.py ├── legacy.py ├── metrics │ ├── __init__.py │ ├── frechet_inception_distance.py │ ├── inception_score.py │ ├── kernel_inception_distance.py │ ├── metric_main.py │ ├── metric_utils.py │ ├── perceptual_path_length.py │ └── precision_recall.py ├── projector.py ├── samples │ ├── 0_0.jpg │ ├── 31_0.jpg │ ├── 41_0.jpg │ ├── 46_0.jpg │ └── pixabay_dog_003552.jpg ├── style_mixing.py ├── torch_utils │ ├── __init__.py │ ├── custom_ops.py │ ├── misc.py │ ├── ops │ │ ├── __init__.py │ │ ├── bias_act.cpp │ │ ├── bias_act.cu │ │ ├── bias_act.h │ │ ├── bias_act.py │ │ ├── conv2d_gradfix.py │ │ ├── conv2d_resample.py │ │ ├── fma.py │ │ ├── grid_sample_gradfix.py │ │ ├── upfirdn2d.cpp │ │ ├── upfirdn2d.cu │ │ ├── upfirdn2d.h │ │ └── upfirdn2d.py │ ├── persistence.py │ └── training_stats.py ├── train.py └── training │ ├── __init__.py │ ├── augment.py │ ├── dataset.py │ ├── loss.py │ ├── networks.py │ └── training_loop.py ├── README.md ├── SimilarDomains ├── README.md ├── configs │ ├── im2im_difa.yaml │ ├── im2im_difa_for_low_memory.yaml │ ├── im2im_difa_hdn.yaml │ ├── im2im_difa_sdelta.yaml │ ├── im2im_jojo.yaml │ ├── im2im_jojo_for_low_memory.yaml │ ├── im2im_jojo_sdelta.yaml │ ├── im2im_mtg.yaml │ ├── im2im_mtg_for_low_memory.yaml │ ├── im2im_mtg_sdelta.yaml │ ├── td_nada_cars.yaml │ ├── td_nada_ffhq.yaml │ └── td_nada_ffhq_sdelta.yaml ├── core │ ├── __init__.py │ ├── dataset.py │ ├── evaluation.py │ ├── loss.py │ ├── lpips │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── dist_model.py │ │ ├── networks_basic.py │ │ ├── pretrained_networks.py │ │ └── weights │ │ │ ├── v0.0 │ │ │ ├── alex.pth │ │ │ ├── squeeze.pth │ │ │ └── vgg.pth │ │ │ └── v0.1 │ │ │ ├── alex.pth │ │ │ ├── squeeze.pth │ │ │ └── vgg.pth │ ├── mappers.py │ ├── parametrizations.py │ ├── sparse_models.py │ ├── style_embed_options.py │ ├── stylegan_patches.py │ ├── uda_models.py │ └── utils │ │ ├── .ipynb_checkpoints │ │ └── example_utils-checkpoint.py │ │ ├── II2S.py │ │ ├── __init__.py │ │ ├── arguments.py │ │ ├── class_registry.py │ │ ├── common.py │ │ ├── example_utils.py │ │ ├── fid.py │ │ ├── gif.py │ │ ├── image_utils.py │ │ ├── karras_to_rosinality.py │ │ ├── loggers.py │ │ ├── loss_utils.py │ │ ├── math_utils.py │ │ ├── notebook.py │ │ ├── reading_weights.py │ │ ├── text_templates.py │ │ └── train_log.py ├── download.py ├── editing │ ├── __init__.py │ ├── interfacegan_directions │ │ ├── age.pt │ │ ├── gender.pt │ │ ├── rotation.pt │ │ └── smile.pt │ ├── latent_editor_wrapper.py │ └── styleflow │ │ ├── cnf.py │ │ ├── diffeq_layers.py │ │ ├── editor.py │ │ ├── flow.py │ │ ├── normalization.py │ │ ├── odefunc.py │ │ └── utils.py ├── examples │ ├── adaptation_in_finetuned_gan.ipynb │ ├── combined_morphing.ipynb │ ├── draw_util.py │ ├── editing.ipynb │ ├── multiple_morphing.ipynb │ ├── photos │ │ └── elon_musk.jpeg │ └── pruned_forward.ipynb ├── gan_models │ ├── BigGAN │ │ ├── BigGAN.py │ │ ├── __init__.py │ │ ├── generator_config.json │ │ ├── layers.py │ │ ├── sync_batchnorm │ │ │ ├── __init__.py │ │ │ ├── batchnorm.py │ │ │ ├── batchnorm_reimpl.py │ │ │ ├── comm.py │ │ │ ├── replicate.py │ │ │ └── unittest.py │ │ └── utils.py │ ├── ProgGAN │ │ ├── __init__.py │ │ └── model.py │ ├── SNGAN │ │ ├── __init__.py │ │ ├── distribution.py │ │ ├── load.py │ │ └── sn_gen_resnet.py │ ├── StyleGAN2 │ │ ├── convert_weight.py │ │ ├── model.py │ │ ├── nvidia.py │ │ ├── nvidia_offsets.py │ │ ├── offsets_model.py │ │ ├── op │ │ │ ├── __init__.py │ │ │ ├── fused_act.py │ │ │ ├── fused_act_torch_native.py │ │ │ ├── fused_bias_act.cpp │ │ │ ├── fused_bias_act_kernel.cu │ │ │ ├── upfirdn2d.cpp │ │ │ ├── upfirdn2d.py │ │ │ ├── upfirdn2d_kernel.cu │ │ │ └── upfirdn2d_torch_native.py │ │ └── prop_convert.py │ ├── __init__.py │ ├── gan_load.py │ ├── gan_with_shift.py │ └── imagenet_classes.json ├── main.py ├── requirements.txt ├── restyle_encoders │ ├── .ipynb_checkpoints │ │ └── psp-checkpoint.py │ ├── __init__.py │ ├── download.py │ ├── e4e.py │ ├── e4e_modules │ │ ├── __init__.py │ │ ├── discriminator.py │ │ └── latent_codes_pool.py │ ├── e4e_restyle.py │ ├── encoders │ │ ├── __init__.py │ │ ├── fpn_encoders.py │ │ ├── helpers.py │ │ ├── map2style.py │ │ ├── model_irse.py │ │ ├── psp_encoders.py │ │ ├── restyle_e4e_encoders.py │ │ └── restyle_psp_encoders.py │ ├── mtcnn │ │ ├── __init__.py │ │ ├── mtcnn.py │ │ └── mtcnn_pytorch │ │ │ ├── __init__.py │ │ │ └── src │ │ │ ├── __init__.py │ │ │ ├── align_trans.py │ │ │ ├── box_utils.py │ │ │ ├── detector.py │ │ │ ├── first_stage.py │ │ │ ├── get_nets.py │ │ │ ├── matlab_cp2tform.py │ │ │ ├── visualization_utils.py │ │ │ └── weights │ │ │ ├── onet.npy │ │ │ ├── pnet.npy │ │ │ └── rnet.npy │ └── psp.py └── trainers.py └── img ├── Figure-FewShot.png ├── diagram.png ├── few_shot_domains.png ├── one_shot_domains.png ├── style_domain_transfer.png └── titan_armin_joker_pixar.png /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | **/.idea/* 3 | __pycache__/ 4 | wandb/ 5 | pretrained/ 6 | .DS_Store 7 | -------------------------------------------------------------------------------- /DissimilarDomains/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .cache/ 3 | run*.sh 4 | calc*.sh 5 | project*.sh 6 | examples/afhq -------------------------------------------------------------------------------- /DissimilarDomains/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | FROM nvcr.io/nvidia/pytorch:20.12-py3 10 | 11 | ENV PYTHONDONTWRITEBYTECODE 1 12 | ENV PYTHONUNBUFFERED 1 13 | 14 | RUN pip install imageio-ffmpeg==0.4.3 pyspng==0.1.0 15 | 16 | WORKDIR /workspace 17 | 18 | # Unset TORCH_CUDA_ARCH_LIST and exec. This makes pytorch run-time 19 | # extension builds significantly faster as we only compile for the 20 | # currently active GPU configuration. 21 | RUN (printf '#!/bin/bash\nunset TORCH_CUDA_ARCH_LIST\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh 22 | ENTRYPOINT ["/entry.sh"] 23 | -------------------------------------------------------------------------------- /DissimilarDomains/LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, NVIDIA Corporation. All rights reserved. 2 | 3 | 4 | NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA) 5 | 6 | 7 | ======================================================================= 8 | 9 | 1. Definitions 10 | 11 | "Licensor" means any person or entity that distributes its Work. 12 | 13 | "Software" means the original work of authorship made available under 14 | this License. 15 | 16 | "Work" means the Software and any additions to or derivative works of 17 | the Software that are made available under this License. 18 | 19 | The terms "reproduce," "reproduction," "derivative works," and 20 | "distribution" have the meaning as provided under U.S. copyright law; 21 | provided, however, that for the purposes of this License, derivative 22 | works shall not include works that remain separable from, or merely 23 | link (or bind by name) to the interfaces of, the Work. 24 | 25 | Works, including the Software, are "made available" under this License 26 | by including in or with the Work either (a) a copyright notice 27 | referencing the applicability of this License to the Work, or (b) a 28 | copy of this License. 29 | 30 | 2. License Grants 31 | 32 | 2.1 Copyright Grant. Subject to the terms and conditions of this 33 | License, each Licensor grants to you a perpetual, worldwide, 34 | non-exclusive, royalty-free, copyright license to reproduce, 35 | prepare derivative works of, publicly display, publicly perform, 36 | sublicense and distribute its Work and any resulting derivative 37 | works in any form. 38 | 39 | 3. Limitations 40 | 41 | 3.1 Redistribution. You may reproduce or distribute the Work only 42 | if (a) you do so under this License, (b) you include a complete 43 | copy of this License with your distribution, and (c) you retain 44 | without modification any copyright, patent, trademark, or 45 | attribution notices that are present in the Work. 46 | 47 | 3.2 Derivative Works. You may specify that additional or different 48 | terms apply to the use, reproduction, and distribution of your 49 | derivative works of the Work ("Your Terms") only if (a) Your Terms 50 | provide that the use limitation in Section 3.3 applies to your 51 | derivative works, and (b) you identify the specific derivative 52 | works that are subject to Your Terms. Notwithstanding Your Terms, 53 | this License (including the redistribution requirements in Section 54 | 3.1) will continue to apply to the Work itself. 55 | 56 | 3.3 Use Limitation. The Work and any derivative works thereof only 57 | may be used or intended for use non-commercially. Notwithstanding 58 | the foregoing, NVIDIA and its affiliates may use the Work and any 59 | derivative works commercially. As used herein, "non-commercially" 60 | means for research or evaluation purposes only. 61 | 62 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 63 | against any Licensor (including any claim, cross-claim or 64 | counterclaim in a lawsuit) to enforce any patents that you allege 65 | are infringed by any Work, then your rights under this License from 66 | such Licensor (including the grant in Section 2.1) will terminate 67 | immediately. 68 | 69 | 3.5 Trademarks. This License does not grant any rights to use any 70 | Licensor’s or its affiliates’ names, logos, or trademarks, except 71 | as necessary to reproduce the notices described in this License. 72 | 73 | 3.6 Termination. If you violate any term of this License, then your 74 | rights under this License (including the grant in Section 2.1) will 75 | terminate immediately. 76 | 77 | 4. Disclaimer of Warranty. 78 | 79 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 80 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 81 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 82 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 83 | THIS LICENSE. 84 | 85 | 5. Limitation of Liability. 86 | 87 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 88 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 89 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 90 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 91 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 92 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 93 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 94 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 95 | THE POSSIBILITY OF SUCH DAMAGES. 96 | 97 | ======================================================================= 98 | -------------------------------------------------------------------------------- /DissimilarDomains/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /DissimilarDomains/docker_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | set -e 12 | 13 | # Wrapper script for setting up `docker run` to properly 14 | # cache downloaded files, custom extension builds and 15 | # mount the source directory into the container and make it 16 | # run as non-root user. 17 | # 18 | # Use it like: 19 | # 20 | # ./docker_run.sh python generate.py --help 21 | # 22 | # To override the default `stylegan2ada:latest` image, run: 23 | # 24 | # IMAGE=my_image:v1.0 ./docker_run.sh python generate.py --help 25 | # 26 | 27 | rest=$@ 28 | 29 | IMAGE="${IMAGE:-sg2ada:latest}" 30 | 31 | CONTAINER_ID=$(docker inspect --format="{{.Id}}" ${IMAGE} 2> /dev/null) 32 | if [[ "${CONTAINER_ID}" ]]; then 33 | docker run --shm-size=2g --gpus all -it --rm -v `pwd`:/scratch --user $(id -u):$(id -g) \ 34 | --workdir=/scratch -e HOME=/scratch $IMAGE $@ 35 | else 36 | echo "Unknown container image: ${IMAGE}" 37 | exit 1 38 | fi 39 | -------------------------------------------------------------------------------- /DissimilarDomains/docs/dataset-tool-help.txt: -------------------------------------------------------------------------------- 1 | Usage: dataset_tool.py [OPTIONS] 2 | 3 | Convert an image dataset into a dataset archive usable with StyleGAN2 ADA 4 | PyTorch. 5 | 6 | The input dataset format is guessed from the --source argument: 7 | 8 | --source *_lmdb/ - Load LSUN dataset 9 | --source cifar-10-python.tar.gz - Load CIFAR-10 dataset 10 | --source path/ - Recursively load all images from path/ 11 | --source dataset.zip - Recursively load all images from dataset.zip 12 | 13 | The output dataset format can be either an image folder or a zip archive. 14 | Specifying the output format and path: 15 | 16 | --dest /path/to/dir - Save output files under /path/to/dir 17 | --dest /path/to/dataset.zip - Save output files into /path/to/dataset.zip archive 18 | 19 | Images within the dataset archive will be stored as uncompressed PNG. 20 | 21 | Image scale/crop and resolution requirements: 22 | 23 | Output images must be square-shaped and they must all have the same power- 24 | of-two dimensions. 25 | 26 | To scale arbitrary input image size to a specific width and height, use 27 | the --width and --height options. Output resolution will be either the 28 | original input resolution (if --width/--height was not specified) or the 29 | one specified with --width/height. 30 | 31 | Use the --transform=center-crop or --transform=center-crop-wide options to 32 | apply a center crop transform on the input image. These options should be 33 | used with the --width and --height options. For example: 34 | 35 | python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \ 36 | --transform=center-crop-wide --width 512 --height=384 37 | 38 | Options: 39 | --source PATH Directory or archive name for input dataset 40 | [required] 41 | --dest PATH Output directory or archive name for output 42 | dataset [required] 43 | --max-images INTEGER Output only up to `max-images` images 44 | --resize-filter [box|lanczos] Filter to use when resizing images for 45 | output resolution [default: lanczos] 46 | --transform [center-crop|center-crop-wide] 47 | Input crop/resize mode 48 | --width INTEGER Output width 49 | --height INTEGER Output height 50 | --help Show this message and exit. 51 | -------------------------------------------------------------------------------- /DissimilarDomains/docs/stylegan2-ada-teaser-1024x252.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/DissimilarDomains/docs/stylegan2-ada-teaser-1024x252.png -------------------------------------------------------------------------------- /DissimilarDomains/docs/stylegan2-ada-training-curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/DissimilarDomains/docs/stylegan2-ada-training-curves.png -------------------------------------------------------------------------------- /DissimilarDomains/docs/train-help.txt: -------------------------------------------------------------------------------- 1 | Usage: train.py [OPTIONS] 2 | 3 | Train a GAN using the techniques described in the paper "Training 4 | Generative Adversarial Networks with Limited Data". 5 | 6 | Examples: 7 | 8 | # Train with custom images using 1 GPU. 9 | python train.py --outdir=~/training-runs --data=~/my-image-folder 10 | 11 | # Train class-conditional CIFAR-10 using 2 GPUs. 12 | python train.py --outdir=~/training-runs --data=~/datasets/cifar10.zip \ 13 | --gpus=2 --cfg=cifar --cond=1 14 | 15 | # Transfer learn MetFaces from FFHQ using 4 GPUs. 16 | python train.py --outdir=~/training-runs --data=~/datasets/metfaces.zip \ 17 | --gpus=4 --cfg=paper1024 --mirror=1 --resume=ffhq1024 --snap=10 18 | 19 | # Reproduce original StyleGAN2 config F. 20 | python train.py --outdir=~/training-runs --data=~/datasets/ffhq.zip \ 21 | --gpus=8 --cfg=stylegan2 --mirror=1 --aug=noaug 22 | 23 | Base configs (--cfg): 24 | auto Automatically select reasonable defaults based on resolution 25 | and GPU count. Good starting point for new datasets. 26 | stylegan2 Reproduce results for StyleGAN2 config F at 1024x1024. 27 | paper256 Reproduce results for FFHQ and LSUN Cat at 256x256. 28 | paper512 Reproduce results for BreCaHAD and AFHQ at 512x512. 29 | paper1024 Reproduce results for MetFaces at 1024x1024. 30 | cifar Reproduce results for CIFAR-10 at 32x32. 31 | 32 | Transfer learning source networks (--resume): 33 | ffhq256 FFHQ trained at 256x256 resolution. 34 | ffhq512 FFHQ trained at 512x512 resolution. 35 | ffhq1024 FFHQ trained at 1024x1024 resolution. 36 | celebahq256 CelebA-HQ trained at 256x256 resolution. 37 | lsundog256 LSUN Dog trained at 256x256 resolution. 38 | Custom network pickle. 39 | 40 | Options: 41 | --outdir DIR Where to save the results [required] 42 | --gpus INT Number of GPUs to use [default: 1] 43 | --snap INT Snapshot interval [default: 50 ticks] 44 | --metrics LIST Comma-separated list or "none" [default: 45 | fid50k_full] 46 | --seed INT Random seed [default: 0] 47 | -n, --dry-run Print training options and exit 48 | --data PATH Training data (directory or zip) [required] 49 | --cond BOOL Train conditional model based on dataset 50 | labels [default: false] 51 | --subset INT Train with only N images [default: all] 52 | --mirror BOOL Enable dataset x-flips [default: false] 53 | --cfg [auto|stylegan2|paper256|paper512|paper1024|cifar] 54 | Base config [default: auto] 55 | --gamma FLOAT Override R1 gamma 56 | --kimg INT Override training duration 57 | --batch INT Override batch size 58 | --aug [noaug|ada|fixed] Augmentation mode [default: ada] 59 | --p FLOAT Augmentation probability for --aug=fixed 60 | --target FLOAT ADA target value for --aug=ada 61 | --augpipe [blit|geom|color|filter|noise|cutout|bg|bgc|bgcf|bgcfn|bgcfnc] 62 | Augmentation pipeline [default: bgc] 63 | --resume PKL Resume training [default: noresume] 64 | --freezed INT Freeze-D [default: 0 layers] 65 | --fp32 BOOL Disable mixed-precision training 66 | --nhwc BOOL Use NHWC memory format with FP16 67 | --nobench BOOL Disable cuDNN benchmarking 68 | --allow-tf32 BOOL Allow PyTorch to use TF32 internally 69 | --workers INT Override number of DataLoader workers 70 | --help Show this message and exit. 71 | -------------------------------------------------------------------------------- /DissimilarDomains/editing/styleflow/cnf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchdiffeq import odeint_adjoint 4 | from torchdiffeq import odeint as odeint_normal 5 | 6 | __all__ = ["CNF", "SequentialFlow"] 7 | 8 | 9 | class SequentialFlow(nn.Module): 10 | """A generalized nn.Sequential container for normalizing flows.""" 11 | 12 | def __init__(self, layer_list): 13 | super(SequentialFlow, self).__init__() 14 | self.chain = nn.ModuleList(layer_list) 15 | 16 | def forward(self, x, context, logpx=None, reverse=False, inds=None, integration_times=None): 17 | if inds is None: 18 | if reverse: 19 | inds = range(len(self.chain) - 1, -1, -1) 20 | else: 21 | inds = range(len(self.chain)) 22 | 23 | if logpx is None: 24 | for i in inds: 25 | # print(x.shape) 26 | x = self.chain[i](x, context, logpx, integration_times, reverse) 27 | return x 28 | else: 29 | for i in inds: 30 | x, logpx = self.chain[i](x, context, logpx, integration_times, reverse) 31 | return x, logpx 32 | 33 | 34 | class CNF(nn.Module): 35 | def __init__(self, odefunc, conditional=True, T=1.0, train_T=False, regularization_fns=None, 36 | solver='dopri5', atol=1e-5, rtol=1e-5, use_adjoint=True): 37 | super(CNF, self).__init__() 38 | self.train_T = train_T 39 | self.T = T 40 | if train_T: 41 | self.register_parameter("sqrt_end_time", nn.Parameter(torch.sqrt(torch.tensor(T)))) 42 | print("Training T :", self.T) 43 | 44 | if regularization_fns is not None and len(regularization_fns) > 0: 45 | raise NotImplementedError("Regularization not supported") 46 | self.use_adjoint = use_adjoint 47 | self.odefunc = odefunc 48 | self.solver = solver 49 | self.atol = atol 50 | self.rtol = rtol 51 | self.test_solver = solver 52 | self.test_atol = atol 53 | self.test_rtol = rtol 54 | self.solver_options = {} 55 | self.conditional = conditional 56 | 57 | def forward(self, x, context=None, logpx=None, integration_times=None, reverse=False): 58 | if logpx is None: 59 | _logpx = torch.zeros(*x.shape[:-1], 1).to(x) 60 | else: 61 | _logpx = logpx 62 | 63 | if self.conditional: 64 | assert context is not None 65 | states = (x, _logpx, context) 66 | atol = [self.atol] * 3 67 | rtol = [self.rtol] * 3 68 | else: 69 | states = (x, _logpx) 70 | atol = [self.atol] * 2 71 | rtol = [self.rtol] * 2 72 | 73 | if integration_times is None: 74 | if self.train_T: 75 | integration_times = torch.stack( 76 | [torch.tensor(0.0).to(x), self.sqrt_end_time * self.sqrt_end_time] 77 | ).to(x) 78 | # print("integration times:", integration_times) 79 | else: 80 | integration_times = torch.tensor([0., self.T], requires_grad=False).to(x) 81 | 82 | if reverse: 83 | integration_times = _flip(integration_times, 0) 84 | 85 | # Refresh the odefunc statistics. 86 | self.odefunc.before_odeint() 87 | odeint = odeint_adjoint if self.use_adjoint else odeint_normal 88 | if self.training: 89 | state_t = odeint( 90 | self.odefunc, 91 | states, 92 | integration_times.to(x), 93 | atol=atol, 94 | rtol=rtol, 95 | method=self.solver, 96 | options=self.solver_options, 97 | ) 98 | else: 99 | state_t = odeint( 100 | self.odefunc, 101 | states, 102 | integration_times.to(x), 103 | atol=self.test_atol, 104 | rtol=self.test_rtol, 105 | method=self.test_solver, 106 | ) 107 | 108 | if len(integration_times) == 2: 109 | 110 | state_t = tuple(s[1] for s in state_t) 111 | 112 | 113 | 114 | z_t, logpz_t = state_t[:2] 115 | 116 | if logpx is not None: 117 | return z_t, logpz_t 118 | else: 119 | return z_t 120 | 121 | def num_evals(self): 122 | return self.odefunc._num_evals.item() 123 | 124 | 125 | def _flip(x, dim): 126 | indices = [slice(None)] * x.dim() 127 | indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) 128 | return x[tuple(indices)] 129 | -------------------------------------------------------------------------------- /DissimilarDomains/editing/styleflow/diffeq_layers.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def weights_init(m): 7 | classname = m.__class__.__name__ 8 | if classname.find('Linear') != -1 or classname.find('Conv') != -1: 9 | nn.init.constant_(m.weight, 0) 10 | nn.init.normal_(m.bias, 0, 0.01) 11 | 12 | 13 | class IgnoreLinear(nn.Module): 14 | def __init__(self, dim_in, dim_out, dim_c): 15 | super(IgnoreLinear, self).__init__() 16 | self._layer = nn.Linear(dim_in, dim_out) 17 | 18 | def forward(self, context, x): 19 | return self._layer(x) 20 | 21 | 22 | class ConcatLinear(nn.Module): 23 | def __init__(self, dim_in, dim_out, dim_c): 24 | super(ConcatLinear, self).__init__() 25 | self._layer = nn.Linear(dim_in + 1 + dim_c, dim_out) 26 | 27 | def forward(self, context, x, c): 28 | if x.dim() == 3: 29 | context = context.unsqueeze(1).expand(-1, x.size(1), -1) 30 | x_context = torch.cat((x, context), dim=2) 31 | return self._layer(x_context) 32 | 33 | 34 | class ConcatLinear_v2(nn.Module): 35 | def __init__(self, dim_in, dim_out, dim_c): 36 | super(ConcatLinear_v2, self).__init__() 37 | self._layer = nn.Linear(dim_in, dim_out) 38 | self._hyper_bias = nn.Linear(1 + dim_c, dim_out, bias=False) 39 | 40 | def forward(self, context, x): 41 | bias = self._hyper_bias(context) 42 | if x.dim() == 3: 43 | bias = bias.unsqueeze(1) 44 | return self._layer(x) + bias 45 | 46 | 47 | class SquashLinear(nn.Module): 48 | def __init__(self, dim_in, dim_out, dim_c): 49 | super(SquashLinear, self).__init__() 50 | self._layer = nn.Linear(dim_in, dim_out) 51 | self._hyper = nn.Linear(1 + dim_c, dim_out) 52 | 53 | def forward(self, context, x): 54 | gate = torch.sigmoid(self._hyper(context)) 55 | if x.dim() == 3: 56 | gate = gate.unsqueeze(1) 57 | return self._layer(x) * gate 58 | 59 | 60 | class ScaleLinear(nn.Module): 61 | def __init__(self, dim_in, dim_out, dim_c): 62 | super(ScaleLinear, self).__init__() 63 | self._layer = nn.Linear(dim_in, dim_out) 64 | self._hyper = nn.Linear(1 + dim_c, dim_out) 65 | 66 | def forward(self, context, x): 67 | gate = self._hyper(context) 68 | if x.dim() == 3: 69 | gate = gate.unsqueeze(1) 70 | return self._layer(x) * gate 71 | 72 | 73 | class ConcatSquashLinear(nn.Module): 74 | def __init__(self, dim_in, dim_out, dim_c): 75 | super(ConcatSquashLinear, self).__init__() 76 | self._layer = nn.Linear(dim_in, dim_out) 77 | self._hyper_bias = nn.Linear(1 + dim_c, dim_out, bias=False) 78 | self._hyper_gate = nn.Linear(1 + dim_c, dim_out) 79 | 80 | def forward(self, context, x): 81 | gate = torch.sigmoid(self._hyper_gate(context)) 82 | bias = self._hyper_bias(context) 83 | if x.dim() == 3: 84 | gate = gate.unsqueeze(1) 85 | bias = bias.unsqueeze(1) 86 | ret = self._layer(x) * gate + bias 87 | return ret 88 | 89 | 90 | class ConcatScaleLinear(nn.Module): 91 | def __init__(self, dim_in, dim_out, dim_c): 92 | super(ConcatScaleLinear, self).__init__() 93 | self._layer = nn.Linear(dim_in, dim_out) 94 | self._hyper_bias = nn.Linear(1 + dim_c, dim_out, bias=False) 95 | self._hyper_gate = nn.Linear(1 + dim_c, dim_out) 96 | 97 | def forward(self, context, x): 98 | gate = self._hyper_gate(context) 99 | bias = self._hyper_bias(context) 100 | if x.dim() == 3: 101 | gate = gate.unsqueeze(1) 102 | bias = bias.unsqueeze(1) 103 | ret = self._layer(x) * gate + bias 104 | return ret 105 | -------------------------------------------------------------------------------- /DissimilarDomains/editing/styleflow/flow.py: -------------------------------------------------------------------------------- 1 | from .odefunc import ODEfunc, ODEnet 2 | from .normalization import MovingBatchNorm1d 3 | from .cnf import CNF, SequentialFlow 4 | 5 | 6 | def count_nfe(model): 7 | class AccNumEvals(object): 8 | 9 | def __init__(self): 10 | self.num_evals = 0 11 | 12 | def __call__(self, module): 13 | if isinstance(module, CNF): 14 | self.num_evals += module.num_evals() 15 | 16 | accumulator = AccNumEvals() 17 | model.apply(accumulator) 18 | return accumulator.num_evals 19 | 20 | 21 | def count_parameters(model): 22 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 23 | 24 | 25 | def count_total_time(model): 26 | class Accumulator(object): 27 | 28 | def __init__(self): 29 | self.total_time = 0 30 | 31 | def __call__(self, module): 32 | if isinstance(module, CNF): 33 | self.total_time = self.total_time + module.sqrt_end_time * module.sqrt_end_time 34 | 35 | accumulator = Accumulator() 36 | model.apply(accumulator) 37 | return accumulator.total_time 38 | 39 | 40 | def build_model( input_dim, hidden_dims, context_dim, num_blocks, conditional): 41 | def build_cnf(): 42 | diffeq = ODEnet( 43 | hidden_dims=hidden_dims, 44 | input_shape=(input_dim,), 45 | context_dim=context_dim, 46 | layer_type='concatsquash', 47 | nonlinearity='tanh', 48 | ) 49 | odefunc = ODEfunc( 50 | diffeq=diffeq, 51 | ) 52 | cnf = CNF( 53 | odefunc=odefunc, 54 | T=1.0, 55 | train_T=True, 56 | conditional=conditional, 57 | solver='dopri5', 58 | use_adjoint=True, 59 | atol=1e-5, 60 | rtol=1e-5, 61 | ) 62 | return cnf 63 | 64 | chain = [build_cnf() for _ in range(num_blocks)] 65 | bn_layers = [MovingBatchNorm1d(input_dim, bn_lag=0, sync=False) 66 | for _ in range(num_blocks)] 67 | bn_chain = [MovingBatchNorm1d(input_dim, bn_lag=0, sync=False)] 68 | for a, b in zip(chain, bn_layers): 69 | bn_chain.append(a) 70 | bn_chain.append(b) 71 | chain = bn_chain 72 | model = SequentialFlow(chain) 73 | 74 | return model 75 | 76 | 77 | def cnf(input_dim,dims,zdim,num_blocks): 78 | dims = tuple(map(int, dims.split("-"))) 79 | model = build_model(input_dim, dims, zdim, num_blocks, True).cuda() 80 | print("Number of trainable parameters of Point CNF: {}".format(count_parameters(model))) 81 | return model 82 | 83 | 84 | -------------------------------------------------------------------------------- /DissimilarDomains/examples/inverted_samples/41_0/0_proj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/DissimilarDomains/examples/inverted_samples/41_0/0_proj.png -------------------------------------------------------------------------------- /DissimilarDomains/examples/inverted_samples/41_0/0_projected_z.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/DissimilarDomains/examples/inverted_samples/41_0/0_projected_z.npz -------------------------------------------------------------------------------- /DissimilarDomains/examples/inverted_samples/41_0/0_target.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/DissimilarDomains/examples/inverted_samples/41_0/0_target.png -------------------------------------------------------------------------------- /DissimilarDomains/examples/inverted_samples/41_0/projected_z.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/DissimilarDomains/examples/inverted_samples/41_0/projected_z.npz -------------------------------------------------------------------------------- /DissimilarDomains/examples/inverted_samples/pixabay_dog_003552/0_proj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/DissimilarDomains/examples/inverted_samples/pixabay_dog_003552/0_proj.png -------------------------------------------------------------------------------- /DissimilarDomains/examples/inverted_samples/pixabay_dog_003552/0_projected_z.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/DissimilarDomains/examples/inverted_samples/pixabay_dog_003552/0_projected_z.npz -------------------------------------------------------------------------------- /DissimilarDomains/examples/inverted_samples/pixabay_dog_003552/0_target.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/DissimilarDomains/examples/inverted_samples/pixabay_dog_003552/0_target.png -------------------------------------------------------------------------------- /DissimilarDomains/examples/inverted_samples/pixabay_dog_003552/projected_z.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/DissimilarDomains/examples/inverted_samples/pixabay_dog_003552/projected_z.npz -------------------------------------------------------------------------------- /DissimilarDomains/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /DissimilarDomains/metrics/frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Frechet Inception Distance (FID) from the paper 10 | "GANs trained by a two time-scale update rule converge to a local Nash 11 | equilibrium". Matches the original implementation by Heusel et al. at 12 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py""" 13 | 14 | import numpy as np 15 | import scipy.linalg 16 | from . import metric_utils 17 | 18 | 19 | # ---------------------------------------------------------------------------- 20 | 21 | def compute_fid(opts, max_real, num_gen): 22 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 23 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' 24 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 25 | 26 | mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset( 27 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 28 | rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov() 29 | 30 | mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator( 31 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 32 | rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov() 33 | 34 | if opts.rank != 0: 35 | return float('nan') 36 | 37 | m = np.square(mu_gen - mu_real).sum() 38 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member 39 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) 40 | return float(fid) 41 | 42 | # ---------------------------------------------------------------------------- 43 | -------------------------------------------------------------------------------- /DissimilarDomains/metrics/inception_score.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Inception Score (IS) from the paper "Improved techniques for training 10 | GANs". Matches the original implementation by Salimans et al. at 11 | https://github.com/openai/improved-gan/blob/master/inception_score/model.py""" 12 | 13 | import numpy as np 14 | from . import metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_is(opts, num_gen, num_splits): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' 21 | detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer. 22 | 23 | gen_probs = metric_utils.compute_feature_stats_for_generator( 24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 25 | capture_all=True, max_items=num_gen).get_all() 26 | 27 | if opts.rank != 0: 28 | return float('nan'), float('nan') 29 | 30 | scores = [] 31 | for i in range(num_splits): 32 | part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits] 33 | kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True))) 34 | kl = np.mean(np.sum(kl, axis=1)) 35 | scores.append(np.exp(kl)) 36 | return float(np.mean(scores)), float(np.std(scores)) 37 | 38 | #---------------------------------------------------------------------------- 39 | -------------------------------------------------------------------------------- /DissimilarDomains/metrics/kernel_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Kernel Inception Distance (KID) from the paper "Demystifying MMD 10 | GANs". Matches the original implementation by Binkowski et al. at 11 | https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py""" 12 | 13 | import numpy as np 14 | from . import metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' 21 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 22 | 23 | real_features = metric_utils.compute_feature_stats_for_dataset( 24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 25 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all() 26 | 27 | gen_features = metric_utils.compute_feature_stats_for_generator( 28 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 29 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all() 30 | 31 | if opts.rank != 0: 32 | return float('nan') 33 | 34 | n = real_features.shape[1] 35 | m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size) 36 | t = 0 37 | for _subset_idx in range(num_subsets): 38 | x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)] 39 | y = real_features[np.random.choice(real_features.shape[0], m, replace=False)] 40 | a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3 41 | b = (x @ y.T / n + 1) ** 3 42 | t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m 43 | kid = t / num_subsets / m 44 | return float(kid) 45 | 46 | #---------------------------------------------------------------------------- 47 | -------------------------------------------------------------------------------- /DissimilarDomains/metrics/precision_recall.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Precision/Recall (PR) from the paper "Improved Precision and Recall 10 | Metric for Assessing Generative Models". Matches the original implementation 11 | by Kynkaanniemi et al. at 12 | https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py""" 13 | 14 | import torch 15 | from . import metric_utils 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size): 20 | assert 0 <= rank < num_gpus 21 | num_cols = col_features.shape[0] 22 | num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus 23 | col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches) 24 | dist_batches = [] 25 | for col_batch in col_batches[rank :: num_gpus]: 26 | dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0] 27 | for src in range(num_gpus): 28 | dist_broadcast = dist_batch.clone() 29 | if num_gpus > 1: 30 | torch.distributed.broadcast(dist_broadcast, src=src) 31 | dist_batches.append(dist_broadcast.cpu() if rank == 0 else None) 32 | return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size): 37 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' 38 | detector_kwargs = dict(return_features=True) 39 | 40 | real_features = metric_utils.compute_feature_stats_for_dataset( 41 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 42 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device) 43 | 44 | gen_features = metric_utils.compute_feature_stats_for_generator( 45 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 46 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device) 47 | 48 | results = dict() 49 | for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]: 50 | kth = [] 51 | for manifold_batch in manifold.split(row_batch_size): 52 | dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 53 | kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None) 54 | kth = torch.cat(kth) if opts.rank == 0 else None 55 | pred = [] 56 | for probes_batch in probes.split(row_batch_size): 57 | dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 58 | pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None) 59 | results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan') 60 | return results['precision'], results['recall'] 61 | 62 | #---------------------------------------------------------------------------- 63 | -------------------------------------------------------------------------------- /DissimilarDomains/samples/0_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/DissimilarDomains/samples/0_0.jpg -------------------------------------------------------------------------------- /DissimilarDomains/samples/31_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/DissimilarDomains/samples/31_0.jpg -------------------------------------------------------------------------------- /DissimilarDomains/samples/41_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/DissimilarDomains/samples/41_0.jpg -------------------------------------------------------------------------------- /DissimilarDomains/samples/46_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/DissimilarDomains/samples/46_0.jpg -------------------------------------------------------------------------------- /DissimilarDomains/samples/pixabay_dog_003552.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/DissimilarDomains/samples/pixabay_dog_003552.jpg -------------------------------------------------------------------------------- /DissimilarDomains/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /DissimilarDomains/torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /DissimilarDomains/torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /DissimilarDomains/torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /DissimilarDomains/torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /DissimilarDomains/torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import warnings 15 | from pkg_resources import parse_version 16 | 17 | import torch 18 | 19 | # pylint: disable=redefined-builtin 20 | # pylint: disable=arguments-differ 21 | # pylint: disable=protected-access 22 | 23 | # ---------------------------------------------------------------------------- 24 | 25 | enabled = False # Enable the custom op by setting this to true. 26 | _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 27 | 28 | 29 | # ---------------------------------------------------------------------------- 30 | 31 | def grid_sample(input, grid): 32 | if _should_use_custom_op(): 33 | return _GridSample2dForward.apply(input, grid) 34 | return torch.nn.functional.grid_sample( 35 | input=input, grid=grid, mode='bilinear', 36 | padding_mode='zeros', align_corners=False 37 | ) 38 | 39 | 40 | # ---------------------------------------------------------------------------- 41 | 42 | def _should_use_custom_op(): 43 | if not enabled: 44 | return False 45 | if parse_version(torch.__version__) >= parse_version('1.7.0'): 46 | return True 47 | warnings.warn( 48 | f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().' 49 | ) 50 | return False 51 | 52 | 53 | # ---------------------------------------------------------------------------- 54 | 55 | class _GridSample2dForward(torch.autograd.Function): 56 | @staticmethod 57 | def forward(ctx, input, grid): 58 | assert input.ndim == 4 59 | assert grid.ndim == 4 60 | output = torch.nn.functional.grid_sample( 61 | input=input, grid=grid, mode='bilinear', 62 | padding_mode='zeros', align_corners=False 63 | ) 64 | ctx.save_for_backward(input, grid) 65 | return output 66 | 67 | @staticmethod 68 | def backward(ctx, grad_output): 69 | input, grid = ctx.saved_tensors 70 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 71 | return grad_input, grad_grid 72 | 73 | 74 | # ---------------------------------------------------------------------------- 75 | 76 | class _GridSample2dBackward(torch.autograd.Function): 77 | @staticmethod 78 | def forward(ctx, grad_output, input, grid): 79 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 80 | if _use_pytorch_1_11_api: 81 | op = op[0] 82 | output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2]) 83 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask) 84 | else: 85 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 86 | ctx.save_for_backward(grid) 87 | return grad_input, grad_grid 88 | 89 | @staticmethod 90 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 91 | _ = grad2_grad_grid # unused 92 | grid, = ctx.saved_tensors 93 | grad2_grad_output = None 94 | grad2_input = None 95 | grad2_grid = None 96 | 97 | if ctx.needs_input_grad[0]: 98 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 99 | 100 | assert not ctx.needs_input_grad[2] 101 | return grad2_grad_output, grad2_input, grad2_grid 102 | 103 | # ---------------------------------------------------------------------------- 104 | -------------------------------------------------------------------------------- /DissimilarDomains/torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 25 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 26 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 27 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 28 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 29 | 30 | // Create output tensor. 31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 32 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 33 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 34 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 35 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 36 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 37 | 38 | // Initialize CUDA kernel parameters. 39 | upfirdn2d_kernel_params p; 40 | p.x = x.data_ptr(); 41 | p.f = f.data_ptr(); 42 | p.y = y.data_ptr(); 43 | p.up = make_int2(upx, upy); 44 | p.down = make_int2(downx, downy); 45 | p.pad0 = make_int2(padx0, pady0); 46 | p.flip = (flip) ? 1 : 0; 47 | p.gain = gain; 48 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 49 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 50 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 51 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 52 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 53 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 54 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 55 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 56 | 57 | // Choose CUDA kernel. 58 | upfirdn2d_kernel_spec spec; 59 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 60 | { 61 | spec = choose_upfirdn2d_kernel(p); 62 | }); 63 | 64 | // Set looping options. 65 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 66 | p.loopMinor = spec.loopMinor; 67 | p.loopX = spec.loopX; 68 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 69 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 70 | 71 | // Compute grid size. 72 | dim3 blockSize, gridSize; 73 | if (spec.tileOutW < 0) // large 74 | { 75 | blockSize = dim3(4, 32, 1); 76 | gridSize = dim3( 77 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 78 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 79 | p.launchMajor); 80 | } 81 | else // small 82 | { 83 | blockSize = dim3(256, 1, 1); 84 | gridSize = dim3( 85 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 86 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 87 | p.launchMajor); 88 | } 89 | 90 | // Launch CUDA kernel. 91 | void* args[] = {&p}; 92 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 93 | return y; 94 | } 95 | 96 | //------------------------------------------------------------------------ 97 | 98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 99 | { 100 | m.def("upfirdn2d", &upfirdn2d); 101 | } 102 | 103 | //------------------------------------------------------------------------ 104 | -------------------------------------------------------------------------------- /DissimilarDomains/torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /DissimilarDomains/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /SimilarDomains/configs/im2im_difa.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | config_dir: configs 3 | config: im2im_single.yaml 4 | project: test_project 5 | tags: 6 | - test 7 | name: test_run_difa 8 | seed: 0 9 | root: . 10 | notes: empty notes 11 | logging: true 12 | step_save: 25 13 | trainer: im2im_difa 14 | training: 15 | iter_num: 301 16 | batch_size: 2 17 | device: cuda:0 18 | generator: stylegan2 19 | phase: all 20 | patch_key: original 21 | source_class: Real Person 22 | target_class: ./image_domains/disney.png 23 | auto_layer_k: 16 24 | auto_layer_iters: 0 25 | auto_layer_batch: 8 26 | mixing_noise: 0.0 27 | clip_layer: 3 28 | model_type: 'difa' 29 | emb: 30 | type: 'mean' 31 | num: 16 32 | online_truncation: 0.8 33 | src_emb_dir: 'clip_means' 34 | inversion: 35 | method: e4e 36 | align_style: False 37 | model_path: pretrained/e4e_ffhq_encode.pt 38 | latents_root: 'latents_inversion_training' 39 | optimization_setup: 40 | visual_encoders: 41 | - ViT-B/32 42 | - ViT-B/16 43 | loss_funcs: 44 | - direction 45 | - difa_w 46 | - difa_local 47 | loss_coefs: 48 | - 1.0 49 | - 6.0 50 | - 1.0 51 | g_reg_every: 4 52 | optimizer: 53 | weight_decay: 0.0 54 | lr: 0.002 55 | betas: 56 | - 0.0 57 | - 0.999 58 | scheduler: 59 | n_steps: 20 60 | start_lr: 0.001 61 | generator_args: 62 | checkpoint_path: pretrained/StyleGAN2/stylegan2-ffhq-config-f.pt 63 | evaluation: 64 | is_on: false 65 | vision_models: 66 | - ViT-B/32 67 | - ViT-B/16 68 | step: 200 69 | data_size: 500 70 | batch_size: 24 71 | fid: true 72 | fid_ref: ../few-shot-gan-adaptation/sketches_all_resized/ 73 | logging: 74 | log_every: 10 75 | log_images: 20 76 | latents_to_edit: [] 77 | truncation: 0.7 78 | num_grid_outputs: 1 79 | checkpointing: 80 | is_on: false 81 | start_from: false 82 | step_backup: 500 83 | -------------------------------------------------------------------------------- /SimilarDomains/configs/im2im_difa_for_low_memory.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | config_dir: configs 3 | config: im2im_single.yaml 4 | project: test_project 5 | tags: 6 | - test 7 | name: test_run_difa 8 | seed: 0 9 | root: . 10 | notes: empty notes 11 | logging: true 12 | step_save: 25 13 | trainer: im2im_difa 14 | training: 15 | iter_num: 301 16 | batch_size: 1 17 | device: cuda:0 18 | generator: stylegan2 19 | phase: all 20 | patch_key: original 21 | source_class: Real Person 22 | target_class: ./image_domains/mermaid.png 23 | auto_layer_k: 16 24 | auto_layer_iters: 0 25 | auto_layer_batch: 8 26 | mixing_noise: 0.0 27 | clip_layer: 3 28 | model_type: 'difa' 29 | emb: 30 | type: 'mean' 31 | num: 16 32 | online_truncation: 0.8 33 | src_emb_dir: 'clip_means' 34 | inversion: 35 | method: e4e 36 | align_style: False 37 | model_path: pretrained/e4e_ffhq_encode.pt 38 | latents_root: 'latents_inversion_training' 39 | optimization_setup: 40 | visual_encoders: 41 | - ViT-B/32 42 | - ViT-B/16 43 | loss_funcs: 44 | - direction 45 | - difa_w 46 | - difa_local 47 | loss_coefs: 48 | - 1.0 49 | - 6.0 50 | - 1.0 51 | g_reg_every: 4 52 | optimizer: 53 | weight_decay: 0.0 54 | lr: 0.002 55 | betas: 56 | - 0.0 57 | - 0.999 58 | scheduler: 59 | n_steps: 20 60 | start_lr: 0.001 61 | generator_args: 62 | checkpoint_path: pretrained/stylegan2-ffhq-config-f.pt 63 | evaluation: 64 | is_on: false 65 | vision_models: 66 | - ViT-B/32 67 | - ViT-B/16 68 | step: 200 69 | data_size: 500 70 | batch_size: 24 71 | fid: true 72 | fid_ref: ../few-shot-gan-adaptation/sketches_all_resized/ 73 | logging: 74 | log_every: 10 75 | log_images: 20 76 | latents_to_edit: [] 77 | truncation: 0.7 78 | num_grid_outputs: 1 79 | low_memory: true 80 | checkpointing: 81 | is_on: false 82 | start_from: false 83 | step_backup: 500 84 | -------------------------------------------------------------------------------- /SimilarDomains/configs/im2im_difa_hdn.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | config_dir: configs 3 | config: im2im_single.yaml 4 | project: test_project 5 | tags: 6 | - test 7 | name: cin_adan 8 | seed: 0 9 | root: . 10 | notes: empty notes 11 | logging: true 12 | step_save: 25 13 | trainer: im2im_difa 14 | training: 15 | iter_num: 301 16 | batch_size: 2 17 | device: cuda:0 18 | generator: stylegan2 19 | phase: all 20 | patch_key: cin_mult 21 | source_class: Real Person 22 | target_class: ./image_domains/adan.png 23 | auto_layer_k: 16 24 | auto_layer_iters: 0 25 | auto_layer_batch: 8 26 | mixing_noise: 0.0 27 | clip_layer: 3 28 | model_type: 'difa' 29 | emb: 30 | type: 'mean' 31 | num: 16 32 | online_truncation: 0.8 33 | src_emb_dir: 'clip_means' 34 | inversion: 35 | method: e4e 36 | align_style: False 37 | model_path: pretrained/e4e_ffhq_encode.pt 38 | latents_root: 'latents_inversion_training' 39 | optimization_setup: 40 | visual_encoders: 41 | - ViT-B/32 42 | - ViT-B/16 43 | loss_funcs: 44 | - direction 45 | - difa_w 46 | - difa_local 47 | loss_coefs: 48 | - 1.0 49 | - 6.0 50 | - 1.0 51 | g_reg_every: 4 52 | optimizer: 53 | weight_decay: 0.0 54 | lr: 0.02 55 | betas: 56 | - 0.9 57 | - 0.999 58 | scheduler: 59 | n_steps: 20 60 | start_lr: 0.001 61 | generator_args: 62 | checkpoint_path: pretrained/StyleGAN2/stylegan2-ffhq-config-f.pt 63 | evaluation: 64 | is_on: false 65 | vision_models: 66 | - ViT-B/32 67 | - ViT-B/16 68 | step: 200 69 | data_size: 500 70 | batch_size: 24 71 | fid: true 72 | fid_ref: ../few-shot-gan-adaptation/sketches_all_resized/ 73 | logging: 74 | log_every: 10 75 | log_images: 20 76 | latents_to_edit: [] 77 | truncation: 0.7 78 | num_grid_outputs: 1 79 | checkpointing: 80 | is_on: false 81 | start_from: false 82 | step_backup: 500 83 | -------------------------------------------------------------------------------- /SimilarDomains/configs/im2im_difa_sdelta.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | config_dir: configs 3 | config: im2im_single.yaml 4 | project: test_project 5 | tags: 6 | - test 7 | name: test_run_difa 8 | seed: 0 9 | root: . 10 | notes: empty notes 11 | logging: true 12 | step_save: 25 13 | trainer: im2im_difa 14 | training: 15 | iter_num: 301 16 | batch_size: 2 17 | device: cuda:0 18 | generator: stylegan2 19 | phase: all 20 | patch_key: s_delta 21 | source_class: Real Person 22 | target_class: ./image_domains/adan.png 23 | auto_layer_k: 16 24 | auto_layer_iters: 0 25 | auto_layer_batch: 8 26 | mixing_noise: 0.0 27 | clip_layer: 3 28 | model_type: 'difa' 29 | emb: 30 | type: 'mean' 31 | num: 16 32 | online_truncation: 0.8 33 | src_emb_dir: 'clip_means' 34 | inversion: 35 | method: e4e 36 | align_style: False 37 | model_path: pretrained/e4e_ffhq_encode.pt 38 | latents_root: 'latents_inversion_training' 39 | optimization_setup: 40 | visual_encoders: 41 | - ViT-B/32 42 | - ViT-B/16 43 | loss_funcs: 44 | - direction 45 | - difa_w 46 | - difa_local 47 | loss_coefs: 48 | - 1.0 49 | - 6.0 50 | - 1.0 51 | g_reg_every: 4 52 | optimizer: 53 | weight_decay: 0.0 54 | lr: 0.08 55 | betas: 56 | - 0.9 57 | - 0.999 58 | scheduler: 59 | n_steps: 30 60 | start_lr: 0.001 61 | generator_args: 62 | checkpoint_path: pretrained/StyleGAN2/stylegan2-ffhq-config-f.pt 63 | evaluation: 64 | is_on: false 65 | vision_models: 66 | - ViT-B/32 67 | - ViT-B/16 68 | step: 200 69 | data_size: 500 70 | batch_size: 24 71 | fid: true 72 | fid_ref: ../few-shot-gan-adaptation/sketches_all_resized/ 73 | logging: 74 | log_every: 10 75 | log_images: 20 76 | latents_to_edit: [] 77 | truncation: 0.7 78 | num_grid_outputs: 1 79 | checkpointing: 80 | is_on: false 81 | start_from: false 82 | step_backup: 500 83 | -------------------------------------------------------------------------------- /SimilarDomains/configs/im2im_jojo.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | config_dir: configs 3 | config: im2im_single.yaml 4 | project: test_project 5 | tags: 6 | - test 7 | name: test_run_jojo 8 | seed: 0 9 | root: . 10 | notes: empty notes 11 | logging: true 12 | step_save: 25 13 | trainer: im2im_JoJo 14 | training: 15 | iter_num: 250 16 | batch_size: 4 17 | device: cuda:0 18 | generator: stylegan2 19 | phase: all 20 | patch_key: original 21 | source_class: Real Person 22 | target_class: ./image_domains/jojo.png 23 | auto_layer_k: 16 24 | auto_layer_iters: 0 25 | auto_layer_batch: 8 26 | mixing_noise: 0.9 27 | alpha: 1. 28 | mix_stylespace: True 29 | preserve_color: False 30 | inversion: 31 | method: e4e 32 | align_style: False 33 | model_path: pretrained/e4e_ffhq_encode.pt 34 | latents_root: 'latents_inversion_training' 35 | optimization_setup: 36 | visual_encoders: 37 | - ViT-B/32 38 | - ViT-B/16 39 | loss_funcs: 40 | - disc_feat_matching 41 | loss_coefs: 42 | - 1.0 43 | g_reg_every: 4 44 | optimizer: 45 | weight_decay: 0.0 46 | lr: 0.002 47 | betas: 48 | - 0.0 49 | - 0.999 50 | scheduler: 51 | n_steps: 20 52 | start_lr: 0.001 53 | generator_args: 54 | checkpoint_path: pretrained/StyleGAN2/stylegan2-ffhq-config-f.pt 55 | evaluation: 56 | is_on: false 57 | vision_models: 58 | - ViT-B/32 59 | - ViT-B/16 60 | step: 200 61 | data_size: 500 62 | batch_size: 24 63 | fid: true 64 | fid_ref: ../few-shot-gan-adaptation/sketches_all_resized/ 65 | logging: 66 | log_every: 10 67 | log_images: 20 68 | latents_to_edit: [] 69 | truncation: 0.7 70 | num_grid_outputs: 1 71 | checkpointing: 72 | is_on: false 73 | start_from: false 74 | step_backup: 500 75 | -------------------------------------------------------------------------------- /SimilarDomains/configs/im2im_jojo_for_low_memory.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | config_dir: configs 3 | config: im2im_single.yaml 4 | project: test_project 5 | tags: 6 | - test 7 | name: test_run_jojo 8 | seed: 0 9 | root: . 10 | notes: empty notes 11 | logging: true 12 | step_save: 25 13 | trainer: im2im_JoJo 14 | training: 15 | iter_num: 251 16 | batch_size: 2 17 | device: cuda:0 18 | generator: stylegan2 19 | phase: all 20 | patch_key: original 21 | source_class: Real Person 22 | target_class: ./image_domains/mermaid.png 23 | auto_layer_k: 16 24 | auto_layer_iters: 0 25 | auto_layer_batch: 8 26 | mixing_noise: 0.9 27 | alpha: 1. 28 | mix_stylespace: True 29 | preserve_color: False 30 | inversion: 31 | method: e4e 32 | align_style: False 33 | model_path: pretrained/e4e_ffhq_encode.pt 34 | latents_root: 'latents_inversion_training' 35 | optimization_setup: 36 | visual_encoders: 37 | - ViT-B/32 38 | - ViT-B/16 39 | loss_funcs: 40 | - disc_feat_matching 41 | loss_coefs: 42 | - 1.0 43 | g_reg_every: 4 44 | optimizer: 45 | weight_decay: 0.0 46 | lr: 0.002 47 | betas: 48 | - 0.0 49 | - 0.999 50 | scheduler: 51 | n_steps: 20 52 | start_lr: 0.001 53 | generator_args: 54 | checkpoint_path: pretrained/stylegan2-ffhq-config-f.pt 55 | evaluation: 56 | is_on: false 57 | vision_models: 58 | - ViT-B/32 59 | - ViT-B/16 60 | step: 200 61 | data_size: 500 62 | batch_size: 24 63 | fid: true 64 | fid_ref: ../few-shot-gan-adaptation/sketches_all_resized/ 65 | logging: 66 | log_every: 10 67 | log_images: 20 68 | latents_to_edit: [] 69 | truncation: 0.7 70 | num_grid_outputs: 1 71 | low_memory: true 72 | checkpointing: 73 | is_on: false 74 | start_from: false 75 | step_backup: 500 76 | -------------------------------------------------------------------------------- /SimilarDomains/configs/im2im_jojo_sdelta.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | config_dir: configs 3 | config: im2im_single.yaml 4 | project: test_project 5 | tags: 6 | - test 7 | name: test_run_jojo 8 | seed: 0 9 | root: . 10 | notes: empty notes 11 | logging: true 12 | step_save: 25 13 | trainer: im2im_JoJo 14 | training: 15 | iter_num: 250 16 | batch_size: 4 17 | device: cuda:0 18 | generator: stylegan2 19 | phase: all 20 | patch_key: s_delta 21 | source_class: Real Person 22 | target_class: ./image_domains/jojo.png 23 | auto_layer_k: 16 24 | auto_layer_iters: 0 25 | auto_layer_batch: 8 26 | mixing_noise: 0.9 27 | alpha: 1. 28 | mix_stylespace: True 29 | preserve_color: False 30 | inversion: 31 | method: e4e 32 | align_style: False 33 | model_path: pretrained/e4e_ffhq_encode.pt 34 | latents_root: 'latents_inversion_training' 35 | optimization_setup: 36 | visual_encoders: 37 | - ViT-B/32 38 | - ViT-B/16 39 | loss_funcs: 40 | - disc_feat_matching 41 | loss_coefs: 42 | - 1.0 43 | g_reg_every: 4 44 | optimizer: 45 | weight_decay: 0.0 46 | lr: 0.05 47 | betas: 48 | - 0.9 49 | - 0.999 50 | scheduler: 51 | n_steps: 20 52 | start_lr: 0.001 53 | generator_args: 54 | checkpoint_path: pretrained/StyleGAN2/stylegan2-ffhq-config-f.pt 55 | evaluation: 56 | is_on: false 57 | vision_models: 58 | - ViT-B/32 59 | - ViT-B/16 60 | step: 200 61 | data_size: 500 62 | batch_size: 24 63 | fid: true 64 | fid_ref: ../few-shot-gan-adaptation/sketches_all_resized/ 65 | logging: 66 | log_every: 10 67 | log_images: 20 68 | latents_to_edit: [] 69 | truncation: 0.7 70 | num_grid_outputs: 1 71 | checkpointing: 72 | is_on: false 73 | start_from: false 74 | step_backup: 500 75 | -------------------------------------------------------------------------------- /SimilarDomains/configs/im2im_mtg.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | config_dir: configs 3 | config: im2im_single.yaml 4 | project: test_project 5 | tags: 6 | - test 7 | name: test_run_mtg 8 | seed: 0 9 | root: . 10 | notes: empty notes 11 | logging: true 12 | step_save: 25 13 | trainer: im2im_single 14 | training: 15 | iter_num: 250 16 | batch_size: 4 17 | device: cuda:0 18 | generator: stylegan2 19 | phase: all 20 | patch_key: original 21 | source_class: Real Person 22 | target_class: ./image_domains/maryna_tymoshenko_.jpg 23 | auto_layer_k: 16 24 | auto_layer_iters: 0 25 | auto_layer_batch: 8 26 | mixing_noise: 0.9 27 | inversion: 28 | method: e4e 29 | align_style: False 30 | model_path: pretrained/e4e_ffhq_encode.pt 31 | latents_root: 'latents_inversion_training' 32 | optimization_setup: 33 | visual_encoders: 34 | - ViT-B/32 35 | - ViT-B/16 36 | loss_funcs: 37 | - direction 38 | - clip_within 39 | - clip_ref 40 | - l2_rec 41 | - lpips_rec 42 | loss_coefs: 43 | - 1.0 44 | - 0.5 45 | - 30.0 46 | - 10.0 47 | - 10.0 48 | g_reg_every: 4 49 | optimizer: 50 | weight_decay: 0.0 51 | lr: 0.002 52 | betas: 53 | - 0.0 54 | - 0.999 55 | scheduler: 56 | n_steps: 25 57 | start_lr: 0.001 58 | generator_args: 59 | checkpoint_path: pretrained/StyleGAN2/stylegan2-ffhq-config-f.pt 60 | evaluation: 61 | is_on: false 62 | metrics: 63 | - ('fid', 100, "folder/images") 64 | vision_models: 65 | - ViT-B/32 66 | - ViT-B/16 67 | step: 200 68 | data_size: 500 69 | batch_size: 24 70 | fid: true 71 | fid_ref: ../few-shot-gan-adaptation/sketches_all_resized/ 72 | logging: 73 | log_every: 10 74 | log_images: 20 75 | latents_to_edit: [] 76 | truncation: 0.7 77 | num_grid_outputs: 1 78 | checkpointing: 79 | is_on: false 80 | start_from: false 81 | step_backup: 500 82 | -------------------------------------------------------------------------------- /SimilarDomains/configs/im2im_mtg_for_low_memory.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | config_dir: configs 3 | config: im2im_single.yaml 4 | project: test_project 5 | tags: 6 | - test 7 | name: test_run_mtg 8 | seed: 0 9 | root: . 10 | notes: empty notes 11 | logging: true 12 | step_save: 25 13 | trainer: im2im_single 14 | training: 15 | iter_num: 251 16 | batch_size: 2 17 | device: cuda:0 18 | generator: stylegan2 19 | phase: all 20 | patch_key: original 21 | source_class: Real Person 22 | target_class: ./image_domains/mermaid.png 23 | auto_layer_k: 16 24 | auto_layer_iters: 0 25 | auto_layer_batch: 8 26 | mixing_noise: 0.9 27 | inversion: 28 | method: e4e 29 | align_style: False 30 | model_path: pretrained/e4e_ffhq_encode.pt 31 | latents_root: 'latents_inversion_training' 32 | optimization_setup: 33 | visual_encoders: 34 | - ViT-B/32 35 | - ViT-B/16 36 | loss_funcs: 37 | - direction 38 | - clip_within 39 | - clip_ref 40 | - l2_rec 41 | - lpips_rec 42 | loss_coefs: 43 | - 1.0 44 | - 0.5 45 | - 30.0 46 | - 10.0 47 | - 10.0 48 | g_reg_every: 4 49 | optimizer: 50 | weight_decay: 0.0 51 | lr: 0.002 52 | betas: 53 | - 0.0 54 | - 0.999 55 | scheduler: 56 | n_steps: 25 57 | start_lr: 0.001 58 | generator_args: 59 | checkpoint_path: pretrained/stylegan2-ffhq-config-f.pt 60 | evaluation: 61 | is_on: false 62 | metrics: 63 | - ('fid', 100, "folder/images") 64 | vision_models: 65 | - ViT-B/32 66 | - ViT-B/16 67 | step: 200 68 | data_size: 500 69 | batch_size: 24 70 | fid: true 71 | fid_ref: ../few-shot-gan-adaptation/sketches_all_resized/ 72 | logging: 73 | log_every: 10 74 | log_images: 20 75 | latents_to_edit: [] 76 | truncation: 0.7 77 | num_grid_outputs: 1 78 | low_memory: true 79 | checkpointing: 80 | is_on: false 81 | start_from: false 82 | step_backup: 500 83 | -------------------------------------------------------------------------------- /SimilarDomains/configs/im2im_mtg_sdelta.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | config_dir: configs 3 | config: im2im_single.yaml 4 | project: test_project 5 | tags: 6 | - test 7 | name: test_run_mtg 8 | seed: 0 9 | root: . 10 | notes: empty notes 11 | logging: true 12 | step_save: 25 13 | trainer: im2im_single 14 | training: 15 | iter_num: 250 16 | batch_size: 4 17 | device: cuda:0 18 | generator: stylegan2 19 | phase: all 20 | patch_key: s_delta 21 | source_class: Real Person 22 | target_class: ./image_domains/maryna_tymoshenko_.jpg.png 23 | auto_layer_k: 16 24 | auto_layer_iters: 0 25 | auto_layer_batch: 8 26 | mixing_noise: 0.9 27 | inversion: 28 | method: e4e 29 | align_style: False 30 | model_path: pretrained/e4e_ffhq_encode.pt 31 | latents_root: 'latents_inversion_training' 32 | optimization_setup: 33 | visual_encoders: 34 | - ViT-B/32 35 | - ViT-B/16 36 | loss_funcs: 37 | - direction 38 | - clip_within 39 | - clip_ref 40 | - l2_rec 41 | - lpips_rec 42 | loss_coefs: 43 | - 1.0 44 | - 0.5 45 | - 30.0 46 | - 10.0 47 | - 10.0 48 | g_reg_every: 4 49 | optimizer: 50 | weight_decay: 0.0 51 | lr: 0.05 52 | betas: 53 | - 0.9 54 | - 0.999 55 | scheduler: 56 | n_steps: 25 57 | start_lr: 0.001 58 | generator_args: 59 | checkpoint_path: pretrained/StyleGAN2/stylegan2-ffhq-config-f.pt 60 | evaluation: 61 | is_on: false 62 | metrics: 63 | - ('fid', 100, "folder/images") 64 | vision_models: 65 | - ViT-B/32 66 | - ViT-B/16 67 | step: 200 68 | data_size: 500 69 | batch_size: 24 70 | fid: true 71 | fid_ref: ../few-shot-gan-adaptation/sketches_all_resized/ 72 | logging: 73 | log_every: 10 74 | log_images: 20 75 | latents_to_edit: [] 76 | truncation: 0.7 77 | num_grid_outputs: 1 78 | checkpointing: 79 | is_on: false 80 | start_from: false 81 | step_backup: 500 82 | -------------------------------------------------------------------------------- /SimilarDomains/configs/td_nada_cars.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | config_dir: configs 3 | config: td_single_car.yaml 4 | project: test_project 5 | tags: 6 | - None 7 | name: test_run_cars 8 | seed: 0 9 | root: . 10 | notes: empty notes 11 | logging: true 12 | step_save: 100 13 | trainer: td_single 14 | training: 15 | iter_num: 400 16 | batch_size: 8 17 | device: cuda:0 18 | generator: stylegan2 19 | phase: all 20 | patch_key: original 21 | source_class: Car 22 | target_class: Golden Car 23 | auto_layer_k: 16 24 | auto_layer_iters: 0 25 | auto_layer_batch: 8 26 | mixing_noise: 0.9 27 | optimization_setup: 28 | visual_encoders: 29 | - ViT-B/32 30 | - ViT-B/16 31 | loss_funcs: 32 | - direction 33 | - indomain 34 | loss_coefs: 35 | - 1.0 36 | - 1.5 37 | g_reg_every: 4 38 | optimizer: 39 | weight_decay: 0.0 40 | lr: 0.002 41 | betas: 42 | - 0.0 43 | - 0.999 44 | generator_args: 45 | checkpoint_path: pretrained/StyleGAN2/stylegan2-car-config-f.pt 46 | img_size: 512 47 | evaluation: 48 | is_on: false 49 | vision_models: 50 | - ViT-B/32 51 | - ViT-L/14 52 | step: 100000 53 | data_size: 500 54 | batch_size: 24 55 | logging: 56 | log_every: 10 57 | log_images: 20 58 | latents_to_edit: [] 59 | truncation: 0.5 60 | num_grid_outputs: 1 61 | checkpointing: 62 | is_on: false 63 | start_from: false 64 | step_backup: 100000 -------------------------------------------------------------------------------- /SimilarDomains/configs/td_nada_ffhq.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | config_dir: configs 3 | config: td_single_ffhq.yaml 4 | project: test_project 5 | tags: 6 | - None 7 | name: test_run_nada 8 | seed: 12 9 | root: . 10 | notes: empty notes 11 | logging: true 12 | step_save: 25 13 | trainer: td_single 14 | training: 15 | iter_num: 250 16 | batch_size: 4 17 | device: cuda:0 18 | generator: stylegan2 19 | phase: all 20 | patch_key: original 21 | source_class: Human 22 | target_class: The Joker 23 | no_coarse: True 24 | auto_layer_k: 16 25 | auto_layer_iters: 0 26 | auto_layer_batch: 8 27 | mixing_noise: 0.9 28 | optimization_setup: 29 | visual_encoders: 30 | - ViT-B/32 31 | - ViT-B/16 32 | loss_funcs: 33 | - direction 34 | - indomain 35 | loss_coefs: 36 | - 1.0 37 | - 0.25 38 | g_reg_every: 4 39 | optimizer: 40 | weight_decay: 0.0 41 | lr: 0.002 42 | betas: 43 | - 0.0 44 | - 0.999 45 | generator_args: 46 | checkpoint_path: pretrained/StyleGAN2/stylegan2-ffhq-config-f.pt 47 | evaluation: 48 | is_on: false 49 | vision_models: 50 | - ViT-L/14 51 | step: 20 52 | data_size: 500 53 | batch_size: 12 54 | logging: 55 | log_every: 10 56 | log_images: 20 57 | latents_to_edit: [] 58 | truncation: 0.7 59 | num_grid_outputs: 1 60 | checkpointing: 61 | is_on: false 62 | start_from: false 63 | step_backup: 100000 -------------------------------------------------------------------------------- /SimilarDomains/configs/td_nada_ffhq_sdelta.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | config_dir: configs 3 | config: td_single_ffhq.yaml 4 | project: test_project 5 | tags: 6 | - None 7 | name: test_run_nada 8 | seed: 12 9 | root: . 10 | notes: empty notes 11 | logging: true 12 | step_save: 25 13 | trainer: td_single 14 | training: 15 | iter_num: 250 16 | batch_size: 4 17 | device: cuda:0 18 | generator: stylegan2 19 | phase: all 20 | patch_key: s_delta 21 | source_class: Human 22 | target_class: The Joker 23 | no_coarse: True 24 | auto_layer_k: 16 25 | auto_layer_iters: 0 26 | auto_layer_batch: 8 27 | mixing_noise: 0.9 28 | optimization_setup: 29 | visual_encoders: 30 | - ViT-B/32 31 | - ViT-B/16 32 | loss_funcs: 33 | - direction 34 | - indomain 35 | loss_coefs: 36 | - 1.0 37 | - 0.25 38 | g_reg_every: 4 39 | optimizer: 40 | weight_decay: 0.0 41 | lr: 0.05 42 | betas: 43 | - 0.9 44 | - 0.999 45 | generator_args: 46 | checkpoint_path: pretrained/StyleGAN2/stylegan2-ffhq-config-f.pt 47 | evaluation: 48 | is_on: false 49 | vision_models: 50 | - ViT-L/14 51 | step: 20 52 | data_size: 500 53 | batch_size: 12 54 | logging: 55 | log_every: 10 56 | log_images: 20 57 | latents_to_edit: [] 58 | truncation: 0.7 59 | num_grid_outputs: 1 60 | checkpointing: 61 | is_on: false 62 | start_from: false 63 | step_backup: 100000 -------------------------------------------------------------------------------- /SimilarDomains/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/SimilarDomains/core/__init__.py -------------------------------------------------------------------------------- /SimilarDomains/core/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import dlib 4 | import PIL 5 | 6 | from PIL import Image 7 | from pathlib import Path 8 | from torch.utils.data import Dataset 9 | from torchvision import transforms as transforms 10 | 11 | from core.utils.common import align_face 12 | 13 | IMG_EXTENSIONS = [ 14 | '.jpg', '.JPG', '.jpeg', '.JPEG', 15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' 16 | ] 17 | 18 | 19 | def is_image_file(filename): 20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 21 | 22 | 23 | def make_dataset(dir): 24 | images = [] 25 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 26 | for root, _, fnames in sorted(os.walk(dir)): 27 | for fname in fnames: 28 | if is_image_file(fname): 29 | path = os.path.join(root, fname) 30 | images.append(path) 31 | return images 32 | 33 | 34 | class ImagesDataset(Dataset): 35 | def __init__(self, opts, image_path=None, align_input=False): 36 | if type(image_path) == list: 37 | self.image_paths = image_path 38 | elif os.path.isdir(image_path): 39 | self.image_paths = sorted(make_dataset(image_path)) 40 | elif os.path.isfile(image_path): 41 | self.image_paths = [image_path] 42 | else: 43 | raise ValueError(f"Incorrect 'image_path' argument in ImagesDataset, {image_path}") 44 | 45 | self.image_transform = transforms.Compose([ 46 | transforms.ToTensor(), 47 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 48 | ]) 49 | 50 | self.opts = opts 51 | self.align_input = align_input 52 | 53 | if self.align_input: 54 | weight_path = str(Path(__file__).parent.parent / 'pretrained/shape_predictor_68_face_landmarks.dat') 55 | self.predictor = dlib.shape_predictor(weight_path) 56 | 57 | def __len__(self): 58 | return len(self.image_paths) 59 | 60 | def __getitem__(self, index): 61 | im_path = Path(self.image_paths[index]) 62 | 63 | if self.align_input: 64 | im_H = align_face(str(im_path), self.predictor, output_size=self.opts.size) 65 | else: 66 | im_H = Image.open(str(im_path)).convert('RGB') 67 | im_H = im_H.resize((self.opts.size, self.opts.size)) 68 | 69 | im_L = im_H.resize((256, 256), PIL.Image.LANCZOS) 70 | 71 | return { 72 | "image_high_res": im_H, 73 | "image_low_res": im_L, 74 | "image_high_res_torch": self.image_transform(im_H), 75 | "image_low_res_torch": self.image_transform(im_L), 76 | "image_name": im_path.stem 77 | } 78 | -------------------------------------------------------------------------------- /SimilarDomains/core/lpips/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.autograd import Variable 5 | from pdb import set_trace as st 6 | from IPython import embed 7 | 8 | class BaseModel(): 9 | def __init__(self): 10 | pass; 11 | 12 | def name(self): 13 | return 'BaseModel' 14 | 15 | def initialize(self, use_gpu=True, gpu_ids=[0]): 16 | self.use_gpu = use_gpu 17 | self.gpu_ids = gpu_ids 18 | 19 | def forward(self): 20 | pass 21 | 22 | def get_image_paths(self): 23 | pass 24 | 25 | def optimize_parameters(self): 26 | pass 27 | 28 | def get_current_visuals(self): 29 | return self.input 30 | 31 | def get_current_errors(self): 32 | return {} 33 | 34 | def save(self, label): 35 | pass 36 | 37 | # helper saving function that can be used by subclasses 38 | def save_network(self, network, path, network_label, epoch_label): 39 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 40 | save_path = os.path.join(path, save_filename) 41 | torch.save(network.state_dict(), save_path) 42 | 43 | # helper loading function that can be used by subclasses 44 | def load_network(self, network, network_label, epoch_label): 45 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 46 | save_path = os.path.join(self.save_dir, save_filename) 47 | print('Loading network from %s'%save_path) 48 | network.load_state_dict(torch.load(save_path)) 49 | 50 | def update_learning_rate(): 51 | pass 52 | 53 | def get_image_paths(self): 54 | return self.image_paths 55 | 56 | def save_done(self, flag=False): 57 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 58 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 59 | -------------------------------------------------------------------------------- /SimilarDomains/core/lpips/weights/v0.0/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/SimilarDomains/core/lpips/weights/v0.0/alex.pth -------------------------------------------------------------------------------- /SimilarDomains/core/lpips/weights/v0.0/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/SimilarDomains/core/lpips/weights/v0.0/squeeze.pth -------------------------------------------------------------------------------- /SimilarDomains/core/lpips/weights/v0.0/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/SimilarDomains/core/lpips/weights/v0.0/vgg.pth -------------------------------------------------------------------------------- /SimilarDomains/core/lpips/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/SimilarDomains/core/lpips/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /SimilarDomains/core/lpips/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/SimilarDomains/core/lpips/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /SimilarDomains/core/lpips/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/SimilarDomains/core/lpips/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /SimilarDomains/core/sparse_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torch.autograd import Function 5 | from collections import OrderedDict 6 | from core.utils.common import get_stylegan_conv_dimensions 7 | 8 | 9 | def cat_stylespace(style_space): 10 | return torch.cat(style_space, dim=1) 11 | 12 | 13 | def split_stylespace(style_space, img_size=1024): 14 | prev = 0 15 | result = [] 16 | for cin, _ in get_stylegan_conv_dimensions(img_size): 17 | result.append(style_space[:, prev:prev + cin]) 18 | prev = prev + cin 19 | 20 | return result 21 | 22 | 23 | def to_dict_input(deltas, img_size=1024): 24 | dict_input = OrderedDict() 25 | start_idx = 0 26 | 27 | for idx, (in_d, out_d) in enumerate(get_stylegan_conv_dimensions(img_size)): 28 | dict_input[f'conv_{idx}'] = { 29 | 'in': deltas[:, start_idx: start_idx + in_d] 30 | } 31 | start_idx += in_d 32 | 33 | return dict_input 34 | 35 | 36 | def to_tensor(dict_input): 37 | return torch.cat([v['in'] for k, v in dict_input.items()], dim=1) 38 | 39 | 40 | def ckpt_to_tensor(ckpt): 41 | state_dict = ckpt['state_dict'] 42 | n = len(state_dict) 43 | return torch.cat([state_dict[f'heads.conv_{i}.params_in'] for i in range(n)], dim=1) 44 | 45 | 46 | class SparsedModel(nn.Module): 47 | def __init__(self, device, ckpt=None): 48 | super().__init__() 49 | 50 | self.device = device 51 | self.convid_to_st = dict([ 52 | (0, 0), (1, 2), (2, 3), (3, 5), 53 | (4, 6), (5, 8), (6, 9), (7, 11), 54 | (8, 12), (9, 14), (10, 15), (11, 17), 55 | (12, 18), (13, 20), (14, 21), (15, 23), 56 | (16, 24) 57 | ]) 58 | 59 | self.s_to_conv_id = {v:k for k, v in self.convid_to_st.items()} 60 | self.input_keys = sorted(self.s_to_conv_id.keys()) 61 | 62 | self.deltas = nn.Parameter(torch.zeros(1, 6048)) 63 | self.register_buffer('grad_mask', torch.ones(1, 6048)) 64 | 65 | if ckpt is not None: 66 | self._deltas_from_ckpt(ckpt) 67 | 68 | def forward(self, style_space): 69 | st = torch.cat([style_space[i] for i in self.input_keys], dim=1) 70 | st_shifted = st + self.deltas * self.grad_mask 71 | 72 | splited_st = split_stylespace(st_shifted) 73 | answer = [ 74 | splited_st[self.s_to_conv_id[i]] if i in self.input_keys else style_space[i] for i in range(len(style_space)) 75 | ] 76 | 77 | return answer, to_dict_input(self.deltas) 78 | 79 | def offsets(self): 80 | return to_dict_input(self.deltas) 81 | 82 | def pruned_offsets(self, perc): 83 | deltas_pruned = torch.clone(self.deltas.data) 84 | top = torch.abs(deltas_pruned.squeeze()).argsort() # top to lower 85 | chosen_idxes = int(6048 * perc) 86 | deltas_pruned[:, top[:chosen_idxes]] = 0. 87 | return to_dict_input(deltas_pruned) 88 | 89 | def _deltas_from_ckpt(self, ckpt): 90 | self.deltas = nn.Parameter(ckpt_to_tensor(ckpt).to(self.device)) 91 | return self 92 | 93 | def pruned_forward(self, style_space, perc): 94 | deltas_pruned = torch.clone(self.deltas.data) 95 | top = torch.abs(deltas_pruned.squeeze()).argsort().flipud() 96 | chosen_idxes = int(6048 * perc) 97 | deltas_pruned = deltas_pruned[:, top[-chosen_idxes:]] = 0. 98 | st = cat_stylespace(style_space) 99 | st_shifted = self.fn(st, deltas_pruned, torch.ones(1, 6048)) 100 | return split_stylespace(st_shifted) 101 | -------------------------------------------------------------------------------- /SimilarDomains/core/style_embed_options.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | ################# II2S options for style image embedding Note: p_norm_lambda = 1e-2 not 1e-3 4 | opts = Namespace() 5 | 6 | # StyleGAN2 setting 7 | opts.size = 1024 8 | opts.ckpt = "pretrained/StyleGAN2/stylegan2-ffhq-config-f.pt" 9 | opts.channel_multiplier = 2 10 | opts.latent = 512 11 | opts.n_mlp = 8 12 | 13 | # loss options 14 | opts.percept_lambda = 1.0 15 | opts.l2_lambda = 1.0 16 | opts.p_norm_lambda = 1e-2 17 | 18 | # arguments 19 | opts.device = 'cuda' 20 | opts.seed = 2 21 | opts.tile_latent = False 22 | opts.opt_name = 'adam' 23 | opts.learning_rate = 0.01 24 | opts.lr_schedule = 'fixed' 25 | opts.steps = 1000 26 | opts.save_intermediate = False 27 | opts.save_interval = 300 28 | opts.verbose = True 29 | 30 | II2S_s_opts = opts -------------------------------------------------------------------------------- /SimilarDomains/core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/SimilarDomains/core/utils/__init__.py -------------------------------------------------------------------------------- /SimilarDomains/core/utils/arguments.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from omegaconf import OmegaConf 4 | from core.uda_models import uda_models 5 | from core.utils.class_registry import ClassRegistry 6 | 7 | 8 | args = ClassRegistry() 9 | generator_args = uda_models.make_dataclass_from_args() 10 | args.add_to_registry("generator_args")(generator_args) 11 | additional_arguments = args.make_dataclass_from_classes() 12 | 13 | DEFAULT_CONFIG_DIR = 'configs' 14 | 15 | 16 | def get_generator_args(generator_name, base_args, conf_args): 17 | return OmegaConf.create( 18 | {generator_name: OmegaConf.merge(base_args, conf_args)} 19 | ) 20 | 21 | 22 | def load_config(): 23 | base_gen_args_config = OmegaConf.structured(additional_arguments) 24 | 25 | conf_cli = OmegaConf.from_cli() 26 | conf_cli.exp.config_dir = DEFAULT_CONFIG_DIR 27 | if not conf_cli.get('exp', False): 28 | raise ValueError("No config") 29 | 30 | config_path = os.path.join(conf_cli.exp.config_dir, conf_cli.exp.config) 31 | conf_file = OmegaConf.load(config_path) 32 | 33 | conf_generator_args = conf_file.generator_args 34 | 35 | generator_args = get_generator_args( 36 | conf_file.training.generator, 37 | base_gen_args_config.generator_args[conf_file.training.generator], 38 | conf_generator_args 39 | ) 40 | 41 | gen_args = OmegaConf.create({ 42 | 'generator_args': generator_args 43 | }) 44 | 45 | config = OmegaConf.merge(conf_file, conf_cli) 46 | config = OmegaConf.merge(config, gen_args) 47 | return config 48 | -------------------------------------------------------------------------------- /SimilarDomains/core/utils/class_registry.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import typing 3 | import omegaconf 4 | import dataclasses 5 | import typing as tp 6 | 7 | 8 | class ClassRegistry: 9 | def __init__(self): 10 | self.classes = dict() 11 | self.args = dict() 12 | self.arg_keys = None 13 | 14 | def __getitem__(self, item): 15 | return self.classes[item] 16 | 17 | def make_dataclass_from_func(self, func, name, arg_keys): 18 | args = inspect.signature(func).parameters 19 | args = [ 20 | (k, typing.Any, omegaconf.MISSING) 21 | if v.default is inspect.Parameter.empty 22 | else (k, typing.Optional[typing.Any], None) 23 | if v.default is None 24 | else ( 25 | k, 26 | type(v.default), 27 | dataclasses.field(default=v.default), 28 | ) 29 | for k, v in args.items() 30 | ] 31 | args = [ 32 | arg 33 | for arg in args 34 | if (arg[0] != "self" and arg[0] != "args" and arg[0] != "kwargs") 35 | ] 36 | if arg_keys: 37 | self.arg_keys = arg_keys 38 | arg_classes = dict() 39 | for key in arg_keys: 40 | arg_classes[key] = dataclasses.make_dataclass(key, args) 41 | return dataclasses.make_dataclass( 42 | name, 43 | [ 44 | (k, v, dataclasses.field(default=v())) 45 | for k, v in arg_classes.items() 46 | ], 47 | ) 48 | return dataclasses.make_dataclass(name, args) 49 | 50 | def make_dataclass_from_classes(self): 51 | return dataclasses.make_dataclass( 52 | 'Name', 53 | [ 54 | (k, v, dataclasses.field(default=v())) 55 | for k, v in self.classes.items() 56 | ], 57 | ) 58 | 59 | def make_dataclass_from_args(self): 60 | return dataclasses.make_dataclass( 61 | 'Name', 62 | [ 63 | (k, v, dataclasses.field(default=v())) 64 | for k, v in self.args.items() 65 | ], 66 | ) 67 | 68 | def _add_single_obj(self, obj, name, arg_keys): 69 | self.classes[name] = obj 70 | if inspect.isfunction(obj): 71 | self.args[name] = self.make_dataclass_from_func( 72 | obj, name, arg_keys 73 | ) 74 | elif inspect.isclass(obj): 75 | self.args[name] = self.make_dataclass_from_func( 76 | obj.__init__, name, arg_keys 77 | ) 78 | 79 | def add_to_registry(self, names: tp.Union[str, tp.List[str]], arg_keys=None): 80 | if not isinstance(names, list): 81 | names = [names] 82 | 83 | def decorator(obj): 84 | for name in names: 85 | self._add_single_obj(obj, name, arg_keys) 86 | 87 | return obj 88 | return decorator 89 | 90 | def __contains__(self, name: str): 91 | return name in self.args.keys() 92 | 93 | def __repr__(self): 94 | return f"{list(self.args.keys())}" 95 | -------------------------------------------------------------------------------- /SimilarDomains/core/utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | 6 | def cosine_loss(x, y): 7 | return 1.0 - F.cosine_similarity(x, y) 8 | 9 | 10 | def mse_loss(x, y): 11 | return F.mse_loss(x, y) 12 | 13 | 14 | def mae_loss(x, y): 15 | return F.l1_loss(x, y) 16 | 17 | 18 | def get_tril_elemets(matrix_torch: torch.Tensor): 19 | flat = torch.tril(matrix_torch, diagonal=-1).flatten() 20 | return flat[torch.nonzero(flat)] 21 | 22 | 23 | def get_tril_elements_mask(linear_size): 24 | mask = np.zeros((linear_size, linear_size), dtype=np.bool) 25 | mask[np.tril_indices_from(mask)] = True 26 | np.fill_diagonal(mask, False) 27 | return mask 28 | 29 | 30 | def flatten_with_non_diagonal(input_matix: torch.Tensor): 31 | linear_matrix_size = input_matix.size(0) 32 | 33 | non_diag = input_matix.flatten()[1:].view(linear_matrix_size - 1, linear_matrix_size + 1)[:, :-1] 34 | return non_diag.flatten() 35 | -------------------------------------------------------------------------------- /SimilarDomains/core/utils/math_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def resample_single_vector(target_vector, cos_lower_bound, n_vectors=1): 5 | """ 6 | Resample one vector 'n_vectors' times with lower bound of cos 'cos_lower_bound' 7 | 8 | Parameters 9 | ---------- 10 | target_vector : torch.Tensor with size() == (1, dim) || (dim) 11 | center of resampling 12 | cos_lower_bound : float 13 | lower bound of cos of resampled vectors 14 | n_vectors : int 15 | number of resampled vectors 16 | 17 | Returns 18 | ------- 19 | omega : torch.Tensor, size [n_vectors, vector_dim] 20 | resampled vectors with cos with target_vector higher than thr_cos 21 | """ 22 | 23 | if target_vector.ndim == 1: 24 | target_vector = target_vector.unsqueeze(0) 25 | 26 | _, dim = target_vector.size() 27 | 28 | u = target_vector / target_vector.norm(dim=-1, keepdim=True) 29 | u = u.repeat(n_vectors, 1) 30 | r = torch.rand_like(u) * 2 - 1 31 | uperp = torch.stack([r[i] - (torch.dot(r[i], u[i]) * u[i]) for i in range(u.size(0))]) 32 | uperp = uperp / uperp.norm(dim=1, keepdim=True) 33 | 34 | cos_theta = torch.rand(n_vectors, device=target_vector.device) * (1 - cos_lower_bound) + cos_lower_bound 35 | cos_theta = cos_theta.unsqueeze(1).repeat(1, target_vector.size(1)) 36 | omega = cos_theta * u + torch.sqrt(1 - cos_theta ** 2) * uperp 37 | 38 | return omega 39 | 40 | 41 | def resample_batch_vectors(target_vector, cos_lower_bound): 42 | """ 43 | Resample 'b' vector 'n_vectors' times with lower bound of cos 'cos_lower_bound' 44 | 45 | Parameters 46 | ---------- 47 | target_vector : torch.Tensor with size() == (b, dim) 48 | center of resampling 49 | cos_lower_bound : float 50 | lower bound of cos of resampled vectors 51 | 52 | Returns 53 | ------- 54 | omega : torch.Tensor, size [n_vectors, vector_dim] 55 | resampled vectors with cos with target_vector higher than thr_cos 56 | """ 57 | 58 | b, dim = target_vector.size() 59 | u = target_vector / target_vector.norm(dim=-1, keepdim=True) 60 | r = torch.rand_like(u) * 2 - 1 61 | uperp = torch.stack([r[i] - (torch.dot(r[i], u[i]) * u[i]) for i in range(u.size(0))]) 62 | uperp = uperp / uperp.norm(dim=1, keepdim=True) 63 | 64 | cos_theta = torch.rand(b, device=target_vector.device) * (1 - cos_lower_bound) + cos_lower_bound 65 | cos_theta = cos_theta.unsqueeze(1).repeat(1, target_vector.size(1)) 66 | omega = cos_theta * u + torch.sqrt(1 - cos_theta ** 2) * uperp 67 | 68 | return omega 69 | 70 | 71 | def resample_batch_templated_embeddings(embeddings, cos_lower_bound): 72 | if embeddings.ndim == 2: 73 | return resample_batch_vectors(embeddings, cos_lower_bound) 74 | 75 | batch, templates, dim = embeddings.shape 76 | embeddings = embeddings.view(-1, dim) 77 | resampled_embeddings = resample_batch_vectors(embeddings, cos_lower_bound) 78 | 79 | resampled_embeddings = resampled_embeddings.view(batch, templates, dim).contiguous() 80 | return resampled_embeddings 81 | 82 | 83 | def convex_hull(target_vectors, alphas): 84 | """ 85 | calculate convex hull with 'alphas' (1 > alpha > 0, \sum alphas = 1) for target vectors 86 | 87 | Parameters 88 | ---------- 89 | target_vectors : torch.Tensor 90 | set of vectors for which convex hull element is calculated. 91 | Size: [b, dim1, dim2] 92 | 93 | alphas : torch.Tensor 94 | appropriate alphas for which element from convex hull will be calculated. 95 | Size: [b, b] 96 | 97 | Returns 98 | ------- 99 | convex_hull_element : torch.Tensor 100 | single element from convex hull 101 | 102 | """ 103 | 104 | convex_hull_element = (target_vectors.unsqueeze(0) * alphas.unsqueeze(2).unsqueeze(3)).sum(dim=1) 105 | convex_hull_element /= convex_hull_element.clone().norm(dim=-1, keepdim=True) 106 | return convex_hull_element 107 | 108 | 109 | def convex_hull_small(target_vectors, alphas): 110 | """ 111 | calculate convex hull with 'alphas' (1 > alpha > 0, \sum alphas = 1) for target vectors 112 | 113 | Parameters 114 | ---------- 115 | target_vectors : torch.Tensor 116 | set of vectors for which convex hull element is calculated. 117 | Size: [b, dim1, dim2] 118 | 119 | alphas : torch.Tensor 120 | appropriate alphas for which element from convex hull will be calculated. 121 | Size: [b, b] 122 | 123 | Returns 124 | ------- 125 | convex_hull_element : torch.Tensor 126 | single element from convex hull 127 | 128 | """ 129 | 130 | convex_hull_element = (target_vectors.unsqueeze(0) * alphas.unsqueeze(2)).sum(dim=1) 131 | convex_hull_element /= convex_hull_element.clone().norm(dim=-1, keepdim=True) 132 | return convex_hull_element 133 | -------------------------------------------------------------------------------- /SimilarDomains/core/utils/notebook.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | from ..uda_models import uda_models 3 | 4 | 5 | def load_config_runtime( 6 | config_path 7 | ): 8 | cfg = OmegaConf.load(config_path) 9 | 10 | # process generator args from default and config 11 | gen_args = OmegaConf.create({ 12 | 'generator_args': { 13 | cfg.training.generator: OmegaConf.merge( 14 | OmegaConf.structured(uda_models.make_dataclass_from_args())[cfg.training.generator], 15 | OmegaConf.structured(cfg.generator_args) 16 | ) 17 | } 18 | }) 19 | cfg.generator_args.clear() 20 | cfg = OmegaConf.merge(cfg, gen_args) 21 | 22 | return cfg 23 | -------------------------------------------------------------------------------- /SimilarDomains/core/utils/reading_weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .example_utils import Inferencer 4 | 5 | 6 | def read_weights(exp_weights_root, idx=None): 7 | if idx is not None: 8 | model_path = exp_weights_root / f'models/models_{idx}.pt' 9 | return torch.load(model_path) 10 | models_path = list((exp_weights_root / 'models').iterdir()) 11 | model_path = sorted(models_path, key=lambda x: int(x.stem.split('_')[1]))[-1] 12 | return torch.load(model_path) 13 | 14 | 15 | def get_model(exp_weights_root, idx=None): 16 | ckpt = read_weights(exp_weights_root, idx) 17 | return Inferencer(ckpt, device) -------------------------------------------------------------------------------- /SimilarDomains/core/utils/text_templates.py: -------------------------------------------------------------------------------- 1 | imagenet_templates = [ 2 | 'a bad photo of a {}.', 3 | 'a sculpture of a {}.', 4 | 'a photo of the hard to see {}.', 5 | 'a low resolution photo of the {}.', 6 | 'a rendering of a {}.', 7 | 'graffiti of a {}.', 8 | 'a bad photo of the {}.', 9 | 'a cropped photo of the {}.', 10 | 'a tattoo of a {}.', 11 | 'the embroidered {}.', 12 | 'a photo of a hard to see {}.', 13 | 'a bright photo of a {}.', 14 | 'a photo of a clean {}.', 15 | 'a photo of a dirty {}.', 16 | 'a dark photo of the {}.', 17 | 'a drawing of a {}.', 18 | 'a photo of my {}.', 19 | 'the plastic {}.', 20 | 'a photo of the cool {}.', 21 | 'a close-up photo of a {}.', 22 | 'a black and white photo of the {}.', 23 | 'a painting of the {}.', 24 | 'a painting of a {}.', 25 | 'a pixelated photo of the {}.', 26 | 'a sculpture of the {}.', 27 | 'a bright photo of the {}.', 28 | 'a cropped photo of a {}.', 29 | 'a plastic {}.', 30 | 'a photo of the dirty {}.', 31 | 'a jpeg corrupted photo of a {}.', 32 | 'a blurry photo of the {}.', 33 | 'a photo of the {}.', 34 | 'a good photo of the {}.', 35 | 'a rendering of the {}.', 36 | 'a {} in a video game.', 37 | 'a photo of one {}.', 38 | 'a doodle of a {}.', 39 | 'a close-up photo of the {}.', 40 | 'a photo of a {}.', 41 | 'the origami {}.', 42 | 'the {} in a video game.', 43 | 'a sketch of a {}.', 44 | 'a doodle of the {}.', 45 | 'a origami {}.', 46 | 'a low resolution photo of a {}.', 47 | 'the toy {}.', 48 | 'a rendition of the {}.', 49 | 'a photo of the clean {}.', 50 | 'a photo of a large {}.', 51 | 'a rendition of a {}.', 52 | 'a photo of a nice {}.', 53 | 'a photo of a weird {}.', 54 | 'a blurry photo of a {}.', 55 | 'a cartoon {}.', 56 | 'art of a {}.', 57 | 'a sketch of the {}.', 58 | 'a embroidered {}.', 59 | 'a pixelated photo of a {}.', 60 | 'itap of the {}.', 61 | 'a jpeg corrupted photo of the {}.', 62 | 'a good photo of a {}.', 63 | 'a plushie {}.', 64 | 'a photo of the nice {}.', 65 | 'a photo of the small {}.', 66 | 'a photo of the weird {}.', 67 | 'the cartoon {}.', 68 | 'art of the {}.', 69 | 'a drawing of the {}.', 70 | 'a photo of the large {}.', 71 | 'a black and white photo of a {}.', 72 | 'the plushie {}.', 73 | 'a dark photo of a {}.', 74 | 'itap of a {}.', 75 | 'graffiti of the {}.', 76 | 'a toy {}.', 77 | 'itap of my {}.', 78 | 'a photo of a cool {}.', 79 | 'a photo of a small {}.', 80 | 'a tattoo of the {}.', 81 | ] 82 | 83 | part_templates = [ 84 | 'the paw of a {}.', 85 | 'the nose of a {}.', 86 | 'the eye of the {}.', 87 | 'the ears of a {}.', 88 | 'an eye of a {}.', 89 | 'the tongue of a {}.', 90 | 'the fur of the {}.', 91 | 'colorful {} fur.', 92 | 'a snout of a {}.', 93 | 'the teeth of the {}.', 94 | 'the {}s fangs.', 95 | 'a claw of the {}.', 96 | 'the face of the {}', 97 | 'a neck of a {}', 98 | 'the head of the {}', 99 | ] 100 | 101 | imagenet_templates_small = [ 102 | 'a photo of a {}.', 103 | 'a rendering of a {}.', 104 | 'a cropped photo of the {}.', 105 | 'the photo of a {}.', 106 | 'a photo of a clean {}.', 107 | 'a photo of a dirty {}.', 108 | 'a dark photo of the {}.', 109 | 'a photo of my {}.', 110 | 'a photo of the cool {}.', 111 | 'a close-up photo of a {}.', 112 | 'a bright photo of the {}.', 113 | 'a cropped photo of a {}.', 114 | 'a photo of the {}.', 115 | 'a good photo of the {}.', 116 | 'a photo of one {}.', 117 | 'a close-up photo of the {}.', 118 | 'a rendition of the {}.', 119 | 'a photo of the clean {}.', 120 | 'a rendition of a {}.', 121 | 'a photo of a nice {}.', 122 | 'a good photo of a {}.', 123 | 'a photo of the nice {}.', 124 | 'a photo of the small {}.', 125 | 'a photo of the weird {}.', 126 | 'a photo of the large {}.', 127 | 'a photo of a cool {}.', 128 | 'a photo of a small {}.', 129 | ] -------------------------------------------------------------------------------- /SimilarDomains/core/utils/train_log.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import collections 3 | import torch 4 | import time 5 | import datetime 6 | 7 | 8 | def strf_time_delta(td): 9 | td_str = "" 10 | if td.days > 0: 11 | td_str += f"{td.days} days, " if td.days > 1 else f"{td.days} day, " 12 | hours = td.seconds // 3600 13 | if hours > 0: 14 | td_str += f"{hours}h " 15 | minutes = (td.seconds // 60) % 60 16 | if minutes > 0: 17 | td_str += f"{minutes}m " 18 | seconds = td.seconds % 60 + td.microseconds * 1e-6 19 | td_str += f"{seconds:.1f}s" 20 | return td_str 21 | 22 | 23 | class Timer: 24 | def __init__(self, info=None, log_event=None): 25 | self.info = info 26 | self.log_event = log_event 27 | 28 | def __enter__(self): 29 | self.start = torch.cuda.Event(enable_timing=True) 30 | self.end = torch.cuda.Event(enable_timing=True) 31 | self.start.record() 32 | return self 33 | 34 | def __exit__(self, exc_type, exc_val, exc_tb): 35 | self.end.record() 36 | torch.cuda.synchronize() 37 | self.duration = self.start.elapsed_time(self.end) / 1000 38 | if self.info: 39 | self.info[f"duration/{self.log_event}"] = self.duration 40 | 41 | 42 | class TimeLog: 43 | def __init__(self, logger, total_num, event): 44 | self.logger = logger 45 | self.total_num = total_num 46 | self.event = event.upper() 47 | self.start = time.time() 48 | 49 | def now(self, current_num): 50 | elapsed = time.time() - self.start 51 | left = self.total_num * elapsed / (current_num + 1) - elapsed 52 | elapsed = strf_time_delta(datetime.timedelta(seconds=elapsed)) 53 | left = strf_time_delta(datetime.timedelta(seconds=left)) 54 | self.logger.log_info( 55 | f"TIME ELAPSED SINCE {self.event} START: {elapsed}" 56 | ) 57 | self.logger.log_info(f"TIME LEFT UNTIL {self.event} END: {left}") 58 | 59 | def end(self): 60 | elapsed = time.time() - self.start 61 | elapsed = strf_time_delta(datetime.timedelta(seconds=elapsed)) 62 | self.logger.log_info( 63 | f"TIME ELAPSED SINCE {self.event} START: {elapsed}" 64 | ) 65 | self.logger.log_info(f"{self.event} ENDS") 66 | 67 | 68 | class MeanTracker(object): 69 | def __init__(self, name): 70 | self.values = [] 71 | self.name = name 72 | 73 | def add(self, val): 74 | self.values.append(float(val)) 75 | 76 | def mean(self): 77 | return np.mean(self.values) 78 | 79 | def flush(self): 80 | mean = self.mean() 81 | self.values = [] 82 | return self.name, mean 83 | 84 | 85 | class _StreamingMean: 86 | def __init__(self, val=None, counts=None): 87 | if val is None: 88 | self.mean = 0.0 89 | self.counts = 0 90 | else: 91 | if isinstance(val, torch.Tensor): 92 | val = val.data.cpu().numpy() 93 | self.mean = val 94 | if counts is not None: 95 | self.counts = counts 96 | else: 97 | self.counts = 1 98 | 99 | def update(self, mean, counts=1): 100 | if isinstance(mean, torch.Tensor): 101 | mean = mean.data.cpu().numpy() 102 | elif isinstance(mean, _StreamingMean): 103 | mean, counts = mean.mean, mean.counts * counts 104 | assert counts >= 0 105 | if counts == 0: 106 | return 107 | total = self.counts + counts 108 | self.mean = self.counts / total * self.mean + counts / total * mean 109 | self.counts = total 110 | 111 | def __add__(self, other): 112 | new = self.__class__(self.mean, self.counts) 113 | if isinstance(other, _StreamingMean): 114 | if other.counts == 0: 115 | return new 116 | else: 117 | new.update(other.mean, other.counts) 118 | else: 119 | new.update(other) 120 | return new 121 | 122 | 123 | class StreamingMeans(collections.defaultdict): 124 | def __init__(self): 125 | super().__init__(_StreamingMean) 126 | 127 | def __setitem__(self, key, value): 128 | if isinstance(value, _StreamingMean): 129 | super().__setitem__(key, value) 130 | else: 131 | super().__setitem__(key, _StreamingMean(value)) 132 | 133 | def update(self, *args, **kwargs): 134 | for_update = dict(*args, **kwargs) 135 | for k, v in for_update.items(): 136 | self[k].update(v) 137 | 138 | def to_dict(self, prefix=""): 139 | return dict((prefix + k, v.mean) for k, v in self.items()) 140 | 141 | def to_str(self): 142 | return ", ".join([f"{k} = {v:.3f}" for k, v in self.to_dict().items()]) -------------------------------------------------------------------------------- /SimilarDomains/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import click 3 | import subprocess 4 | 5 | from pathlib import Path 6 | 7 | 8 | def download_curl(source: str, destination: str) -> None: 9 | subprocess.run( 10 | ['curl', '-L', '-k', source, '-o', destination], 11 | stdout=subprocess.DEVNULL 12 | ) 13 | 14 | 15 | def untar(path: str, destination: str = None): 16 | command = ['tar', '-xzvf', path] 17 | 18 | if destination is not None: 19 | command += ['-C', destination] 20 | if not os.path.exists(destination): 21 | os.makedirs(destination) 22 | subprocess.run(command, stdout=subprocess.DEVNULL) 23 | 24 | 25 | def download_gdrive(file_id: str, destination: str) -> None: 26 | subprocess.run(['gdown', '--id', file_id, '-O', destination]) 27 | 28 | 29 | def unzip(path: str, res_path: str = None): 30 | command = ['unzip', path] 31 | 32 | if res_path is not None: 33 | command += ['-d', res_path] 34 | subprocess.run(command, stdout=subprocess.DEVNULL) 35 | 36 | 37 | def bzip2(path: str): 38 | subprocess.run(['bzip2', '-d', path]) 39 | 40 | 41 | def rm_file(path: str): 42 | subprocess.run(['rm', path]) 43 | 44 | 45 | class Setup: 46 | def __init__(self): 47 | self.root = Path(__file__).parent 48 | self.pretrained_root = Path(__file__).parent / 'pretrained' 49 | self.pretrained_root.mkdir(exist_ok=True) 50 | 51 | def _download(self, data): 52 | 53 | if data.get('root_located', False): 54 | root = self.root 55 | else: 56 | root = self.pretrained_root 57 | 58 | file_dest = str(root / data['name']) 59 | 60 | if 'link' in data: 61 | download_curl(data['link'], file_dest) 62 | elif 'id' in data: 63 | download_gdrive(data['id'], file_dest) 64 | 65 | if file_dest.endswith('bz2'): 66 | bzip2(file_dest) 67 | rm_file(file_dest) 68 | elif file_dest.endswith('tar.gz'): 69 | untar(file_dest, str(root / data['uncompressed_dir'])) 70 | rm_file(file_dest) 71 | elif file_dest.endswith('.zip'): 72 | unzip(file_dest, str(root / data['uncompressed_dir'])) 73 | rm_file(file_dest) 74 | 75 | def setup(self, values): 76 | for value in values: 77 | self._download(SOURCES[value]) 78 | 79 | 80 | SOURCES = { 81 | 'sg2-ffhq': { 82 | 'link': 'https://nxt.2a2i.org/index.php/s/kyR9byFznz5GBTd/download/stylegan2-ffhq-config-f.pt.zip', 83 | 'name': 'stylegan2-ffhq-config-f.pt.zip', 84 | 'uncompressed_dir': '' 85 | }, 86 | 'dlib': { 87 | 'link': 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2', 88 | 'name': 'shape_predictor_68_face_landmarks.dat.bz2' 89 | }, 90 | 'restyle_psp': { 91 | 'id': '1nbxCIVw9H3YnQsoIPykNEFwWJnHVHlVd', 92 | 'name': 'restyle_psp_ffhq_encode.pt' 93 | }, 94 | 'e4e': { 95 | 'link': 'https://nxt.2a2i.org/index.php/s/ey49AsRwgyK77C9/download/e4e_ffhq_encode.pt.zip', 96 | 'name': 'e4e_ffhq_encode.pt.zip', 97 | 'uncompressed_dir': '' 98 | }, 99 | 'clip_means': { 100 | 'link': 'https://nxt.2a2i.org/index.php/s/CbxaqSy6C7sFNW2/download/clip_means.zip', 101 | 'name': 'clip_means.zip', 102 | 'uncompressed_dir': '', 103 | 'root_located': True, 104 | }, 105 | 'ckpt': { 106 | 'link': 'https://nxt.2a2i.org/index.php/s/eDWLK8rDzSFoxeZ/download/checkpoints.tar.gz', 107 | 'name': 'checkpoints.tar.gz', 108 | 'uncompressed_dir': 'checkpoints_iccv' 109 | }, 110 | 'sg2_tuned': { 111 | 'link': 'https://nxt.2a2i.org/index.php/s/JzwG7gFHaKrHwDt/download/StyleGAN2_ADA.zip', 112 | 'name': 'StyleGAN2_ADA.zip', 113 | 'uncompressed_dir': '' 114 | }, 115 | 'sg2': { 116 | 'link': 'https://nxt.2a2i.org/index.php/s/2K3jbFD3Tg7QmHA/download/StyleGAN2.zip', 117 | 'name': 'StyleGAN2.zip', 118 | 'uncompressed_dir': '' 119 | }, 120 | 'image_domains': { 121 | 'link': 'https://nxt.2a2i.org/index.php/s/ZTBnffeW5TfrJjy/download/image_domains.zip', 122 | 'name': 'image_domains.zip', 123 | 'uncompressed_dir': '', 124 | 'root_located': True, 125 | } 126 | } 127 | 128 | 129 | @click.command() 130 | @click.argument('value', default=None, nargs=-1) 131 | def main(value): 132 | downloader = Setup() 133 | values = value if value else SOURCES.keys() 134 | downloader.setup(values) 135 | 136 | 137 | if __name__ == '__main__': 138 | main() 139 | -------------------------------------------------------------------------------- /SimilarDomains/editing/__init__.py: -------------------------------------------------------------------------------- 1 | from .styleflow.editor import StyleFlowEditor -------------------------------------------------------------------------------- /SimilarDomains/editing/interfacegan_directions/age.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/SimilarDomains/editing/interfacegan_directions/age.pt -------------------------------------------------------------------------------- /SimilarDomains/editing/interfacegan_directions/gender.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/SimilarDomains/editing/interfacegan_directions/gender.pt -------------------------------------------------------------------------------- /SimilarDomains/editing/interfacegan_directions/rotation.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/SimilarDomains/editing/interfacegan_directions/rotation.pt -------------------------------------------------------------------------------- /SimilarDomains/editing/interfacegan_directions/smile.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/SimilarDomains/editing/interfacegan_directions/smile.pt -------------------------------------------------------------------------------- /SimilarDomains/editing/latent_editor_wrapper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import argparse 5 | 6 | 7 | class LatentEditor: 8 | def __init__(self): 9 | self.interfacegan_directions = { 10 | "age": "editings/interfacegan_directions/age.pt", 11 | "smile": "editings/interfacegan_directions/smile.pt", 12 | "rotation": "editings/interfacegan_directions/rotation.pt", 13 | } 14 | 15 | self.interfacegan_directions_tensors = { 16 | name: torch.load(path).cuda() 17 | for name, path in self.interfacegan_directions.items() 18 | } 19 | 20 | def get_single_interface_gan_edits_with_direction( 21 | self, start_w, factors, direction 22 | ): 23 | latents_to_display = [] 24 | for factor in factors: 25 | latents_to_display.append( 26 | self.apply_interfacegan( 27 | start_w, self.interfacegan_directions_tensors[direction], factor / 2 28 | ) 29 | ) 30 | return latents_to_display 31 | 32 | def apply_interfacegan(self, latent, direction, factor=1, factor_range=None): 33 | edit_latents = [] 34 | if factor_range is not None: # Apply a range of editing factors. for example, (-5, 5) 35 | for f in range(*factor_range): 36 | edit_latent = latent + f * direction 37 | edit_latents.append(edit_latent) 38 | edit_latents = torch.cat(edit_latents) 39 | else: 40 | edit_latents = latent + factor * direction 41 | return edit_latents -------------------------------------------------------------------------------- /SimilarDomains/editing/styleflow/cnf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchdiffeq import odeint_adjoint 4 | from torchdiffeq import odeint as odeint_normal 5 | 6 | __all__ = ["CNF", "SequentialFlow"] 7 | 8 | 9 | class SequentialFlow(nn.Module): 10 | """A generalized nn.Sequential container for normalizing flows.""" 11 | 12 | def __init__(self, layer_list): 13 | super(SequentialFlow, self).__init__() 14 | self.chain = nn.ModuleList(layer_list) 15 | 16 | def forward(self, x, context, logpx=None, reverse=False, inds=None, integration_times=None): 17 | if inds is None: 18 | if reverse: 19 | inds = range(len(self.chain) - 1, -1, -1) 20 | else: 21 | inds = range(len(self.chain)) 22 | 23 | if logpx is None: 24 | for i in inds: 25 | # print(x.shape) 26 | x = self.chain[i](x, context, logpx, integration_times, reverse) 27 | return x 28 | else: 29 | for i in inds: 30 | x, logpx = self.chain[i](x, context, logpx, integration_times, reverse) 31 | return x, logpx 32 | 33 | 34 | class CNF(nn.Module): 35 | def __init__(self, odefunc, conditional=True, T=1.0, train_T=False, regularization_fns=None, 36 | solver='dopri5', atol=1e-5, rtol=1e-5, use_adjoint=True): 37 | super(CNF, self).__init__() 38 | self.train_T = train_T 39 | self.T = T 40 | if train_T: 41 | self.register_parameter("sqrt_end_time", nn.Parameter(torch.sqrt(torch.tensor(T)))) 42 | print("Training T :", self.T) 43 | 44 | if regularization_fns is not None and len(regularization_fns) > 0: 45 | raise NotImplementedError("Regularization not supported") 46 | self.use_adjoint = use_adjoint 47 | self.odefunc = odefunc 48 | self.solver = solver 49 | self.atol = atol 50 | self.rtol = rtol 51 | self.test_solver = solver 52 | self.test_atol = atol 53 | self.test_rtol = rtol 54 | self.solver_options = {} 55 | self.conditional = conditional 56 | 57 | def forward(self, x, context=None, logpx=None, integration_times=None, reverse=False): 58 | if logpx is None: 59 | _logpx = torch.zeros(*x.shape[:-1], 1).to(x) 60 | else: 61 | _logpx = logpx 62 | 63 | if self.conditional: 64 | assert context is not None 65 | states = (x, _logpx, context) 66 | atol = [self.atol] * 3 67 | rtol = [self.rtol] * 3 68 | else: 69 | states = (x, _logpx) 70 | atol = [self.atol] * 2 71 | rtol = [self.rtol] * 2 72 | 73 | if integration_times is None: 74 | if self.train_T: 75 | integration_times = torch.stack( 76 | [torch.tensor(0.0).to(x), self.sqrt_end_time * self.sqrt_end_time] 77 | ).to(x) 78 | # print("integration times:", integration_times) 79 | else: 80 | integration_times = torch.tensor([0., self.T], requires_grad=False).to(x) 81 | 82 | if reverse: 83 | integration_times = _flip(integration_times, 0) 84 | 85 | # Refresh the odefunc statistics. 86 | self.odefunc.before_odeint() 87 | odeint = odeint_adjoint if self.use_adjoint else odeint_normal 88 | if self.training: 89 | state_t = odeint( 90 | self.odefunc, 91 | states, 92 | integration_times.to(x), 93 | atol=atol, 94 | rtol=rtol, 95 | method=self.solver, 96 | options=self.solver_options, 97 | ) 98 | else: 99 | state_t = odeint( 100 | self.odefunc, 101 | states, 102 | integration_times.to(x), 103 | atol=self.test_atol, 104 | rtol=self.test_rtol, 105 | method=self.test_solver, 106 | ) 107 | 108 | if len(integration_times) == 2: 109 | 110 | state_t = tuple(s[1] for s in state_t) 111 | 112 | 113 | 114 | z_t, logpz_t = state_t[:2] 115 | 116 | if logpx is not None: 117 | return z_t, logpz_t 118 | else: 119 | return z_t 120 | 121 | def num_evals(self): 122 | return self.odefunc._num_evals.item() 123 | 124 | 125 | def _flip(x, dim): 126 | indices = [slice(None)] * x.dim() 127 | indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) 128 | return x[tuple(indices)] 129 | -------------------------------------------------------------------------------- /SimilarDomains/editing/styleflow/diffeq_layers.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def weights_init(m): 7 | classname = m.__class__.__name__ 8 | if classname.find('Linear') != -1 or classname.find('Conv') != -1: 9 | nn.init.constant_(m.weight, 0) 10 | nn.init.normal_(m.bias, 0, 0.01) 11 | 12 | 13 | class IgnoreLinear(nn.Module): 14 | def __init__(self, dim_in, dim_out, dim_c): 15 | super(IgnoreLinear, self).__init__() 16 | self._layer = nn.Linear(dim_in, dim_out) 17 | 18 | def forward(self, context, x): 19 | return self._layer(x) 20 | 21 | 22 | class ConcatLinear(nn.Module): 23 | def __init__(self, dim_in, dim_out, dim_c): 24 | super(ConcatLinear, self).__init__() 25 | self._layer = nn.Linear(dim_in + 1 + dim_c, dim_out) 26 | 27 | def forward(self, context, x, c): 28 | if x.dim() == 3: 29 | context = context.unsqueeze(1).expand(-1, x.size(1), -1) 30 | x_context = torch.cat((x, context), dim=2) 31 | return self._layer(x_context) 32 | 33 | 34 | class ConcatLinear_v2(nn.Module): 35 | def __init__(self, dim_in, dim_out, dim_c): 36 | super(ConcatLinear_v2, self).__init__() 37 | self._layer = nn.Linear(dim_in, dim_out) 38 | self._hyper_bias = nn.Linear(1 + dim_c, dim_out, bias=False) 39 | 40 | def forward(self, context, x): 41 | bias = self._hyper_bias(context) 42 | if x.dim() == 3: 43 | bias = bias.unsqueeze(1) 44 | return self._layer(x) + bias 45 | 46 | 47 | class SquashLinear(nn.Module): 48 | def __init__(self, dim_in, dim_out, dim_c): 49 | super(SquashLinear, self).__init__() 50 | self._layer = nn.Linear(dim_in, dim_out) 51 | self._hyper = nn.Linear(1 + dim_c, dim_out) 52 | 53 | def forward(self, context, x): 54 | gate = torch.sigmoid(self._hyper(context)) 55 | if x.dim() == 3: 56 | gate = gate.unsqueeze(1) 57 | return self._layer(x) * gate 58 | 59 | 60 | class ScaleLinear(nn.Module): 61 | def __init__(self, dim_in, dim_out, dim_c): 62 | super(ScaleLinear, self).__init__() 63 | self._layer = nn.Linear(dim_in, dim_out) 64 | self._hyper = nn.Linear(1 + dim_c, dim_out) 65 | 66 | def forward(self, context, x): 67 | gate = self._hyper(context) 68 | if x.dim() == 3: 69 | gate = gate.unsqueeze(1) 70 | return self._layer(x) * gate 71 | 72 | 73 | class ConcatSquashLinear(nn.Module): 74 | def __init__(self, dim_in, dim_out, dim_c): 75 | super(ConcatSquashLinear, self).__init__() 76 | self._layer = nn.Linear(dim_in, dim_out) 77 | self._hyper_bias = nn.Linear(1 + dim_c, dim_out, bias=False) 78 | self._hyper_gate = nn.Linear(1 + dim_c, dim_out) 79 | 80 | def forward(self, context, x): 81 | gate = torch.sigmoid(self._hyper_gate(context)) 82 | bias = self._hyper_bias(context) 83 | if x.dim() == 3: 84 | gate = gate.unsqueeze(1) 85 | bias = bias.unsqueeze(1) 86 | ret = self._layer(x) * gate + bias 87 | return ret 88 | 89 | 90 | class ConcatScaleLinear(nn.Module): 91 | def __init__(self, dim_in, dim_out, dim_c): 92 | super(ConcatScaleLinear, self).__init__() 93 | self._layer = nn.Linear(dim_in, dim_out) 94 | self._hyper_bias = nn.Linear(1 + dim_c, dim_out, bias=False) 95 | self._hyper_gate = nn.Linear(1 + dim_c, dim_out) 96 | 97 | def forward(self, context, x): 98 | gate = self._hyper_gate(context) 99 | bias = self._hyper_bias(context) 100 | if x.dim() == 3: 101 | gate = gate.unsqueeze(1) 102 | bias = bias.unsqueeze(1) 103 | ret = self._layer(x) * gate + bias 104 | return ret 105 | -------------------------------------------------------------------------------- /SimilarDomains/editing/styleflow/editor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import numpy as np 4 | 5 | from pathlib import Path 6 | from .flow import build_styleflow_model 7 | 8 | 9 | class StyleFlowEditor: 10 | keep_indexes = [ 11 | 2, 5, 25, 28, 16, 32, 33, 34, 55, 75, 79, 162, 177, 196, 12 | 160, 212, 246, 285, 300, 329, 362, 369, 462, 460, 478, 13 | 551, 583, 643, 879, 852, 914, 999, 976, 627, 844, 237, 52, 301, 599 14 | ] 15 | attr_order = ['Gender', 'Glasses', 'Yaw', 'Pitch', 'Baldness', 'Beard', 'Age', 'Expression'] 16 | lighting_order = ['Left->Right', 'Right->Left', 'Down->Up', 'Up->Down', 'No light', 'Front light'] 17 | attr_degree_list = [1.5, 2.5, 1., 1., 2, 1.7,0.93, 1.] 18 | min_dic = { 19 | 'Gender': 0, 'Glasses': 0, 'Yaw': -20, 'Pitch': -20, 20 | 'Baldness': 0, 'Beard': 0.0, 'Age': 0, 'Expression': 0 21 | } 22 | max_dic = { 23 | 'Gender': 1, 'Glasses': 1, 'Yaw': 20, 'Pitch': 20, 24 | 'Baldness': 1, 'Beard': 1, 'Age': 65, 'Expression': 1 25 | } 26 | 27 | def __init__( 28 | self, data_path, weight_path, device 29 | ): 30 | 31 | self.model = build_styleflow_model(weight_path, device) 32 | self.device = device 33 | 34 | raw_w = pickle.load(open(Path(data_path) / 'sg2latents.pickle', 'rb')) 35 | raw_attr = np.load(Path(data_path) / 'attributes.npy') 36 | raw_lights = np.load(Path(data_path) / 'light.npy') 37 | 38 | self.all_w = np.array(raw_w['Latent'])[self.keep_indexes] 39 | self.all_attr = raw_attr[self.keep_indexes] 40 | self.all_lights = raw_lights[self.keep_indexes] 41 | 42 | self.zero_padding = torch.zeros(1, 18, 1).to(self.device) 43 | self.gap_dic = {i: StyleFlowEditor.max_dic[i] - StyleFlowEditor.min_dic[i] for i in StyleFlowEditor.max_dic} 44 | 45 | def _allocate_entity(self, idx): 46 | self.w_current_np = self.all_w[idx].copy() 47 | self.attr_current = self.all_attr[idx].copy() 48 | self.light_current = self.all_lights[idx].copy() 49 | 50 | self.attr_current_list = [self.attr_current[i][0] for i in range(len(self.attr_order))] 51 | self.light_current_list = [0 for i in range(len(self.lighting_order))] 52 | 53 | array_source = torch.from_numpy(self.attr_current).type(torch.FloatTensor).to(self.device) 54 | array_light = torch.from_numpy(self.light_current).type(torch.FloatTensor).to(self.device) 55 | self.final_array_target_ = torch.cat([array_light, array_source.unsqueeze(0).unsqueeze(-1)], dim=1) 56 | self.initial_w = torch.from_numpy(self.w_current_np).to(self.device) 57 | 58 | def _invert_to_real(self, name, edit_power): 59 | return float(float(edit_power) * self.gap_dic[name] + StyleFlowEditor.min_dic[name]) 60 | 61 | def get_edited_pair(self, attr_idx, edit_power): 62 | fws = self.model(self.initial_w, self.final_array_target_, self.zero_padding) 63 | 64 | real_value = self._invert_to_real(StyleFlowEditor.attr_order[attr_idx], edit_power) 65 | attr_change = real_value - self.attr_current_list[attr_idx] 66 | attr_final = StyleFlowEditor.attr_degree_list[attr_idx] * attr_change + self.attr_current_list[attr_idx] 67 | 68 | final_array_target = self.final_array_target_.clone() 69 | final_array_target[0, attr_idx + 9, 0, 0] = attr_final 70 | 71 | rev = self.model(fws[0], final_array_target, self.zero_padding, True) 72 | 73 | if attr_idx == 0: 74 | rev[0][0][8:] = self.initial_w[0][8:] 75 | elif attr_idx == 1: 76 | rev[0][0][:2] = self.initial_w[0][:2] 77 | rev[0][0][4:] = self.initial_w[0][4:] 78 | elif attr_idx == 2: 79 | rev[0][0][4:] = self.initial_w[0][4:] 80 | elif attr_idx == 3: 81 | rev[0][0][4:] = self.initial_w[0][4:] 82 | elif attr_idx == 4: 83 | rev[0][0][6:] = self.initial_w[0][6:] 84 | elif attr_idx == 5: 85 | rev[0][0][:5] = self.initial_w[0][:5] 86 | rev[0][0][10:] = self.initial_w[0][10:] 87 | elif attr_idx == 6: 88 | rev[0][0][0:4] = self.initial_w[0][0:4] 89 | rev[0][0][8:] = self.initial_w[0][8:] 90 | elif attr_idx == 7: 91 | rev[0][0][:4] = self.initial_w[0][:4] 92 | rev[0][0][6:] = self.initial_w[0][6:] 93 | 94 | return self.initial_w, rev[0].clone() -------------------------------------------------------------------------------- /SimilarDomains/editing/styleflow/flow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | from .odefunc import ODEfunc, ODEnet 5 | from .normalization import MovingBatchNorm1d 6 | from .cnf import CNF, SequentialFlow 7 | 8 | 9 | def count_nfe(model): 10 | class AccNumEvals(object): 11 | 12 | def __init__(self): 13 | self.num_evals = 0 14 | 15 | def __call__(self, module): 16 | if isinstance(module, CNF): 17 | self.num_evals += module.num_evals() 18 | 19 | accumulator = AccNumEvals() 20 | model.apply(accumulator) 21 | return accumulator.num_evals 22 | 23 | 24 | def count_parameters(model): 25 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 26 | 27 | 28 | def count_total_time(model): 29 | class Accumulator(object): 30 | 31 | def __init__(self): 32 | self.total_time = 0 33 | 34 | def __call__(self, module): 35 | if isinstance(module, CNF): 36 | self.total_time = self.total_time + module.sqrt_end_time * module.sqrt_end_time 37 | 38 | accumulator = Accumulator() 39 | model.apply(accumulator) 40 | return accumulator.total_time 41 | 42 | 43 | def build_model( input_dim, hidden_dims, context_dim, num_blocks, conditional): 44 | def build_cnf(): 45 | diffeq = ODEnet( 46 | hidden_dims=hidden_dims, 47 | input_shape=(input_dim,), 48 | context_dim=context_dim, 49 | layer_type='concatsquash', 50 | nonlinearity='tanh', 51 | ) 52 | odefunc = ODEfunc( 53 | diffeq=diffeq, 54 | ) 55 | cnf = CNF( 56 | odefunc=odefunc, 57 | T=1.0, 58 | train_T=True, 59 | conditional=conditional, 60 | solver='dopri5', 61 | use_adjoint=True, 62 | atol=1e-5, 63 | rtol=1e-5, 64 | ) 65 | return cnf 66 | 67 | chain = [build_cnf() for _ in range(num_blocks)] 68 | bn_layers = [MovingBatchNorm1d(input_dim, bn_lag=0, sync=False) 69 | for _ in range(num_blocks)] 70 | bn_chain = [MovingBatchNorm1d(input_dim, bn_lag=0, sync=False)] 71 | for a, b in zip(chain, bn_layers): 72 | bn_chain.append(a) 73 | bn_chain.append(b) 74 | chain = bn_chain 75 | model = SequentialFlow(chain) 76 | 77 | return model 78 | 79 | 80 | def build_styleflow_model( 81 | ckpt_path, device, 82 | input_dim = 512, 83 | dims = '512-512-512-512-512', 84 | zdim = 17, 85 | num_blocks = 1 86 | ): 87 | dims = tuple(map(int, dims.split("-"))) 88 | model = build_model(input_dim, dims, zdim, num_blocks, True).cuda() 89 | print("Number of trainable parameters of Point CNF: {}".format(count_parameters(model))) 90 | 91 | model.load_state_dict(torch.load(ckpt_path, map_location='cpu')) 92 | model.to(device).eval() 93 | 94 | return model 95 | 96 | 97 | -------------------------------------------------------------------------------- /SimilarDomains/examples/draw_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | from pathlib import Path 6 | from copy import deepcopy 7 | from collections import OrderedDict 8 | 9 | 10 | class IdentityEditor: 11 | def __call__(self, style_space, **kwargs): 12 | return style_space 13 | 14 | def __add__(self, style_editor, **kwargs): 15 | new_editor = deepcopy(style_editor) 16 | return new_editor 17 | 18 | 19 | class StyleEditor: 20 | offsetconv_to_style = dict([ 21 | (0, 0), 22 | (1, 2), 23 | (2, 3), 24 | (3, 5), 25 | (4, 6), 26 | (5, 8), 27 | (6, 9), 28 | (7, 11), 29 | (8, 12), 30 | (9, 14), 31 | (10, 15), 32 | (11, 17), 33 | (12, 18), 34 | (13, 20), 35 | (14, 21), 36 | (15, 23), 37 | (16, 24) 38 | ]) 39 | 40 | def __init__(self, ckpt=None, device='cuda:0', img_size=1024): 41 | self.device = device 42 | if ckpt is None: 43 | return 44 | self._construct_from_ckpt(ckpt) 45 | 46 | def _construct_from_ckpt(self, ckpt): 47 | last_layer_index = max([int(k.split('.')[1].split('_')[-1]) for k in ckpt['state_dict'].keys()]) 48 | 49 | self.shifts = { 50 | StyleEditor.offsetconv_to_style[off_idx]: ckpt['state_dict'][f'heads.conv_{off_idx}.params_in'].to(self.device) 51 | for off_idx in range(last_layer_index) 52 | } 53 | 54 | def __call__(self, stspace, power=1.): 55 | answer = {} 56 | for idx, value in enumerate(stspace): 57 | if idx in self.shifts: 58 | answer[idx] = stspace[idx].clone() + power * self.shifts[idx] 59 | else: 60 | answer[idx] = stspace[idx].clone() 61 | 62 | return list(answer.values()) 63 | 64 | def __mul__(self, alpha): 65 | answer = deepcopy(self) 66 | for st_idx in answer.shifts: 67 | answer.shifts[st_idx] = self.shifts[st_idx] * alpha 68 | return answer 69 | 70 | def __add__(self, other): 71 | answer = deepcopy(self) 72 | for st_idx in other.shifts: 73 | answer.shifts[st_idx] = self.shifts[st_idx] + other.shifts[st_idx] 74 | return answer 75 | 76 | def to(self, device): 77 | for st_idx in self.shifts: 78 | self.shifts[st_idx] = self.shifts[st_idx].to(device) 79 | self.device = device 80 | return self 81 | 82 | 83 | def set_seed(seed): 84 | random.seed(seed) 85 | torch.random.manual_seed(seed) 86 | np.random.seed(seed) 87 | 88 | 89 | def morph_g_ema(ckpt1, ckpt2, alpha): 90 | final_ckpt = OrderedDict() 91 | 92 | for key in ckpt1['g_ema']: 93 | final_ckpt[key] = alpha * ckpt1['g_ema'][key] + (1 - alpha) * ckpt2['g_ema'][key] 94 | 95 | return {'g_ema': final_ckpt} 96 | 97 | 98 | w_style_pair = [ 99 | (0, 0), 100 | (1, 1), 101 | (1, 2), 102 | (2, 3), 103 | (3, 4), 104 | (3, 5), 105 | (4, 6), 106 | (5, 7), 107 | (5, 8), 108 | (6, 9), 109 | (7, 10), 110 | (7, 11), 111 | (8, 12), 112 | (9, 13), 113 | (9, 14), 114 | (10, 15), 115 | (11, 16), 116 | (11, 17), 117 | (12, 18), 118 | (13, 19), 119 | (13, 20), 120 | (14, 21), 121 | (15, 22), 122 | (15, 23), 123 | (16, 24), 124 | (17, 25) 125 | ] 126 | 127 | 128 | p_root = Path(__file__).resolve().parent.parent / 'pretrained' 129 | 130 | weights = {p.name.rsplit('_', 1)[0]: p for p in (p_root / 'checkpoints_iccv').iterdir()} 131 | 132 | weights.update({p.name.split('-')[1]: p for p in (p_root / 'StyleGAN2').iterdir()}) 133 | weights.update({'_'.join(p.stem.split('_')[1:3]): p for p in (p_root / 'StyleGAN2_ADA').iterdir()}) 134 | -------------------------------------------------------------------------------- /SimilarDomains/examples/photos/elon_musk.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/SimilarDomains/examples/photos/elon_musk.jpeg -------------------------------------------------------------------------------- /SimilarDomains/examples/pruned_forward.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "cf00dd04-7abd-4684-9d21-172b9224bb24", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%cd ../" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "f5fa116c-6584-4e7a-8c48-3be0505a7c6f", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "# import dnnlib\n", 21 | "import pickle as pkl\n", 22 | "import torch\n", 23 | "import torch.nn as nn\n", 24 | "import torch.nn.functional as F\n", 25 | "import matplotlib.pyplot as plt\n", 26 | "import numpy as np\n", 27 | "import random\n", 28 | "\n", 29 | "from PIL import Image\n", 30 | "from pathlib import Path\n", 31 | "from torchvision.transforms import Resize\n", 32 | "\n", 33 | "from core.utils.example_utils import Inferencer, vstack_with_lines, hstack_with_lines, to_im\n", 34 | "from core.utils.image_utils import construct_paper_image_grid\n", 35 | "from core.utils.reading_weights import read_weights\n", 36 | "from core.uda_models import OffsetsTunningGenerator\n", 37 | "from core.sparse_models import SparsedModel\n", 38 | "\n", 39 | "from examples.draw_util import weights" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "id": "f57bd469-21ca-48b1-83de-2e0c3ea6e0a2", 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "device = 'cuda:0'\n", 50 | "\n", 51 | "g = OffsetsTunningGenerator(\n", 52 | " checkpoint_path='pretrained/StyleGAN2/stylegan2-ffhq-config-f.pt'\n", 53 | ").patch_layers('s_delta').to(device)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "id": "12acc409-0bcf-4f97-a34c-a237cbba8f6b", 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "id": "aca3a6b8-e539-4e1d-8326-7b605408f4c5", 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "percentiles = [0.7, 0.8, 0.9, 0.9, 0.95]\n", 72 | "\n", 73 | "\n", 74 | "domain = 'sketch'\n", 75 | "bs = 4\n", 76 | "truncation = 0.8\n", 77 | "\n", 78 | "model = SparsedModel(device, read_weights(weights[domain]))\n", 79 | "z = [torch.randn(bs, 512).to(device)]\n", 80 | "resize = Resize(256)\n", 81 | "\n", 82 | "\n", 83 | "images = []\n", 84 | "for perc in percentiles:\n", 85 | " offsets = model.pruned_offsets(perc)\n", 86 | " im, _ = g(z, offsets=offsets, truncation=truncation)\n", 87 | " images.append(to_im(resize(im.detach()), padding=0))\n", 88 | " \n", 89 | " \n", 90 | "orig_ims, _ = g(z, truncation=truncation)\n", 91 | "images.append(to_im(resize(orig_ims.detach()), padding=0))" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "id": "fe83b74f-e258-4af7-aa56-cbdb5e76da71", 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "ext = 2\n", 102 | "\n", 103 | "plt.figure(figsize=(bs * ext, (len(percentiles) + 1) * ext))\n", 104 | "plt.imshow(vstack_with_lines(images, 10))\n", 105 | "\n", 106 | "plt.xticks(\n", 107 | " np.arange(128, bs * 256, 256), \n", 108 | " labels=[f\"id {k}\" for k in range(bs)]\n", 109 | ")\n", 110 | "\n", 111 | "\n", 112 | "plt.yticks(\n", 113 | " np.arange(128, len(percentiles) * (256 + 10) + 256, 256 + 10),\n", 114 | " labels=[f\"{p * 100}% pruned\" for p in percentiles] + ['Original']\n", 115 | ")\n", 116 | "\n", 117 | "\n", 118 | "# plt.axis('off')\n", 119 | "plt.show()" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "id": "202c8636-7fb0-4ef8-9aed-39c6fc90fb4f", 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "id": "59ab9e67-17e2-4457-9153-e6eb250d365a", 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "id": "52fba3a2-7249-4a9a-80f2-3fbf4b12d337", 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [] 145 | } 146 | ], 147 | "metadata": { 148 | "kernelspec": { 149 | "display_name": "Python 3 (ipykernel)", 150 | "language": "python", 151 | "name": "python3" 152 | }, 153 | "language_info": { 154 | "codemirror_mode": { 155 | "name": "ipython", 156 | "version": 3 157 | }, 158 | "file_extension": ".py", 159 | "mimetype": "text/x-python", 160 | "name": "python", 161 | "nbconvert_exporter": "python", 162 | "pygments_lexer": "ipython3", 163 | "version": "3.7.12" 164 | } 165 | }, 166 | "nbformat": 4, 167 | "nbformat_minor": 5 168 | } 169 | -------------------------------------------------------------------------------- /SimilarDomains/gan_models/BigGAN/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/SimilarDomains/gan_models/BigGAN/__init__.py -------------------------------------------------------------------------------- /SimilarDomains/gan_models/BigGAN/generator_config.json: -------------------------------------------------------------------------------- 1 | {"num_epochs": 100, 2 | "BN_eps": 1e-05, 3 | "G_batch_size": 512, 4 | "sv_log_interval": 10, 5 | "shuffle": true, 6 | "batch_size": 256, 7 | "G_mixed_precision": false, 8 | "toggle_grads": true, 9 | "mybn": false, 10 | "augment": false, 11 | "D_B2": 0.999, 12 | "D_attn": "64", 13 | "log_G_spectra": false, 14 | "G_shared": true, 15 | "num_D_steps": 1, 16 | "num_best_copies": 5, 17 | "load_in_mem": false, 18 | "split_D": false, 19 | "sample_npz": true, 20 | "D_B1": 0.0, 21 | "cross_replica": false, 22 | "SN_eps": 1e-06, 23 | "G_lr": 0.0001, 24 | "num_G_SV_itrs": 1, 25 | "pin_memory": true, 26 | "D_mixed_precision": false, 27 | "num_G_SVs": 1, 28 | "G_fp16": false, 29 | "sample_interps": true, 30 | "test_every": 2000, 31 | "sample_random": true, 32 | "num_D_SV_itrs": 1, 33 | "config_from_name": false, 34 | "G_eval_mode": true, 35 | "D_nl": "inplace_relu", 36 | "G_param": "SN", 37 | "num_inception_images": 50000, 38 | "save_every": 1000, 39 | "D_lr": 0.0004, 40 | "sample_inception_metrics": true, 41 | "G_attn": "64", 42 | "G_depth": 1, 43 | "which_train_fn": "GAN", 44 | "norm_style": "bn", 45 | "sample_num_npz": 50000, 46 | "hashname": false, 47 | "sample_sheet_folder_num": -1, 48 | "resume": false, 49 | "D_ortho": 0.0, 50 | "ema_start": 20000, 51 | "num_workers": 8, 52 | "dataset": "I128_hdf5", 53 | "ema": true, 54 | "num_D_accumulations": 8, 55 | "no_fid": false, 56 | "D_fp16": false, 57 | "G_init": "ortho", 58 | "D_init": "ortho", 59 | "D_ch": 96, 60 | "dim_z": 120, 61 | "D_wide": true, 62 | "accumulate_stats": false, 63 | "num_D_SVs": 1, 64 | "G_B1": 0.0, 65 | "use_ema": true, 66 | "pbar": "mine", 67 | "sample_trunc_curves": "0.05_0.05_1.0", 68 | "use_multiepoch_sampler": true, 69 | "num_G_accumulations": 8, 70 | "G_ch": 96, 71 | "G_B2": 0.999, 72 | "D_depth": 1, 73 | "D_param": "SN", 74 | "G_ortho": 0.0, 75 | "seed": 0, 76 | "log_D_spectra": false, 77 | "num_save_copies": 2, 78 | "hier": true, 79 | "G_nl": "inplace_relu", 80 | "skip_init": true, 81 | "sample_sheets": true, 82 | "z_var": 1.0, 83 | "adam_eps": 1e-06, 84 | "experiment_name": "", 85 | "ema_decay": 0.9999, 86 | "model": "BigGAN", 87 | "shared_dim": 128, 88 | "which_best": "IS", 89 | "parallel": true, 90 | "num_standing_accumulations": 16 91 | } -------------------------------------------------------------------------------- /SimilarDomains/gan_models/BigGAN/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import ( 12 | SynchronizedBatchNorm1d, 13 | SynchronizedBatchNorm2d, 14 | SynchronizedBatchNorm3d, 15 | ) 16 | from .replicate import DataParallelWithCallback, patch_replication_callback 17 | -------------------------------------------------------------------------------- /SimilarDomains/gan_models/BigGAN/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ["BatchNormReimpl"] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | 28 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 29 | super().__init__() 30 | 31 | self.num_features = num_features 32 | self.eps = eps 33 | self.momentum = momentum 34 | self.weight = nn.Parameter(torch.empty(num_features)) 35 | self.bias = nn.Parameter(torch.empty(num_features)) 36 | self.register_buffer("running_mean", torch.zeros(num_features)) 37 | self.register_buffer("running_var", torch.ones(num_features)) 38 | self.reset_parameters() 39 | 40 | def reset_running_stats(self): 41 | self.running_mean.zero_() 42 | self.running_var.fill_(1) 43 | 44 | def reset_parameters(self): 45 | self.reset_running_stats() 46 | init.uniform_(self.weight) 47 | init.zeros_(self.bias) 48 | 49 | def forward(self, input_): 50 | batchsize, channels, height, width = input_.size() 51 | numel = batchsize * height * width 52 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 53 | sum_ = input_.sum(1) 54 | sum_of_square = input_.pow(2).sum(1) 55 | mean = sum_ / numel 56 | sumvar = sum_of_square - sum_ * mean 57 | 58 | self.running_mean = ( 59 | 1 - self.momentum 60 | ) * self.running_mean + self.momentum * mean.detach() 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | 1 - self.momentum 64 | ) * self.running_var + self.momentum * unbias_var.detach() 65 | 66 | bias_var = sumvar / numel 67 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 68 | output = (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze( 69 | 1 70 | ) * self.weight.unsqueeze(1) + self.bias.unsqueeze(1) 71 | 72 | return ( 73 | output.view(channels, batchsize, height, width) 74 | .permute(1, 0, 2, 3) 75 | .contiguous() 76 | ) 77 | -------------------------------------------------------------------------------- /SimilarDomains/gan_models/BigGAN/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | "CallbackContext", 17 | "execute_replication_callbacks", 18 | "DataParallelWithCallback", 19 | "patch_replication_callback", 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, "__data_parallel_replicate__"): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /SimilarDomains/gan_models/BigGAN/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = "NaN" 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ("Tensor close check failed\n" "adiff={}\n" "rdiff={}\n").format( 24 | adiff, rdiff 25 | ) 26 | self.assertTrue(torch.allclose(x, y), message) 27 | -------------------------------------------------------------------------------- /SimilarDomains/gan_models/BigGAN/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import print_function 5 | import torch.nn as nn 6 | 7 | # Convenience dicts 8 | imsize_dict = { 9 | "I32": 32, 10 | "I32_hdf5": 32, 11 | "I64": 64, 12 | "I64_hdf5": 64, 13 | "I128": 128, 14 | "I128_hdf5": 128, 15 | "I256": 256, 16 | "I256_hdf5": 256, 17 | "C10": 32, 18 | "C100": 32, 19 | } 20 | nclass_dict = { 21 | "I32": 1000, 22 | "I32_hdf5": 1000, 23 | "I64": 1000, 24 | "I64_hdf5": 1000, 25 | "I128": 1000, 26 | "I128_hdf5": 1000, 27 | "I256": 1000, 28 | "I256_hdf5": 1000, 29 | "C10": 10, 30 | "C100": 100, 31 | } 32 | activation_dict = { 33 | "inplace_relu": nn.ReLU(inplace=True), 34 | "relu": nn.ReLU(inplace=False), 35 | "ir": nn.ReLU(inplace=True), 36 | } 37 | -------------------------------------------------------------------------------- /SimilarDomains/gan_models/ProgGAN/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/SimilarDomains/gan_models/ProgGAN/__init__.py -------------------------------------------------------------------------------- /SimilarDomains/gan_models/ProgGAN/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | This work is based on the Theano/Lasagne implementation of 5 | Progressive Growing of GANs paper from tkarras: 6 | https://github.com/tkarras/progressive_growing_of_gans 7 | 8 | PyTorch Model definition 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from collections import OrderedDict 16 | 17 | 18 | class PixelNormLayer(nn.Module): 19 | def __init__(self): 20 | super(PixelNormLayer, self).__init__() 21 | 22 | def forward(self, x): 23 | return x / torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8) 24 | 25 | 26 | class WScaleLayer(nn.Module): 27 | def __init__(self, size): 28 | super(WScaleLayer, self).__init__() 29 | self.scale = nn.Parameter(torch.randn([1])) 30 | self.b = nn.Parameter(torch.randn(size)) 31 | self.size = size 32 | 33 | def forward(self, x): 34 | x_size = x.size() 35 | x = x * self.scale + self.b.view(1, -1, 1, 1).expand( 36 | x_size[0], self.size, x_size[2], x_size[3]) 37 | 38 | return x 39 | 40 | 41 | class NormConvBlock(nn.Module): 42 | def __init__(self, in_channels, out_channels, kernel_size, padding): 43 | super(NormConvBlock, self).__init__() 44 | self.norm = PixelNormLayer() 45 | self.conv = nn.Conv2d( 46 | in_channels, out_channels, kernel_size, 1, padding, bias=False) 47 | self.wscale = WScaleLayer(out_channels) 48 | 49 | def forward(self, x): 50 | x = self.norm(x) 51 | x = self.conv(x) 52 | x = F.leaky_relu(self.wscale(x), negative_slope=0.2) 53 | return x 54 | 55 | 56 | class NormUpscaleConvBlock(nn.Module): 57 | def __init__(self, in_channels, out_channels, kernel_size, padding): 58 | super(NormUpscaleConvBlock, self).__init__() 59 | self.norm = PixelNormLayer() 60 | self.up = nn.Upsample(scale_factor=2, mode='nearest') 61 | self.conv = nn.Conv2d( 62 | in_channels, out_channels, kernel_size, 1, padding, bias=False) 63 | self.wscale = WScaleLayer(out_channels) 64 | 65 | def forward(self, x): 66 | x = self.norm(x) 67 | x = self.up(x) 68 | x = self.conv(x) 69 | x = F.leaky_relu(self.wscale(x), negative_slope=0.2) 70 | return x 71 | 72 | 73 | class Generator(nn.Module): 74 | def __init__(self): 75 | super(Generator, self).__init__() 76 | 77 | self.features = nn.Sequential( 78 | NormConvBlock(512, 512, kernel_size=4, padding=3), 79 | NormConvBlock(512, 512, kernel_size=3, padding=1), 80 | NormUpscaleConvBlock(512, 512, kernel_size=3, padding=1), 81 | NormConvBlock(512, 512, kernel_size=3, padding=1), 82 | NormUpscaleConvBlock(512, 512, kernel_size=3, padding=1), 83 | NormConvBlock(512, 512, kernel_size=3, padding=1), 84 | NormUpscaleConvBlock(512, 512, kernel_size=3, padding=1), 85 | NormConvBlock(512, 512, kernel_size=3, padding=1), 86 | NormUpscaleConvBlock(512, 256, kernel_size=3, padding=1), 87 | NormConvBlock(256, 256, kernel_size=3, padding=1), 88 | NormUpscaleConvBlock(256, 128, kernel_size=3, padding=1), 89 | NormConvBlock(128, 128, kernel_size=3, padding=1), 90 | NormUpscaleConvBlock(128, 64, kernel_size=3, padding=1), 91 | NormConvBlock(64, 64, kernel_size=3, padding=1), 92 | NormUpscaleConvBlock(64, 32, kernel_size=3, padding=1), 93 | NormConvBlock(32, 32, kernel_size=3, padding=1), 94 | NormUpscaleConvBlock(32, 16, kernel_size=3, padding=1), 95 | NormConvBlock(16, 16, kernel_size=3, padding=1)) 96 | 97 | self.output = nn.Sequential(OrderedDict([ 98 | ('norm', PixelNormLayer()), 99 | ('conv', nn.Conv2d(16, 100 | 3, 101 | kernel_size=1, 102 | padding=0, 103 | bias=False)), 104 | ('wscale', WScaleLayer(3)) 105 | ])) 106 | 107 | def forward(self, x): 108 | x = self.features(x) 109 | x = self.output(x) 110 | return x 111 | -------------------------------------------------------------------------------- /SimilarDomains/gan_models/SNGAN/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/SimilarDomains/gan_models/SNGAN/__init__.py -------------------------------------------------------------------------------- /SimilarDomains/gan_models/SNGAN/distribution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class BaseDistribution(nn.Module): 6 | def __init__(self, dim, device='cuda'): 7 | super(BaseDistribution, self).__init__() 8 | self.device = device 9 | self.dim = dim 10 | 11 | def cuda(self, device=None): 12 | super(BaseDistribution, self).cuda(device) 13 | self.device = 'cuda' if device is None else device 14 | 15 | def cpu(self): 16 | super(BaseDistribution, self).cpu() 17 | self.device='cpu' 18 | 19 | def to(self, device): 20 | super(BaseDistribution, self).to(device) 21 | self.device = device 22 | 23 | def forward(self, batch_size): 24 | raise NotImplementedError 25 | 26 | 27 | class NormalDistribution(BaseDistribution): 28 | def __init__(self, dim): 29 | super(NormalDistribution, self).__init__(dim) 30 | 31 | def forward(self, batch_size): 32 | return torch.randn([batch_size, self.dim]).to(self.device) 33 | -------------------------------------------------------------------------------- /SimilarDomains/gan_models/SNGAN/load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | from gan_models.SNGAN.sn_gen_resnet import SN_RES_GEN_CONFIGS, make_resnet_generator 5 | from gan_models.SNGAN.distribution import NormalDistribution 6 | 7 | 8 | MODELS = { 9 | 'sn_resnet32': 32, 10 | 'sn_resnet64': 64, 11 | } 12 | 13 | 14 | DISTRIBUTIONS = { 15 | 'normal': NormalDistribution, 16 | } 17 | 18 | 19 | class Args: 20 | def __init__(self, **kwargs): 21 | self.nonfixed_noise = False 22 | self.noises_count = 1 23 | self.equal_split = False 24 | self.generator_batch_norm = False 25 | self.gen_sn = False 26 | self.distribution_params = "{}" 27 | 28 | self.__dict__.update(kwargs) 29 | 30 | 31 | def load_model_from_state_dict(root_dir): 32 | args = Args(**json.load(open(os.path.join(root_dir, 'args.json')))) 33 | generator_model_path = os.path.join(root_dir, 'generator.pt') 34 | 35 | try: 36 | image_channels = args.image_channels 37 | except Exception: 38 | image_channels = 3 39 | 40 | gen_config = SN_RES_GEN_CONFIGS[args.model] 41 | generator= make_resnet_generator(gen_config, channels=image_channels, 42 | distribution=NormalDistribution(args.latent_dim), 43 | img_size=MODELS[args.model]) 44 | 45 | generator.load_state_dict( 46 | torch.load(generator_model_path, map_location=torch.device('cpu')), strict=False) 47 | return generator 48 | -------------------------------------------------------------------------------- /SimilarDomains/gan_models/SNGAN/sn_gen_resnet.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torch import nn 4 | import numpy as np 5 | from gan_models.SNGAN.distribution import NormalDistribution 6 | 7 | 8 | ResNetGenConfig = namedtuple('ResNetGenConfig', ['channels', 'seed_dim']) 9 | SN_RES_GEN_CONFIGS = { 10 | 'sn_resnet32': ResNetGenConfig([256, 256, 256, 256], 4), 11 | 'sn_resnet64': ResNetGenConfig([16 * 64, 8 * 64, 4 * 64, 2 * 64, 64], 4), 12 | } 13 | 14 | 15 | class Reshape(nn.Module): 16 | def __init__(self, target_shape): 17 | super(Reshape, self).__init__() 18 | self.target_shape = target_shape 19 | 20 | def forward(self, input): 21 | return input.view(self.target_shape) 22 | 23 | 24 | class ResBlockGenerator(nn.Module): 25 | def __init__(self, in_channels, out_channels): 26 | super(ResBlockGenerator, self).__init__() 27 | 28 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=1) 29 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding=1) 30 | 31 | nn.init.xavier_uniform_(self.conv1.weight.data, np.sqrt(2)) 32 | nn.init.xavier_uniform_(self.conv2.weight.data, np.sqrt(2)) 33 | 34 | self.model = nn.Sequential( 35 | nn.BatchNorm2d(in_channels), 36 | nn.ReLU(inplace=True), 37 | nn.Upsample(scale_factor=2), 38 | self.conv1, 39 | nn.BatchNorm2d(out_channels), 40 | nn.ReLU(inplace=True), 41 | self.conv2 42 | ) 43 | 44 | if in_channels == out_channels: 45 | self.bypass = nn.Upsample(scale_factor=2) 46 | else: 47 | self.bypass = nn.Sequential( 48 | nn.Upsample(scale_factor=2), 49 | nn.Conv2d(in_channels, out_channels, 3, 1, padding=1) 50 | ) 51 | nn.init.xavier_uniform_(self.bypass[1].weight.data, 1.0) 52 | 53 | def forward(self, x): 54 | return self.model(x) + self.bypass(x) 55 | 56 | 57 | class GenWrapper(nn.Module): 58 | def __init__(self, model, out_img_shape, distribution): 59 | super(GenWrapper, self).__init__() 60 | 61 | self.model = model 62 | self.out_img_shape = out_img_shape 63 | self.distribution = distribution 64 | self.force_no_grad = False 65 | 66 | def cuda(self, device=None): 67 | super(GenWrapper, self).cuda(device) 68 | self.distribution.cuda() 69 | 70 | def forward(self, batch_size): 71 | if self.force_no_grad: 72 | with torch.no_grad(): 73 | img = self.model(self.distribution(batch_size)) 74 | else: 75 | img = self.model(self.distribution(batch_size)) 76 | 77 | img = img.view(img.shape[0], *self.out_img_shape) 78 | return img 79 | 80 | 81 | def make_resnet_generator(resnet_gen_config, img_size=128, channels=3, 82 | distribution=NormalDistribution(128)): 83 | def make_dense(): 84 | dense = nn.Linear( 85 | distribution.dim, resnet_gen_config.seed_dim**2 * resnet_gen_config.channels[0]) 86 | nn.init.xavier_uniform_(dense.weight.data, 1.) 87 | return dense 88 | 89 | def make_final(): 90 | final = nn.Conv2d(resnet_gen_config.channels[-1], channels, 3, stride=1, padding=1) 91 | nn.init.xavier_uniform_(final.weight.data, 1.) 92 | return final 93 | 94 | model_channels = resnet_gen_config.channels 95 | 96 | input_layers = [ 97 | make_dense(), 98 | Reshape([-1, model_channels[0], 4, 4]) 99 | ] 100 | res_blocks = [ 101 | ResBlockGenerator(model_channels[i], model_channels[i + 1]) 102 | for i in range(len(model_channels) - 1) 103 | ] 104 | out_layers = [ 105 | nn.BatchNorm2d(model_channels[-1]), 106 | nn.ReLU(inplace=True), 107 | make_final(), 108 | nn.Tanh() 109 | ] 110 | 111 | model = nn.Sequential(*(input_layers + res_blocks + out_layers)) 112 | 113 | return GenWrapper(model, [channels, img_size, img_size], distribution) 114 | -------------------------------------------------------------------------------- /SimilarDomains/gan_models/StyleGAN2/op/__init__.py: -------------------------------------------------------------------------------- 1 | # try: 2 | # from .fused_act import FusedLeakyReLU, fused_leaky_relu 3 | # from .upfirdn2d import upfirdn2d 4 | # except: 5 | from .upfirdn2d_torch_native import upfirdn2d 6 | from .fused_act_torch_native import FusedLeakyReLU, fused_leaky_relu -------------------------------------------------------------------------------- /SimilarDomains/gan_models/StyleGAN2/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | fused = load( 11 | 'fused', 12 | sources=[ 13 | os.path.join(module_path, 'fused_bias_act.cpp'), 14 | os.path.join(module_path, 'fused_bias_act_kernel.cu'), 15 | ], 16 | ) 17 | 18 | 19 | class FusedLeakyReLUFunctionBackward(Function): 20 | @staticmethod 21 | def forward(ctx, grad_output, out, negative_slope, scale): 22 | ctx.save_for_backward(out) 23 | ctx.negative_slope = negative_slope 24 | ctx.scale = scale 25 | 26 | empty = grad_output.new_empty(0) 27 | 28 | grad_input = fused.fused_bias_act( 29 | grad_output, empty, out, 3, 1, negative_slope, scale 30 | ) 31 | 32 | dim = [0] 33 | 34 | if grad_input.ndim > 2: 35 | dim += list(range(2, grad_input.ndim)) 36 | 37 | grad_bias = grad_input.sum(dim).detach() 38 | 39 | return grad_input, grad_bias 40 | 41 | @staticmethod 42 | def backward(ctx, gradgrad_input, gradgrad_bias): 43 | out, = ctx.saved_tensors 44 | gradgrad_out = fused.fused_bias_act( 45 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 46 | ) 47 | 48 | return gradgrad_out, None, None, None 49 | 50 | 51 | class FusedLeakyReLUFunction(Function): 52 | @staticmethod 53 | def forward(ctx, input, bias, negative_slope, scale): 54 | empty = input.new_empty(0) 55 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 56 | ctx.save_for_backward(out) 57 | ctx.negative_slope = negative_slope 58 | ctx.scale = scale 59 | 60 | return out 61 | 62 | @staticmethod 63 | def backward(ctx, grad_output): 64 | out, = ctx.saved_tensors 65 | 66 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 67 | grad_output, out, ctx.negative_slope, ctx.scale 68 | ) 69 | 70 | return grad_input, grad_bias, None, None 71 | 72 | 73 | class FusedLeakyReLU(nn.Module): 74 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 75 | super().__init__() 76 | 77 | self.bias = nn.Parameter(torch.zeros(channel)) 78 | self.negative_slope = negative_slope 79 | self.scale = scale 80 | 81 | def forward(self, input): 82 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 83 | 84 | 85 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 86 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 87 | -------------------------------------------------------------------------------- /SimilarDomains/gan_models/StyleGAN2/op/fused_act_torch_native.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | module_path = os.path.dirname(__file__) 8 | 9 | 10 | class FusedLeakyReLU(nn.Module): 11 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 12 | super().__init__() 13 | 14 | self.bias = nn.Parameter(torch.zeros(channel)) 15 | self.negative_slope = negative_slope 16 | self.scale = scale 17 | 18 | def forward(self, input): 19 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 20 | 21 | 22 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 23 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 24 | if input.ndim == 3: 25 | return ( 26 | F.leaky_relu( 27 | input + bias.view(1, *rest_dim, bias.shape[0]), negative_slope=negative_slope 28 | ) 29 | * scale 30 | ) 31 | else: 32 | return ( 33 | F.leaky_relu( 34 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope 35 | ) 36 | * scale 37 | ) 38 | -------------------------------------------------------------------------------- /SimilarDomains/gan_models/StyleGAN2/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /SimilarDomains/gan_models/StyleGAN2/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /SimilarDomains/gan_models/StyleGAN2/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /SimilarDomains/gan_models/StyleGAN2/op/upfirdn2d_torch_native.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | module_path = os.path.dirname(__file__) 8 | 9 | 10 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 11 | out = upfirdn2d_native( 12 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 13 | ) 14 | 15 | return out 16 | 17 | 18 | def upfirdn2d_native( 19 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 20 | ): 21 | _, channel, in_h, in_w = input.shape 22 | input = input.reshape(-1, in_h, in_w, 1) 23 | 24 | _, in_h, in_w, minor = input.shape 25 | kernel_h, kernel_w = kernel.shape 26 | 27 | out = input.view(-1, in_h, 1, in_w, 1, minor) 28 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 29 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 30 | 31 | out = F.pad( 32 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 33 | ) 34 | out = out[ 35 | :, 36 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 37 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 38 | :, 39 | ] 40 | 41 | out = out.permute(0, 3, 1, 2) 42 | out = out.reshape( 43 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 44 | ) 45 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 46 | out = F.conv2d(out, w) 47 | out = out.reshape( 48 | -1, 49 | minor, 50 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 51 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 52 | ) 53 | out = out.permute(0, 2, 3, 1) 54 | out = out[:, ::down_y, ::down_x, :] 55 | 56 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 57 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 58 | 59 | return out.view(-1, channel, out_h, out_w) 60 | -------------------------------------------------------------------------------- /SimilarDomains/gan_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/SimilarDomains/gan_models/__init__.py -------------------------------------------------------------------------------- /SimilarDomains/gan_models/gan_load.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | from gan_models.BigGAN import BigGAN, utils 7 | from gan_models.ProgGAN.model import Generator as ProgGenerator 8 | from gan_models.SNGAN.load import load_model_from_state_dict 9 | 10 | try: 11 | from gan_models.StyleGAN2.model import Discriminator as StyleGan2Discriminator 12 | from gan_models.StyleGAN2.model import Generator as StyleGAN2Generator 13 | except Exception as e: 14 | print('StyleGAN2 load fail: {}'.format(e)) 15 | 16 | from core.utils.class_registry import ClassRegistry 17 | 18 | generator_registry = ClassRegistry() 19 | 20 | 21 | class ConditionedBigGAN(nn.Module): 22 | def __init__(self, big_gan, target_classes=(239, )): 23 | super(ConditionedBigGAN, self).__init__() 24 | self.big_gan = big_gan 25 | self.target_classes = nn.Parameter(torch.tensor(target_classes, dtype=torch.int64), 26 | requires_grad=False) 27 | 28 | self.dim_z = self.big_gan.dim_z 29 | 30 | def set_classes(self, cl): 31 | try: 32 | cl[0] 33 | except Exception: 34 | cl = [cl] 35 | self.target_classes.data = torch.tensor(cl, dtype=torch.int64) 36 | 37 | def mixed_classes(self, batch_size): 38 | device = next(self.parameters()).device 39 | if len(self.target_classes.data.shape) == 0: 40 | return self.target_classes.repeat(batch_size).cuda() 41 | else: 42 | return torch.from_numpy( 43 | np.random.choice(self.target_classes.cpu(), [batch_size])).to(device) 44 | 45 | def forward(self, z, classes=None): 46 | if classes is None: 47 | classes = self.mixed_classes(z.shape[0]).to(z.device) 48 | 49 | cl_emb = self.big_gan.shared(classes).to(z.device) 50 | return self.big_gan(z, cl_emb) 51 | 52 | 53 | class StyleGAN2Wrapper(nn.Module): 54 | def __init__(self, g, shift_in_w): 55 | super(StyleGAN2Wrapper, self).__init__() 56 | self.style_gan2 = g 57 | self.shift_in_w = shift_in_w 58 | self.dim_z = 512 59 | self.dim_shift = self.style_gan2.style_dim if shift_in_w else self.dim_z 60 | 61 | def forward(self, input, input_is_latent=False): 62 | return self.style_gan2([input], input_is_latent=input_is_latent)[0] 63 | 64 | def gen_shifted(self, z, shift): 65 | if self.shift_in_w: 66 | w = self.style_gan2.get_latent(z) 67 | return self.forward(w + shift, input_is_latent=True) 68 | else: 69 | return self.forward(z + shift, input_is_latent=False) 70 | 71 | 72 | @generator_registry.add_func_to_registry("stylegan2") 73 | def make_style_gan2(size, weights, latent_dim=512, n_layers_mlp=8, shift_in_w=True): 74 | G = StyleGAN2Generator(size, latent_dim, n_layers_mlp) 75 | G.load_state_dict(torch.load(weights, map_location='cpu')['g_ema']) 76 | G.cuda().eval() 77 | 78 | return StyleGAN2Wrapper(G, shift_in_w=shift_in_w) 79 | 80 | 81 | def make_style_gan2_discriminator(size, weights_path): 82 | D = StyleGan2Discriminator(size) 83 | D.load_state_dict(torch.load(weights_path, map_location='cpu')['d']) 84 | return D 85 | 86 | 87 | @generator_registry.add_func_to_registry("biggan") 88 | def make_big_gan(config_path, weights_path, target_classes): 89 | with open(config_path, 'r') as f: 90 | config = json.load(f) 91 | 92 | config['resolution'] = utils.imsize_dict[config['dataset']] 93 | config['n_classes'] = utils.nclass_dict[config['dataset']] 94 | config['G_activation'] = utils.activation_dict[config['G_nl']] 95 | config['D_activation'] = utils.activation_dict[config['D_nl']] 96 | config['skip_init'] = True 97 | config['no_optim'] = True 98 | 99 | G = BigGAN.Generator(**config) 100 | G.load_state_dict(torch.load(weights_path, map_location='cpu'), strict=True) 101 | 102 | return ConditionedBigGAN(G, target_classes).eval() 103 | 104 | 105 | @generator_registry.add_func_to_registry("proggan") 106 | def make_proggan(weights_root): 107 | model = ProgGenerator() 108 | model.load_state_dict(torch.load(weights_root, map_location='cpu')) 109 | model.cuda() 110 | 111 | setattr(model, 'dim_z', [512, 1, 1]) 112 | return model 113 | 114 | 115 | @generator_registry.add_func_to_registry("sn_anime") 116 | def make_sngan(gan_dir): 117 | gan = load_model_from_state_dict(gan_dir) 118 | G = gan.model.eval() 119 | setattr(G, 'dim_z', gan.distribution.dim) 120 | return G 121 | 122 | 123 | @generator_registry.add_func_to_registry("sn_mnist") 124 | def make_sngan(gan_dir): 125 | gan = load_model_from_state_dict(gan_dir) 126 | G = gan.model.eval() 127 | setattr(G, 'dim_z', gan.distribution.dim) 128 | return G 129 | -------------------------------------------------------------------------------- /SimilarDomains/gan_models/gan_with_shift.py: -------------------------------------------------------------------------------- 1 | import types 2 | from functools import wraps 3 | 4 | 5 | def add_forward_with_shift(generator): 6 | def gen_shifted(self, z, shift, *args, **kwargs): 7 | return self.forward(z + shift, *args, **kwargs) 8 | 9 | generator.gen_shifted = types.MethodType(gen_shifted, generator) 10 | generator.dim_shift = generator.dim_z 11 | 12 | 13 | def gan_with_shift(gan_factory): 14 | @wraps(gan_factory) 15 | def wrapper(*args, **kwargs): 16 | gan = gan_factory(*args, **kwargs) 17 | add_forward_with_shift(gan) 18 | return gan 19 | 20 | return wrapper 21 | -------------------------------------------------------------------------------- /SimilarDomains/main.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | from trainers import trainer_registry 3 | 4 | from core.utils.common import setup_seed 5 | from core.utils.arguments import load_config 6 | from pprint import pprint 7 | 8 | 9 | def run_experiment(exp_config): 10 | pprint(OmegaConf.to_container(exp_config)) 11 | setup_seed(exp_config.exp.seed) 12 | trainer = trainer_registry[exp_config.exp.trainer](exp_config) 13 | trainer.setup() 14 | trainer.train_loop() 15 | 16 | 17 | def run_experiment_from_ckpt(): 18 | ... 19 | 20 | 21 | if __name__ == '__main__': 22 | base_config = load_config() 23 | 24 | if base_config.get('checkpoint'): 25 | ... 26 | 27 | run_experiment(base_config) 28 | -------------------------------------------------------------------------------- /SimilarDomains/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.11.0 2 | torchdiffeq==0.2.3 3 | torchvision==0.12.0 4 | omegaconf 5 | wandb 6 | ftfy 7 | regex 8 | tqdm 9 | git+https://github.com/openai/CLIP.git 10 | scikit-image 11 | scikit-learn 12 | ipython 13 | dlib 14 | gdown 15 | click 16 | -------------------------------------------------------------------------------- /SimilarDomains/restyle_encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/SimilarDomains/restyle_encoders/__init__.py -------------------------------------------------------------------------------- /SimilarDomains/restyle_encoders/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | 4 | from pathlib import Path 5 | 6 | 7 | def get_download_model_command(save_path, file_id, file_name): 8 | """ Get wget download command for downloading the desired model and save to directory ../pretrained_models. """ 9 | if not os.path.exists(save_path): 10 | os.makedirs(save_path) 11 | url = r"""wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id={FILE_ID}" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt""".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path) 12 | return url 13 | 14 | MODEL_PATHS = { 15 | "ffhq_encode": {"id": "1sw6I2lRIB0MpuJkpc8F5BJiSZrc0hjfE", "name": "restyle_psp_ffhq_encode.pt"}, 16 | "cars_encode": {"id": "1zJHqHRQ8NOnVohVVCGbeYMMr6PDhRpPR", "name": "restyle_psp_cars_encode.pt"}, 17 | "church_encode": {"id": "1bcxx7mw-1z7dzbJI_z7oGpWG1oQAvMaD", "name": "restyle_psp_church_encode.pt"}, 18 | "horse_encode": {"id": "19_sUpTYtJmhSAolKLm3VgI-ptYqd-hgY", "name": "restyle_e4e_horse_encode.pt"}, 19 | "afhq_wild_encode": {"id": "1GyFXVTNDUw3IIGHmGS71ChhJ1Rmslhk7", "name": "restyle_psp_afhq_wild_encode.pt"}, 20 | "toonify": {"id": "1GtudVDig59d4HJ_8bGEniz5huaTSGO_0", "name": "restyle_psp_toonify.pt"} 21 | } 22 | 23 | 24 | if __name__ == "__main__": 25 | exp = 'ffhq_encode' 26 | path = MODEL_PATHS[exp] 27 | path_to_save = Path(os.getcwd()).resolve() / 'pretrained' 28 | download_command = get_download_model_command(str(path_to_save), file_id=path["id"], file_name=path["name"]) 29 | 30 | if not os.path.exists(path_to_save / path['name']) or os.path.getsize(path_to_save / path['name']) < 1000000: 31 | print(f'Downloading ReStyle model for {exp}...') 32 | subprocess.run(f"wget {download_command}", shell=True, check=True) 33 | 34 | # if google drive receives too many requests, we'll reach the quota limit and be unable to download the model 35 | if os.path.getsize(EXPERIMENT_ARGS['model_path']) < 1000000: 36 | raise ValueError("Pretrained model was unable to be downloaded correctly!") 37 | else: 38 | print('Done.') 39 | else: 40 | print(f'ReStyle model for {exp} already exists!') -------------------------------------------------------------------------------- /SimilarDomains/restyle_encoders/e4e_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/SimilarDomains/restyle_encoders/e4e_modules/__init__.py -------------------------------------------------------------------------------- /SimilarDomains/restyle_encoders/e4e_modules/discriminator.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class LatentCodesDiscriminator(nn.Module): 5 | def __init__(self, style_dim, n_mlp): 6 | super().__init__() 7 | 8 | self.style_dim = style_dim 9 | 10 | layers = [] 11 | for i in range(n_mlp-1): 12 | layers.append( 13 | nn.Linear(style_dim, style_dim) 14 | ) 15 | layers.append(nn.LeakyReLU(0.2)) 16 | layers.append(nn.Linear(512, 1)) 17 | self.mlp = nn.Sequential(*layers) 18 | 19 | def forward(self, w): 20 | return self.mlp(w) 21 | -------------------------------------------------------------------------------- /SimilarDomains/restyle_encoders/e4e_modules/latent_codes_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class LatentCodesPool: 6 | """This class implements latent codes buffer that stores previously generated w latent codes. 7 | This buffer enables us to update discriminators using a history of generated w's 8 | rather than the ones produced by the latest encoder. 9 | """ 10 | 11 | def __init__(self, pool_size): 12 | """Initialize the ImagePool class 13 | Parameters: 14 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 15 | """ 16 | self.pool_size = pool_size 17 | if self.pool_size > 0: # create an empty pool 18 | self.num_ws = 0 19 | self.ws = [] 20 | 21 | def query(self, ws): 22 | """Return w's from the pool. 23 | Parameters: 24 | ws: the latest generated w's from the generator 25 | Returns w's from the buffer. 26 | By 50/100, the buffer will return input w's. 27 | By 50/100, the buffer will return w's previously stored in the buffer, 28 | and insert the current w's to the buffer. 29 | """ 30 | if self.pool_size == 0: # if the buffer size is 0, do nothing 31 | return ws 32 | return_ws = [] 33 | for w in ws: # ws.shape: (batch, 512) or (batch, n_latent, 512) 34 | # w = torch.unsqueeze(image.data, 0) 35 | if w.ndim == 2: 36 | i = random.randint(0, len(w) - 1) # apply a random latent index as a candidate 37 | w = w[i] 38 | self.handle_w(w, return_ws) 39 | return_ws = torch.stack(return_ws, 0) # collect all the images and return 40 | return return_ws 41 | 42 | def handle_w(self, w, return_ws): 43 | if self.num_ws < self.pool_size: # if the buffer is not full; keep inserting current codes to the buffer 44 | self.num_ws = self.num_ws + 1 45 | self.ws.append(w) 46 | return_ws.append(w) 47 | else: 48 | p = random.uniform(0, 1) 49 | if p > 0.5: # by 50% chance, the buffer will return a previously stored latent code, and insert the current code into the buffer 50 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 51 | tmp = self.ws[random_id].clone() 52 | self.ws[random_id] = w 53 | return_ws.append(tmp) 54 | else: # by another 50% chance, the buffer will return the current image 55 | return_ws.append(w) 56 | -------------------------------------------------------------------------------- /SimilarDomains/restyle_encoders/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/SimilarDomains/restyle_encoders/encoders/__init__.py -------------------------------------------------------------------------------- /SimilarDomains/restyle_encoders/encoders/map2style.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import nn 3 | from torch.nn import Conv2d, Module 4 | 5 | from gan_models.StyleGAN2.model import EqualLinear 6 | 7 | 8 | class GradualStyleBlock(Module): 9 | def __init__(self, in_c, out_c, spatial): 10 | super(GradualStyleBlock, self).__init__() 11 | self.out_c = out_c 12 | self.spatial = spatial 13 | num_pools = int(np.log2(spatial)) 14 | modules = [] 15 | modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1), 16 | nn.LeakyReLU()] 17 | for i in range(num_pools - 1): 18 | modules += [ 19 | Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1), 20 | nn.LeakyReLU() 21 | ] 22 | self.convs = nn.Sequential(*modules) 23 | self.linear = EqualLinear(out_c, out_c, lr_mul=1) 24 | 25 | def forward(self, x): 26 | x = self.convs(x) 27 | x = x.view(-1, self.out_c) 28 | x = self.linear(x) 29 | return x 30 | -------------------------------------------------------------------------------- /SimilarDomains/restyle_encoders/encoders/model_irse.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 2 | from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 3 | 4 | """ 5 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 6 | """ 7 | 8 | 9 | class Backbone(Module): 10 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 11 | super(Backbone, self).__init__() 12 | assert input_size in [112, 224], "input_size should be 112 or 224" 13 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 14 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 15 | blocks = get_blocks(num_layers) 16 | if mode == 'ir': 17 | unit_module = bottleneck_IR 18 | elif mode == 'ir_se': 19 | unit_module = bottleneck_IR_SE 20 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 21 | BatchNorm2d(64), 22 | PReLU(64)) 23 | if input_size == 112: 24 | self.output_layer = Sequential(BatchNorm2d(512), 25 | Dropout(drop_ratio), 26 | Flatten(), 27 | Linear(512 * 7 * 7, 512), 28 | BatchNorm1d(512, affine=affine)) 29 | else: 30 | self.output_layer = Sequential(BatchNorm2d(512), 31 | Dropout(drop_ratio), 32 | Flatten(), 33 | Linear(512 * 14 * 14, 512), 34 | BatchNorm1d(512, affine=affine)) 35 | 36 | modules = [] 37 | for block in blocks: 38 | for bottleneck in block: 39 | modules.append(unit_module(bottleneck.in_channel, 40 | bottleneck.depth, 41 | bottleneck.stride)) 42 | self.body = Sequential(*modules) 43 | 44 | def forward(self, x): 45 | x = self.input_layer(x) 46 | x = self.body(x) 47 | x = self.output_layer(x) 48 | return l2_norm(x) 49 | 50 | 51 | def IR_50(input_size): 52 | """Constructs a ir-50 model.""" 53 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 54 | return model 55 | 56 | 57 | def IR_101(input_size): 58 | """Constructs a ir-101 model.""" 59 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 60 | return model 61 | 62 | 63 | def IR_152(input_size): 64 | """Constructs a ir-152 model.""" 65 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 66 | return model 67 | 68 | 69 | def IR_SE_50(input_size): 70 | """Constructs a ir_se-50 model.""" 71 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 72 | return model 73 | 74 | 75 | def IR_SE_101(input_size): 76 | """Constructs a ir_se-101 model.""" 77 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 78 | return model 79 | 80 | 81 | def IR_SE_152(input_size): 82 | """Constructs a ir_se-152 model.""" 83 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 84 | return model 85 | -------------------------------------------------------------------------------- /SimilarDomains/restyle_encoders/encoders/restyle_psp_encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module 4 | from torchvision.models.resnet import resnet34 5 | 6 | from restyle_encoders.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE 7 | from restyle_encoders.encoders.map2style import GradualStyleBlock 8 | 9 | 10 | class BackboneEncoder(Module): 11 | """ 12 | The simpler backbone architecture used by ReStyle where all style vectors are extracted from the final 16x16 feature 13 | map of the encoder. This classes uses the simplified architecture applied over an ResNet IRSE-50 backbone. 14 | Note this class is designed to be used for the human facial domain. 15 | """ 16 | def __init__(self, num_layers, mode='ir', n_styles=18, opts=None): 17 | super(BackboneEncoder, self).__init__() 18 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 19 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 20 | blocks = get_blocks(num_layers) 21 | if mode == 'ir': 22 | unit_module = bottleneck_IR 23 | elif mode == 'ir_se': 24 | unit_module = bottleneck_IR_SE 25 | 26 | self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False), 27 | BatchNorm2d(64), 28 | PReLU(64)) 29 | modules = [] 30 | for block in blocks: 31 | for bottleneck in block: 32 | modules.append(unit_module(bottleneck.in_channel, 33 | bottleneck.depth, 34 | bottleneck.stride)) 35 | self.body = Sequential(*modules) 36 | 37 | self.styles = nn.ModuleList() 38 | self.style_count = n_styles 39 | for i in range(self.style_count): 40 | style = GradualStyleBlock(512, 512, 16) 41 | self.styles.append(style) 42 | 43 | def forward(self, x): 44 | x = self.input_layer(x) 45 | x = self.body(x) 46 | latents = [] 47 | for j in range(self.style_count): 48 | latents.append(self.styles[j](x)) 49 | out = torch.stack(latents, dim=1) 50 | return out 51 | 52 | 53 | class ResNetBackboneEncoder(Module): 54 | """ 55 | The simpler backbone architecture used by ReStyle where all style vectors are extracted from the final 16x16 feature 56 | map of the encoder. This classes uses the simplified architecture applied over an ResNet34 backbone. 57 | """ 58 | def __init__(self, n_styles=18, opts=None): 59 | super(ResNetBackboneEncoder, self).__init__() 60 | 61 | self.conv1 = nn.Conv2d(opts.input_nc, 64, kernel_size=7, stride=2, padding=3, bias=False) 62 | self.bn1 = BatchNorm2d(64) 63 | self.relu = PReLU(64) 64 | 65 | resnet_basenet = resnet34(pretrained=True) 66 | blocks = [ 67 | resnet_basenet.layer1, 68 | resnet_basenet.layer2, 69 | resnet_basenet.layer3, 70 | resnet_basenet.layer4 71 | ] 72 | modules = [] 73 | for block in blocks: 74 | for bottleneck in block: 75 | modules.append(bottleneck) 76 | self.body = Sequential(*modules) 77 | 78 | self.styles = nn.ModuleList() 79 | self.style_count = n_styles 80 | for i in range(self.style_count): 81 | style = GradualStyleBlock(512, 512, 16) 82 | self.styles.append(style) 83 | 84 | def forward(self, x): 85 | x = self.conv1(x) 86 | x = self.bn1(x) 87 | x = self.relu(x) 88 | x = self.body(x) 89 | latents = [] 90 | for j in range(self.style_count): 91 | latents.append(self.styles[j](x)) 92 | out = torch.stack(latents, dim=1) 93 | return out 94 | -------------------------------------------------------------------------------- /SimilarDomains/restyle_encoders/mtcnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/SimilarDomains/restyle_encoders/mtcnn/__init__.py -------------------------------------------------------------------------------- /SimilarDomains/restyle_encoders/mtcnn/mtcnn_pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/SimilarDomains/restyle_encoders/mtcnn/mtcnn_pytorch/__init__.py -------------------------------------------------------------------------------- /SimilarDomains/restyle_encoders/mtcnn/mtcnn_pytorch/src/__init__.py: -------------------------------------------------------------------------------- 1 | from .visualization_utils import show_bboxes 2 | from .detector import detect_faces 3 | -------------------------------------------------------------------------------- /SimilarDomains/restyle_encoders/mtcnn/mtcnn_pytorch/src/detector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from .get_nets import PNet, RNet, ONet 4 | from .box_utils import nms, calibrate_box, get_image_boxes, convert_to_square 5 | from .first_stage import run_first_stage 6 | 7 | 8 | def detect_faces(image, min_face_size=20.0, 9 | thresholds=[0.6, 0.7, 0.8], 10 | nms_thresholds=[0.7, 0.7, 0.7]): 11 | """ 12 | Arguments: 13 | image: an instance of PIL.Image. 14 | min_face_size: a float number. 15 | thresholds: a list of length 3. 16 | nms_thresholds: a list of length 3. 17 | 18 | Returns: 19 | two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10], 20 | bounding boxes and facial landmarks. 21 | """ 22 | 23 | # LOAD MODELS 24 | pnet = PNet() 25 | rnet = RNet() 26 | onet = ONet() 27 | onet.eval() 28 | 29 | # BUILD AN IMAGE PYRAMID 30 | width, height = image.size 31 | min_length = min(height, width) 32 | 33 | min_detection_size = 12 34 | factor = 0.707 # sqrt(0.5) 35 | 36 | # scales for scaling the image 37 | scales = [] 38 | 39 | # scales the image so that 40 | # minimum size that we can detect equals to 41 | # minimum face size that we want to detect 42 | m = min_detection_size / min_face_size 43 | min_length *= m 44 | 45 | factor_count = 0 46 | while min_length > min_detection_size: 47 | scales.append(m * factor ** factor_count) 48 | min_length *= factor 49 | factor_count += 1 50 | 51 | # STAGE 1 52 | 53 | # it will be returned 54 | bounding_boxes = [] 55 | 56 | with torch.no_grad(): 57 | # run P-Net on different scales 58 | for s in scales: 59 | boxes = run_first_stage(image, pnet, scale=s, threshold=thresholds[0]) 60 | bounding_boxes.append(boxes) 61 | 62 | # collect boxes (and offsets, and scores) from different scales 63 | bounding_boxes = [i for i in bounding_boxes if i is not None] 64 | bounding_boxes = np.vstack(bounding_boxes) 65 | 66 | keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0]) 67 | bounding_boxes = bounding_boxes[keep] 68 | 69 | # use offsets predicted by pnet to transform bounding boxes 70 | bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:]) 71 | # shape [n_boxes, 5] 72 | 73 | bounding_boxes = convert_to_square(bounding_boxes) 74 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 75 | 76 | # STAGE 2 77 | 78 | img_boxes = get_image_boxes(bounding_boxes, image, size=24) 79 | img_boxes = torch.FloatTensor(img_boxes) 80 | 81 | output = rnet(img_boxes) 82 | offsets = output[0].data.numpy() # shape [n_boxes, 4] 83 | probs = output[1].data.numpy() # shape [n_boxes, 2] 84 | 85 | keep = np.where(probs[:, 1] > thresholds[1])[0] 86 | bounding_boxes = bounding_boxes[keep] 87 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 88 | offsets = offsets[keep] 89 | 90 | keep = nms(bounding_boxes, nms_thresholds[1]) 91 | bounding_boxes = bounding_boxes[keep] 92 | bounding_boxes = calibrate_box(bounding_boxes, offsets[keep]) 93 | bounding_boxes = convert_to_square(bounding_boxes) 94 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 95 | 96 | # STAGE 3 97 | 98 | img_boxes = get_image_boxes(bounding_boxes, image, size=48) 99 | if len(img_boxes) == 0: 100 | return [], [] 101 | img_boxes = torch.FloatTensor(img_boxes) 102 | output = onet(img_boxes) 103 | landmarks = output[0].data.numpy() # shape [n_boxes, 10] 104 | offsets = output[1].data.numpy() # shape [n_boxes, 4] 105 | probs = output[2].data.numpy() # shape [n_boxes, 2] 106 | 107 | keep = np.where(probs[:, 1] > thresholds[2])[0] 108 | bounding_boxes = bounding_boxes[keep] 109 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 110 | offsets = offsets[keep] 111 | landmarks = landmarks[keep] 112 | 113 | # compute landmark points 114 | width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0 115 | height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0 116 | xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1] 117 | landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5] 118 | landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10] 119 | 120 | bounding_boxes = calibrate_box(bounding_boxes, offsets) 121 | keep = nms(bounding_boxes, nms_thresholds[2], mode='min') 122 | bounding_boxes = bounding_boxes[keep] 123 | landmarks = landmarks[keep] 124 | 125 | return bounding_boxes, landmarks 126 | -------------------------------------------------------------------------------- /SimilarDomains/restyle_encoders/mtcnn/mtcnn_pytorch/src/first_stage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from PIL import Image 4 | import numpy as np 5 | from .box_utils import nms, _preprocess 6 | 7 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 8 | device = 'cuda:0' 9 | 10 | 11 | def run_first_stage(image, net, scale, threshold): 12 | """Run P-Net, generate bounding boxes, and do NMS. 13 | 14 | Arguments: 15 | image: an instance of PIL.Image. 16 | net: an instance of pytorch's nn.Module, P-Net. 17 | scale: a float number, 18 | scale width and height of the image by this number. 19 | threshold: a float number, 20 | threshold on the probability of a face when generating 21 | bounding boxes from predictions of the net. 22 | 23 | Returns: 24 | a float numpy array of shape [n_boxes, 9], 25 | bounding boxes with scores and offsets (4 + 1 + 4). 26 | """ 27 | 28 | # scale the image and convert it to a float array 29 | width, height = image.size 30 | sw, sh = math.ceil(width * scale), math.ceil(height * scale) 31 | img = image.resize((sw, sh), Image.BILINEAR) 32 | img = np.asarray(img, 'float32') 33 | 34 | img = torch.FloatTensor(_preprocess(img)).to(device) 35 | with torch.no_grad(): 36 | output = net(img) 37 | probs = output[1].cpu().data.numpy()[0, 1, :, :] 38 | offsets = output[0].cpu().data.numpy() 39 | # probs: probability of a face at each sliding window 40 | # offsets: transformations to true bounding boxes 41 | 42 | boxes = _generate_bboxes(probs, offsets, scale, threshold) 43 | if len(boxes) == 0: 44 | return None 45 | 46 | keep = nms(boxes[:, 0:5], overlap_threshold=0.5) 47 | return boxes[keep] 48 | 49 | 50 | def _generate_bboxes(probs, offsets, scale, threshold): 51 | """Generate bounding boxes at places 52 | where there is probably a face. 53 | 54 | Arguments: 55 | probs: a float numpy array of shape [n, m]. 56 | offsets: a float numpy array of shape [1, 4, n, m]. 57 | scale: a float number, 58 | width and height of the image were scaled by this number. 59 | threshold: a float number. 60 | 61 | Returns: 62 | a float numpy array of shape [n_boxes, 9] 63 | """ 64 | 65 | # applying P-Net is equivalent, in some sense, to 66 | # moving 12x12 window with stride 2 67 | stride = 2 68 | cell_size = 12 69 | 70 | # indices of boxes where there is probably a face 71 | inds = np.where(probs > threshold) 72 | 73 | if inds[0].size == 0: 74 | return np.array([]) 75 | 76 | # transformations of bounding boxes 77 | tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)] 78 | # they are defined as: 79 | # w = x2 - x1 + 1 80 | # h = y2 - y1 + 1 81 | # x1_true = x1 + tx1*w 82 | # x2_true = x2 + tx2*w 83 | # y1_true = y1 + ty1*h 84 | # y2_true = y2 + ty2*h 85 | 86 | offsets = np.array([tx1, ty1, tx2, ty2]) 87 | score = probs[inds[0], inds[1]] 88 | 89 | # P-Net is applied to scaled images 90 | # so we need to rescale bounding boxes back 91 | bounding_boxes = np.vstack([ 92 | np.round((stride * inds[1] + 1.0) / scale), 93 | np.round((stride * inds[0] + 1.0) / scale), 94 | np.round((stride * inds[1] + 1.0 + cell_size) / scale), 95 | np.round((stride * inds[0] + 1.0 + cell_size) / scale), 96 | score, offsets 97 | ]) 98 | # why one is added? 99 | 100 | return bounding_boxes.T 101 | -------------------------------------------------------------------------------- /SimilarDomains/restyle_encoders/mtcnn/mtcnn_pytorch/src/visualization_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageDraw 2 | 3 | 4 | def show_bboxes(img, bounding_boxes, facial_landmarks=[]): 5 | """Draw bounding boxes and facial landmarks. 6 | 7 | Arguments: 8 | img: an instance of PIL.Image. 9 | bounding_boxes: a float numpy array of shape [n, 5]. 10 | facial_landmarks: a float numpy array of shape [n, 10]. 11 | 12 | Returns: 13 | an instance of PIL.Image. 14 | """ 15 | 16 | img_copy = img.copy() 17 | draw = ImageDraw.Draw(img_copy) 18 | 19 | for b in bounding_boxes: 20 | draw.rectangle([ 21 | (b[0], b[1]), (b[2], b[3]) 22 | ], outline='white') 23 | 24 | for p in facial_landmarks: 25 | for i in range(5): 26 | draw.ellipse([ 27 | (p[i] - 1.0, p[i + 5] - 1.0), 28 | (p[i] + 1.0, p[i + 5] + 1.0) 29 | ], outline='blue') 30 | 31 | return img_copy 32 | -------------------------------------------------------------------------------- /SimilarDomains/restyle_encoders/mtcnn/mtcnn_pytorch/src/weights/onet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/SimilarDomains/restyle_encoders/mtcnn/mtcnn_pytorch/src/weights/onet.npy -------------------------------------------------------------------------------- /SimilarDomains/restyle_encoders/mtcnn/mtcnn_pytorch/src/weights/pnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/SimilarDomains/restyle_encoders/mtcnn/mtcnn_pytorch/src/weights/pnet.npy -------------------------------------------------------------------------------- /SimilarDomains/restyle_encoders/mtcnn/mtcnn_pytorch/src/weights/rnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/SimilarDomains/restyle_encoders/mtcnn/mtcnn_pytorch/src/weights/rnet.npy -------------------------------------------------------------------------------- /img/Figure-FewShot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/img/Figure-FewShot.png -------------------------------------------------------------------------------- /img/diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/img/diagram.png -------------------------------------------------------------------------------- /img/few_shot_domains.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/img/few_shot_domains.png -------------------------------------------------------------------------------- /img/one_shot_domains.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/img/one_shot_domains.png -------------------------------------------------------------------------------- /img/style_domain_transfer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/img/style_domain_transfer.png -------------------------------------------------------------------------------- /img/titan_armin_joker_pixar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleDomain/44b8370c90df25754ec01e1f880ca02b502e3bce/img/titan_armin_joker_pixar.png --------------------------------------------------------------------------------