├── CelebAMask-HQ-attribute-anno.txt ├── LICENSE ├── README.md ├── arguments ├── inference_arguments.py └── training_arguments.py ├── assets ├── dicaprio.png ├── graph.png ├── images_preview.webp ├── keanu.jpeg ├── mask.jpg ├── phase1.png ├── phase2.png ├── potter.jpg └── preview.webp ├── available_directions.txt ├── configs ├── fse_editor_train.yaml ├── fse_inference.yaml ├── fse_inverter_inference.yaml ├── fse_inverter_train.yaml ├── paths.py └── simple_inference.yaml ├── criteria ├── __init__.py ├── id_loss.py ├── id_vit_loss.py ├── lpips │ ├── __init__.py │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── moco_loss.py ├── ms_ssim.py ├── resnet.py └── w_norm.py ├── datasets ├── datasets.py ├── loaders.py └── transforms.py ├── dnnlib ├── __init__.py ├── tflib │ ├── __init__.py │ ├── autosummary.py │ ├── custom_ops.py │ ├── network.py │ ├── ops │ │ ├── __init__.py │ │ ├── fused_bias_act.cu │ │ ├── fused_bias_act.py │ │ ├── upfirdn_2d.cu │ │ └── upfirdn_2d.py │ ├── optimizer.py │ └── tfutil.py └── util.py ├── editings ├── bound │ ├── Eyeglasses_boundary.npy │ ├── Heavy_Makeup_boundary.npy │ └── Smiling_boundary.npy ├── deltaedit │ ├── delta_mapper.py │ ├── editor.py │ └── map_tool.py ├── ganspace.py ├── ganspace_pca │ ├── cars_pca.pt │ ├── church_pca.pt │ └── ffhq_pca.pt ├── interfacegan_directions │ ├── age.pt │ ├── rotation.pt │ └── smile.pt ├── latent_editor.py └── styleclip │ ├── __init__.py │ ├── global_mapper_data │ ├── S_mean_std │ ├── delta_i_c.npy │ └── templates.txt │ ├── mapper │ ├── __init__.py │ ├── gloabl_mapper.py │ ├── latent_mappers.py │ └── styleclip_mapper.py │ └── models │ ├── __init__.py │ └── stylegan2 │ ├── __init__.py │ ├── model.py │ └── op │ ├── __init__.py │ ├── fused_act.py │ └── upfirdn2d.py ├── env_install.sh ├── metrics └── metrics.py ├── models ├── farl │ ├── __init__.py │ └── farl.py ├── hyperinverter │ ├── encoders │ │ ├── __init__.py │ │ ├── fpn_encoders.py │ │ ├── helpers.py │ │ └── model_irse.py │ ├── hypernetwork.py │ ├── stylegan2 │ │ ├── __init__.py │ │ ├── model.py │ │ └── op │ │ │ ├── __init__.py │ │ │ ├── fused_act.py │ │ │ ├── fused_bias_act.cpp │ │ │ ├── fused_bias_act_kernel.cu │ │ │ ├── upfirdn2d.cpp │ │ │ ├── upfirdn2d.py │ │ │ └── upfirdn2d_kernel.cu │ ├── stylegan2_ada.py │ └── weight_shapes.py ├── methods.py ├── mtcnn │ ├── __init__.py │ ├── mtcnn.py │ └── mtcnn_pytorch │ │ ├── __init__.py │ │ └── src │ │ ├── __init__.py │ │ ├── align_trans.py │ │ ├── box_utils.py │ │ ├── detector.py │ │ ├── first_stage.py │ │ ├── get_nets.py │ │ ├── matlab_cp2tform.py │ │ ├── visualization_utils.py │ │ └── weights │ │ ├── onet.npy │ │ ├── pnet.npy │ │ └── rnet.npy └── psp │ ├── __init__.py │ ├── encoders │ ├── __init__.py │ ├── feature_resnet.py │ ├── helpers.py │ ├── model_irse.py │ └── psp_encoders.py │ └── stylegan2 │ ├── __init__.py │ ├── model.py │ └── op │ ├── __init__.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── notebook ├── StyleFeatureEditor_inference.ipynb └── images │ ├── gosling.jpg │ ├── robert.png │ ├── robert_aligned_mask.jpg │ ├── scarlet.jpg │ ├── smith.jpg │ └── watson.jpeg ├── requirements.txt ├── runners ├── base_runner.py ├── inference_runners.py ├── simple_runner.py └── training_runners.py ├── scripts ├── align_all_parallel.py ├── calculate_metrics.py ├── fid_calculation.py ├── inference.py ├── simple_inference.py └── train.py ├── training ├── __init__.py ├── loggers.py ├── losses.py └── optimizers.py └── utils ├── __init__.py ├── class_registry.py ├── common_utils.py ├── data_utils.py ├── model_utils.py └── torch_utils ├── __init__.py ├── custom_ops.py ├── misc.py ├── ops ├── __init__.py ├── bias_act.cpp ├── bias_act.cu ├── bias_act.h ├── bias_act.py ├── conv2d_gradfix.py ├── conv2d_resample.py ├── conv2d_resample_new.py ├── fma.py ├── upfirdn2d.cpp ├── upfirdn2d.cu ├── upfirdn2d.h └── upfirdn2d.py └── persistence.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 AIRI - Artificial Intelligence Research Institute 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /arguments/inference_arguments.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from pathlib import Path 4 | from typing import Optional, List, Tuple, Dict 5 | from dataclasses import dataclass, field 6 | from omegaconf import OmegaConf, MISSING 7 | from utils.class_registry import ClassRegistry 8 | from models.methods import methods_registry 9 | from metrics.metrics import metrics_registry 10 | 11 | 12 | 13 | args = ClassRegistry() 14 | 15 | 16 | @args.add_to_registry("exp") 17 | @dataclass 18 | class ExperimentArgs: 19 | config_dir: str = str(Path(__file__).resolve().parent / "configs") 20 | config: str = MISSING 21 | output_dir: str = "results_dir" 22 | seed: int = 1 23 | root: str = os.getenv("EXP_ROOT", ".") 24 | domain: str = "human_faces" 25 | wandb: bool = False 26 | 27 | 28 | @args.add_to_registry("data") 29 | @dataclass 30 | class DataArgs: 31 | inference_dir: str = "" 32 | transform: str = "face_1024" 33 | 34 | 35 | @args.add_to_registry("inference") 36 | @dataclass 37 | class InferenceArgs: 38 | inference_runner: str = "base_inference_runner" 39 | editings_data: Dict = field(default_factory=lambda: {}) 40 | 41 | 42 | @args.add_to_registry("model") 43 | @dataclass 44 | class ModelArgs: 45 | method: str = "fse_full" 46 | device: str = "0" 47 | batch_size: int = 4 48 | workers: int = 4 49 | checkpoint_path: str = "" 50 | 51 | 52 | 53 | MethodsArgs = methods_registry.make_dataclass_from_args("MethodsArgs") 54 | args.add_to_registry("methods_args")(MethodsArgs) 55 | 56 | MetricsArgs = metrics_registry.make_dataclass_from_args("MetricsArgs") 57 | args.add_to_registry("metrics")(MetricsArgs) 58 | 59 | 60 | 61 | Args = args.make_dataclass_from_classes("Args") 62 | 63 | 64 | def load_config(): 65 | config = OmegaConf.structured(Args) 66 | 67 | conf_cli = OmegaConf.from_cli() 68 | config.exp.config = conf_cli.exp.config 69 | config.exp.config_dir = conf_cli.exp.config_dir 70 | 71 | config_path = os.path.join(config.exp.config_dir, config.exp.config) 72 | conf_file = OmegaConf.load(config_path) 73 | config = OmegaConf.merge(config, conf_file) 74 | for method in list(config.methods_args.keys()): 75 | if method != config.model.method: 76 | config.methods_args.__delattr__(method) 77 | 78 | config = OmegaConf.merge(config, conf_cli) 79 | 80 | return config 81 | -------------------------------------------------------------------------------- /arguments/training_arguments.py: -------------------------------------------------------------------------------- 1 | import os 2 | from training.losses import disc_losses 3 | from training.optimizers import optimizers 4 | from pathlib import Path 5 | from typing import Optional, List, Tuple, Dict 6 | from dataclasses import dataclass, field 7 | from omegaconf import OmegaConf, MISSING 8 | from utils.class_registry import ClassRegistry 9 | from models.methods import methods_registry 10 | from metrics.metrics import metrics_registry 11 | 12 | 13 | args = ClassRegistry() 14 | 15 | 16 | @args.add_to_registry("exp") 17 | @dataclass 18 | class ExperimentArgs: 19 | config_dir: str = str(Path(__file__).resolve().parent / "configs") 20 | config: str = MISSING 21 | exp_dir: str = "experiments" 22 | name: str = MISSING 23 | seed: int = 1 24 | root: str = os.getenv("EXP_ROOT", ".") 25 | wandb: bool = True 26 | wandb_project: str = "sfe" 27 | domain: str = "human_faces" 28 | 29 | 30 | @args.add_to_registry("data") 31 | @dataclass 32 | class DataArgs: 33 | special_dir: str = MISSING 34 | transform: str = "face_1024" 35 | input_train_dir: str = MISSING 36 | input_val_dir: str = MISSING 37 | 38 | 39 | @args.add_to_registry("train") 40 | @dataclass 41 | class TrainingArgs: 42 | train_runner: str = "base_training_runner" 43 | encoder_optimizer: str = "ranger" 44 | disc_optimizer: str = "adam" 45 | resume_path: str = "" 46 | val_metrics: List[str] = field( 47 | default_factory=lambda: ["msssim", "lpips", "l2", "fid"] 48 | ) 49 | start_step: int = 0 50 | steps: int = 300000 51 | log_step: int = 500 52 | checkpoint_step: int = 15000 53 | val_step: int = 15000 54 | train_dis: bool = False 55 | dis_train_start_step: int = 150000 56 | bs_used_before_adv_loss: int = 8 57 | disc_edits: List[str] = field( 58 | default_factory=lambda: [] 59 | ) 60 | 61 | @args.add_to_registry("model") 62 | @dataclass 63 | class ModelArgs: 64 | method: str = "fse_full" 65 | device: str = "0" 66 | batch_size: int = 4 67 | workers: int = 4 68 | checkpoint_path: str = "" 69 | 70 | 71 | @args.add_to_registry("encoder_losses") 72 | @dataclass 73 | class EncoderLossesArgs: 74 | l2: float = 0.0 75 | lpips: float = 0.0 76 | lpips_scale: float = 0.0 77 | id: float = 0.0 78 | moco: float = 0.0 79 | adv: float = 0.0 80 | feat_rec: float = 0.0 81 | feat_rec_l1: float = 0.0 82 | l2_latent: float = 0.0 83 | id_vit: float = 0.0 84 | 85 | 86 | MethodsArgs = methods_registry.make_dataclass_from_args("MethodsArgs") 87 | args.add_to_registry("methods_args")(MethodsArgs) 88 | 89 | DiscLossesArgs = disc_losses.make_dataclass_from_args("DiscLossesArgs") 90 | args.add_to_registry("disc_losses")(DiscLossesArgs) 91 | 92 | OptimizersArgs = optimizers.make_dataclass_from_args("OptimizersArgs") 93 | args.add_to_registry("optimizers")(OptimizersArgs) 94 | 95 | MetricsArgs = metrics_registry.make_dataclass_from_args("MetricsArgs") 96 | args.add_to_registry("metrics")(MetricsArgs) 97 | 98 | 99 | Args = args.make_dataclass_from_classes("Args") 100 | 101 | 102 | def load_config(): 103 | config = OmegaConf.structured(Args) 104 | 105 | conf_cli = OmegaConf.from_cli() 106 | config.exp.config = conf_cli.exp.config 107 | config.exp.config_dir = conf_cli.exp.config_dir 108 | 109 | config_path = os.path.join(config.exp.config_dir, config.exp.config) 110 | conf_file = OmegaConf.load(config_path) 111 | config = OmegaConf.merge(config, conf_file) 112 | 113 | for method in list(config.methods_args.keys()): 114 | if method != config.model.method: 115 | config.methods_args.__delattr__(method) 116 | 117 | config = OmegaConf.merge(config, conf_cli) 118 | 119 | return config 120 | -------------------------------------------------------------------------------- /assets/dicaprio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/assets/dicaprio.png -------------------------------------------------------------------------------- /assets/graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/assets/graph.png -------------------------------------------------------------------------------- /assets/images_preview.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/assets/images_preview.webp -------------------------------------------------------------------------------- /assets/keanu.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/assets/keanu.jpeg -------------------------------------------------------------------------------- /assets/mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/assets/mask.jpg -------------------------------------------------------------------------------- /assets/phase1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/assets/phase1.png -------------------------------------------------------------------------------- /assets/phase2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/assets/phase2.png -------------------------------------------------------------------------------- /assets/potter.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/assets/potter.jpg -------------------------------------------------------------------------------- /assets/preview.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/assets/preview.webp -------------------------------------------------------------------------------- /available_directions.txt: -------------------------------------------------------------------------------- 1 | direction method\name | direction effect | approximate direction | additional comments 2 | | during positive edit | power range to | 3 | | (during negative edit | editing works well | 4 | | effect is reversed) | (found empirically) | 5 | 6 | 7 | fs_directions: 8 | fs_glasses | add glasses | [-20; 30] | may open mouth 9 | fs_smiling | add smile | [-10; 10] 10 | fs_makeup | add make-up | [-10; 15] | bad works with men 11 | 12 | ganspace_directions: 13 | eye_openness | close eyes | [-30; 45] 14 | trimmed_beard | remove beard | [-30; 30] | bad works with women 15 | lipstick | add lipstick | [-30; 30] | bad works with men 16 | face_roundness | make face rounder | [-20; 15] 17 | nose_length | decreace nose length | [-30; 30] | may open mouth 18 | eyebrow_thickness | decreace eyebrow | [-20; 20] 19 | | thickness | 20 | displeased | add sadness | [-10; 10] 21 | 22 | interfacegan_directions: 23 | age | increase age | [-10; 10] 24 | smile | add smile | [-5 ; 5 ] | may cause hair artefacts, better use fs_smiling 25 | rotation | turn face right | [-7 ; 7 ] 26 | 27 | styleclip_directions: 28 | afro | afro hairstyle | [0; 0.14] 29 | angry | make angrier | [0; 0.14] | cause background artefacts 30 | bobcut | bobcut hairstyle | [0; 0.18] | cause background artefacts 31 | bowlcut | bowlcut h. style | [0; 0.14] | cause background artefacts 32 | mohawk | mohawk hairstyle | [0; 0.10] | cause background artefacts 33 | curly_hair | add curls | [0; 0.12] 34 | purple_hair | dye hair purple | [0; 0.12] 35 | surprised | make more surprised | [0; 0.10] 36 | beyonce | make similar to ... | [0; 0.12] 37 | hilary_clinton | make similar to ... | [0; 0.10] 38 | depp | make similar to ... | [0; 0.12] 39 | taylor_swift | make similar to ... | [0; 0.10] 40 | trump | make similar to ... | [0; 0.10] 41 | zuckerberg | make similar to ... | [0; 0.10] 42 | 43 | stylespace_directions: 44 | black hair | darken hair | [-7; 10] 45 | blond hair | darken hair | [-10; 7] | yes, positive power darkens hair 46 | grey hair | make hair colored | [-7 ; 7] | negative means make hair grey 47 | wavy hair | add hair length | [-7 ; 7] 48 | receding hairline | remove bald | [-10; 10] 49 | smiling | remove smile | [-10; 10] | better use fs_smiling 50 | sideburns | remove sideburns | [-7 ; 7] | bad works with women 51 | goatee | remove goatee | [-7 ; 7] | bad works with women 52 | earrings | add earrings | [0 ; 15] | bad works with men 53 | glasses | remove glasses | [-10; 10] 54 | gender | add femininity | [-10; 7] | better use global mapper 55 | 56 | You can alse use directions from text prompts via StyleClip Global Mapper (https://arxiv.org/abs/2103.17249). 57 | Such directions look as follows: "styleclip_global_{neutral prompt}_{target prompt}_{disentanglement}" where 58 | neutral prompt -- some neutral description of the original image (e.g. "a face") 59 | target prompt -- text that contains the desired edit (e.g. "a smilling face") 60 | disentanglement -- positive number, the more this attribute - the more related attributes will also be changed (e.g. 61 | for grey hair editing, wrinkle, skin colour and glasses may also be edited) 62 | 63 | Example: "styleclip_global_a face_a face with black hair_0.18" -------------------------------------------------------------------------------- /configs/fse_editor_train.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | train_runner: fse_editor 3 | train_dis: True 4 | dis_train_start_step: 45000 5 | steps: 500000 6 | 7 | 8 | model: 9 | method: fse_full 10 | device: "0" 11 | batch_size: 8 12 | 13 | 14 | encoder_losses: 15 | l2: 1.0 16 | lpips: 0.8 17 | id: 0.1 18 | adv: 0.01 19 | 20 | 21 | disc_losses: 22 | main: 23 | coef: 1 24 | r1: 25 | coef: 10 26 | 27 | 28 | optimizers: 29 | ranger: 30 | lr: 0.0001 -------------------------------------------------------------------------------- /configs/fse_inference.yaml: -------------------------------------------------------------------------------- 1 | inference: 2 | inference_runner: fse_inference_runner 3 | 4 | editings_data: { 5 | "age": [ 2, 3, 4, 5, 6, 7, 8], 6 | "fs_glasses": [5, 10, 15, 20, 25, 30], 7 | "fs_smiling": [-10, -9, -8, -7, -6, -5, -4,-3], 8 | "styleclip_global_face with hair_face with fire hair_0.10": [5, 9] 9 | } 10 | 11 | 12 | model: 13 | method: fse_full 14 | device: "0" 15 | batch_size: 8 16 | -------------------------------------------------------------------------------- /configs/fse_inverter_inference.yaml: -------------------------------------------------------------------------------- 1 | inference: 2 | inference_runner: fse_inverter_inference_runner 3 | editings_data: { 4 | "age": [-7, -5, -3, 3, 5, 7, 10], 5 | "fs_makeup": [5, 10, 15], 6 | "afro": [0.03, 0.07, 0.085, 0.1], 7 | "angry": [0.07, 0.1, 0.12], 8 | "purple_hair": [0.07, 0.1, 0.12], 9 | "glasses": [-7, 5], 10 | "face_roundness": [-17, -12, -8, 8, 12, 17], 11 | "rotation": [-5.0, -3.0, -1.0, 1.0, 3.0, 5.0], 12 | "bobcut": [0.07, 0.12, 0.18], 13 | "bowlcut": [0.07, 0.14], 14 | "mohawk": [0.07, 0.10], 15 | "blond hair ": [-8, -4, 4, 8], 16 | "fs_smiling": [-9, -6, -3, 3, 6, 9] 17 | } 18 | 19 | 20 | model: 21 | method: fse_inverter 22 | device: "0" 23 | batch_size: 8 24 | -------------------------------------------------------------------------------- /configs/fse_inverter_train.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | train_runner: fse_inverter 3 | train_dis: True 4 | dis_train_start_step: 45000 5 | 6 | 7 | model: 8 | method: fse_inverter 9 | device: "0" 10 | batch_size: 8 11 | 12 | 13 | encoder_losses: 14 | l2: 1.0 15 | lpips: 0.8 16 | id: 0.1 17 | adv: 0.01 18 | feat_rec: 0.01 19 | 20 | 21 | disc_losses: 22 | main: 23 | coef: 1 24 | r1: 25 | coef: 10 26 | 27 | 28 | optimizers: 29 | ranger: 30 | lr: 0.0001 -------------------------------------------------------------------------------- /configs/paths.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, asdict, fields 2 | 3 | 4 | models_dir = "pretrained_models/" 5 | 6 | 7 | @dataclass 8 | class DefaultPathsClass: 9 | psp_path: str = models_dir + "psp_ffhq_encode.pt" 10 | e4e_path: str = models_dir + "e4e_ffhq_encode.pt" 11 | farl_path:str = models_dir + "face_parsing.farl.lapa.main_ema_136500_jit191.pt" 12 | mobile_net_pth: str = models_dir + "mobilenet0.25_Final.pth" 13 | ir_se50_path: str = models_dir + "model_ir_se50.pth" 14 | stylegan_weights: str = models_dir + "stylegan2-ffhq-config-f.pt" 15 | stylegan_car_weights: str = models_dir + "stylegan2-car-config-f-new.pkl" 16 | stylegan_weights_pkl: str = models_dir + "stylegan2-ffhq-config-f.pkl" 17 | arcface_model_path: str = models_dir + "iresnet50-7f187506.pth" 18 | moco: str = models_dir + "moco_v2_800ep_pretrain.pt" 19 | curricular_face_path: str = models_dir + "CurricularFace_Backbone.pth" 20 | mtcnn: str = models_dir + "mtcnn" 21 | landmark: str = models_dir + "79999_iter.pth" 22 | 23 | def __iter__(self): 24 | for field in fields(self): 25 | yield field.name, getattr(self, field.name) 26 | 27 | 28 | DefaultPaths = DefaultPathsClass() 29 | -------------------------------------------------------------------------------- /configs/simple_inference.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | domain: human_faces 3 | 4 | model: 5 | method: fse_full 6 | device: "0" 7 | batch_size: 1 8 | workers: 1 9 | 10 | methods_args: 11 | fse_full: 12 | -------------------------------------------------------------------------------- /criteria/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/criteria/__init__.py -------------------------------------------------------------------------------- /criteria/id_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from configs.paths import DefaultPaths 4 | from models.psp.encoders.model_irse import Backbone 5 | 6 | 7 | class IDLoss(nn.Module): 8 | def __init__(self): 9 | super(IDLoss, self).__init__() 10 | print("Loading ResNet ArcFace") 11 | self.facenet = Backbone( 12 | input_size=112, num_layers=50, drop_ratio=0.6, mode="ir_se" 13 | ) 14 | self.facenet.load_state_dict(torch.load(DefaultPaths.ir_se50_path)) 15 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 16 | self.facenet = self.facenet.cuda().eval() 17 | 18 | def extract_feats(self, x): 19 | x = x[:, :, 35:223, 32:220] # Crop interesting region 20 | x = self.face_pool(x) 21 | x_feats = self.facenet(x.cuda()) 22 | return x_feats 23 | 24 | 25 | def forward(self, y_hat, y): 26 | n_samples = y.shape[0] 27 | y_feats = self.extract_feats(y) 28 | y_hat_feats = self.extract_feats(y_hat) 29 | 30 | y_feats = y_feats.detach() 31 | loss = 0 32 | count = 0 33 | for i in range(n_samples): 34 | diff_target = y_hat_feats[i].dot(y_feats[i]) 35 | loss += 1 - diff_target 36 | count += 1 37 | 38 | return loss / count 39 | -------------------------------------------------------------------------------- /criteria/lpips/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/criteria/lpips/__init__.py -------------------------------------------------------------------------------- /criteria/lpips/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from criteria.lpips.networks import get_network, LinLayers 5 | from criteria.lpips.utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | Arguments: 12 | net_type (str): the network type to compare the features: 13 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 14 | version (str): the version of LPIPS. Default: 0.1. 15 | """ 16 | 17 | def __init__(self, net_type: str = "vgg", version: str = "0.1"): 18 | assert version in ["0.1"], "v0.1 is only supported now" 19 | 20 | super(LPIPS, self).__init__() 21 | 22 | # pretrained network 23 | self.net = get_network(net_type).to("cuda") 24 | 25 | # linear layers 26 | self.lin = LinLayers(self.net.n_channels_list).to("cuda") 27 | self.lin.load_state_dict(get_state_dict(net_type, version)) 28 | 29 | def forward(self, x: torch.Tensor, y: torch.Tensor): 30 | feat_x, feat_y = self.net(x), self.net(y) 31 | 32 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 33 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 34 | 35 | return torch.sum(torch.cat(res, 0)) / x.shape[0] 36 | -------------------------------------------------------------------------------- /criteria/lpips/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from criteria.lpips.utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == "alex": 14 | return AlexNet() 15 | elif net_type == "squeeze": 16 | return SqueezeNet() 17 | elif net_type == "vgg": 18 | return VGG16() 19 | else: 20 | raise NotImplementedError("choose net_type from [alex, squeeze, vgg].") 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__( 26 | [ 27 | nn.Sequential(nn.Identity(), nn.Conv2d(nc, 1, 1, 1, 0, bias=False)) 28 | for nc in n_channels_list 29 | ] 30 | ) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | "mean", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] 43 | ) 44 | self.register_buffer( 45 | "std", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] 46 | ) 47 | 48 | def set_requires_grad(self, state: bool): 49 | for param in chain(self.parameters(), self.buffers()): 50 | param.requires_grad = state 51 | 52 | def z_score(self, x: torch.Tensor): 53 | return (x - self.mean) / self.std 54 | 55 | def forward(self, x: torch.Tensor): 56 | x = self.z_score(x) 57 | 58 | output = [] 59 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 60 | x = layer(x) 61 | if i in self.target_layers: 62 | output.append(normalize_activation(x)) 63 | if len(output) == len(self.target_layers): 64 | break 65 | return output 66 | 67 | 68 | class SqueezeNet(BaseNet): 69 | def __init__(self): 70 | super(SqueezeNet, self).__init__() 71 | 72 | self.layers = models.squeezenet1_1(True).features 73 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 74 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 75 | 76 | self.set_requires_grad(False) 77 | 78 | 79 | class AlexNet(BaseNet): 80 | def __init__(self): 81 | super(AlexNet, self).__init__() 82 | 83 | self.layers = models.alexnet(True).features 84 | self.target_layers = [2, 5, 8, 10, 12] 85 | self.n_channels_list = [64, 192, 384, 256, 256] 86 | 87 | self.set_requires_grad(False) 88 | 89 | 90 | class VGG16(BaseNet): 91 | def __init__(self): 92 | super(VGG16, self).__init__() 93 | 94 | self.layers = models.vgg16(True).features 95 | self.target_layers = [4, 9, 16, 23, 30] 96 | self.n_channels_list = [64, 128, 256, 512, 512] 97 | 98 | self.set_requires_grad(False) 99 | -------------------------------------------------------------------------------- /criteria/lpips/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = "alex", version: str = "0.1"): 12 | # build url 13 | url = ( 14 | "https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/" 15 | + f"master/lpips/weights/v{version}/{net_type}.pth" 16 | ) 17 | 18 | # download 19 | old_state_dict = torch.hub.load_state_dict_from_url( 20 | url, 21 | progress=True, 22 | map_location=None if torch.cuda.is_available() else torch.device("cpu"), 23 | ) 24 | 25 | # rename keys 26 | new_state_dict = OrderedDict() 27 | for key, val in old_state_dict.items(): 28 | new_key = key 29 | new_key = new_key.replace("lin", "") 30 | new_key = new_key.replace("model.", "") 31 | new_state_dict[new_key] = val 32 | 33 | return new_state_dict 34 | -------------------------------------------------------------------------------- /criteria/moco_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from configs.paths import DefaultPaths 5 | 6 | 7 | class MocoLoss(nn.Module): 8 | def __init__(self): 9 | super(MocoLoss, self).__init__() 10 | print("Loading MOCO model from path: {}".format(DefaultPaths.moco)) 11 | self.model = self.__load_model() 12 | self.model.cuda() 13 | self.model.eval() 14 | 15 | @staticmethod 16 | def __load_model(): 17 | import torchvision.models as models 18 | 19 | model = models.__dict__["resnet50"]() 20 | # freeze all layers but the last fc 21 | for name, param in model.named_parameters(): 22 | if name not in ["fc.weight", "fc.bias"]: 23 | param.requires_grad = False 24 | checkpoint = torch.load(DefaultPaths.moco, map_location="cpu") 25 | state_dict = checkpoint["state_dict"] 26 | # rename moco pre-trained keys 27 | for k in list(state_dict.keys()): 28 | # retain only encoder_q up to before the embedding layer 29 | if k.startswith("module.encoder_q") and not k.startswith( 30 | "module.encoder_q.fc" 31 | ): 32 | # remove prefix 33 | state_dict[k[len("module.encoder_q.") :]] = state_dict[k] 34 | # delete renamed or unused k 35 | del state_dict[k] 36 | msg = model.load_state_dict(state_dict, strict=False) 37 | assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} 38 | # remove output layer 39 | model = nn.Sequential(*list(model.children())[:-1]).cuda() 40 | return model 41 | 42 | def extract_feats(self, x): 43 | x = F.interpolate(x, size=224) 44 | x_feats = self.model(x) 45 | x_feats = nn.functional.normalize(x_feats, dim=1) 46 | x_feats = x_feats.squeeze() 47 | return x_feats 48 | 49 | 50 | def forward(self, y_hat, y): 51 | n_samples = y.shape[0] 52 | y_feats = self.extract_feats(y) 53 | y_hat_feats = self.extract_feats(y_hat) 54 | y_feats = y_feats.detach() 55 | loss = 0 56 | count = 0 57 | for i in range(n_samples): 58 | diff_target = y_hat_feats[i].dot(y_feats[i]) 59 | loss += 1 - diff_target 60 | count += 1 61 | 62 | return loss / count 63 | -------------------------------------------------------------------------------- /criteria/ms_ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from math import exp 4 | import numpy as np 5 | 6 | """ 7 | Taken from https://github.com/jorge-pessoa/pytorch-msssim 8 | """ 9 | 10 | 11 | def gaussian(window_size, sigma): 12 | gauss = torch.Tensor( 13 | [ 14 | exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) 15 | for x in range(window_size) 16 | ] 17 | ) 18 | return gauss / gauss.sum() 19 | 20 | 21 | def create_window(window_size, channel=1): 22 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 23 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 24 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 25 | return window 26 | 27 | 28 | def ssim( 29 | img1, 30 | img2, 31 | window_size=11, 32 | window=None, 33 | size_average=True, 34 | full=False, 35 | val_range=None, 36 | ): 37 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). 38 | if val_range is None: 39 | if torch.max(img1) > 128: 40 | max_val = 255 41 | else: 42 | max_val = 1 43 | 44 | if torch.min(img1) < -0.5: 45 | min_val = -1 46 | else: 47 | min_val = 0 48 | L = max_val - min_val 49 | else: 50 | L = val_range 51 | 52 | padd = 0 53 | (_, channel, height, width) = img1.size() 54 | if window is None: 55 | real_size = min(window_size, height, width) 56 | window = create_window(real_size, channel=channel).to(img1.device) 57 | 58 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel) 59 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel) 60 | 61 | mu1_sq = mu1.pow(2) 62 | mu2_sq = mu2.pow(2) 63 | mu1_mu2 = mu1 * mu2 64 | 65 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq 66 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq 67 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 68 | 69 | C1 = (0.01 * L) ** 2 70 | C2 = (0.03 * L) ** 2 71 | 72 | v1 = 2.0 * sigma12 + C2 73 | v2 = sigma1_sq + sigma2_sq + C2 74 | cs = v1 / v2 # contrast sensitivity 75 | 76 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 77 | 78 | if size_average: 79 | cs = cs.mean() 80 | ret = ssim_map.mean() 81 | else: 82 | cs = cs.mean(1).mean(1).mean(1) 83 | ret = ssim_map.mean(1).mean(1).mean(1) 84 | 85 | if full: 86 | return ret, cs 87 | return ret 88 | 89 | 90 | def msssim( 91 | img1, img2, window_size=11, size_average=True, val_range=None, normalize=None 92 | ): 93 | device = img1.device 94 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) 95 | levels = weights.size()[0] 96 | ssims = [] 97 | mcs = [] 98 | for _ in range(levels): 99 | sim, cs = ssim( 100 | img1, 101 | img2, 102 | window_size=window_size, 103 | size_average=size_average, 104 | full=True, 105 | val_range=val_range, 106 | ) 107 | 108 | # Relu normalize (not compliant with original definition) 109 | if normalize == "relu": 110 | ssims.append(torch.relu(sim)) 111 | mcs.append(torch.relu(cs)) 112 | else: 113 | ssims.append(sim) 114 | mcs.append(cs) 115 | 116 | img1 = F.avg_pool2d(img1, (2, 2)) 117 | img2 = F.avg_pool2d(img2, (2, 2)) 118 | 119 | ssims = torch.stack(ssims) 120 | mcs = torch.stack(mcs) 121 | 122 | # Simple normalize (not compliant with original definition) 123 | if normalize == "simple" or normalize == True: 124 | ssims = (ssims + 1) / 2 125 | mcs = (mcs + 1) / 2 126 | 127 | pow1 = mcs**weights 128 | pow2 = ssims**weights 129 | 130 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ 131 | output = torch.prod(pow1[:-1]) * pow2[-1] 132 | return output 133 | 134 | 135 | # Classes to re-use window 136 | class SSIM(torch.nn.Module): 137 | def __init__(self, window_size=11, size_average=True, val_range=None): 138 | super(SSIM, self).__init__() 139 | self.window_size = window_size 140 | self.size_average = size_average 141 | self.val_range = val_range 142 | 143 | # Assume 1 channel for SSIM 144 | self.channel = 1 145 | self.window = create_window(window_size) 146 | 147 | def forward(self, img1, img2): 148 | (_, channel, _, _) = img1.size() 149 | 150 | if channel == self.channel and self.window.dtype == img1.dtype: 151 | window = self.window 152 | else: 153 | window = ( 154 | create_window(self.window_size, channel) 155 | .to(img1.device) 156 | .type(img1.dtype) 157 | ) 158 | self.window = window 159 | self.channel = channel 160 | 161 | return ssim( 162 | img1, 163 | img2, 164 | window=window, 165 | window_size=self.window_size, 166 | size_average=self.size_average, 167 | ) 168 | 169 | 170 | class MSSSIM(torch.nn.Module): 171 | def __init__(self, window_size=11, size_average=True, channel=3): 172 | super(MSSSIM, self).__init__() 173 | self.window_size = window_size 174 | self.size_average = size_average 175 | self.channel = channel 176 | 177 | def forward(self, img1, img2): 178 | return msssim( 179 | img1, img2, window_size=self.window_size, size_average=self.size_average 180 | ) 181 | -------------------------------------------------------------------------------- /criteria/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.model_zoo as modelzoo 5 | 6 | import os 7 | 8 | 9 | # from modules.bn import InPlaceABNSync as BatchNorm2d 10 | 11 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | def __init__(self, in_chan, out_chan, stride=1): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(in_chan, out_chan, stride) 24 | self.bn1 = nn.BatchNorm2d(out_chan) 25 | self.conv2 = conv3x3(out_chan, out_chan) 26 | self.bn2 = nn.BatchNorm2d(out_chan) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.downsample = None 29 | if in_chan != out_chan or stride != 1: 30 | self.downsample = nn.Sequential( 31 | nn.Conv2d(in_chan, out_chan, 32 | kernel_size=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(out_chan), 34 | ) 35 | 36 | def forward(self, x): 37 | residual = self.conv1(x) 38 | residual = F.relu(self.bn1(residual)) 39 | residual = self.conv2(residual) 40 | residual = self.bn2(residual) 41 | 42 | shortcut = x 43 | if self.downsample is not None: 44 | shortcut = self.downsample(x) 45 | 46 | out = shortcut + residual 47 | out = self.relu(out) 48 | return out 49 | 50 | 51 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 52 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 53 | for i in range(bnum-1): 54 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 55 | return nn.Sequential(*layers) 56 | 57 | 58 | class Resnet18(nn.Module): 59 | def __init__(self): 60 | super(Resnet18, self).__init__() 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = nn.BatchNorm2d(64) 64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 65 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 66 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 67 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 68 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 69 | self.init_weight() 70 | 71 | def forward(self, x): 72 | x = self.conv1(x) 73 | x = F.relu(self.bn1(x)) 74 | x = self.maxpool(x) 75 | 76 | x = self.layer1(x) 77 | feat8 = self.layer2(x) # 1/8 78 | feat16 = self.layer3(feat8) # 1/16 79 | feat32 = self.layer4(feat16) # 1/32 80 | return feat8, feat16, feat32 81 | 82 | def init_weight(self): 83 | #state_dict = modelzoo.load_url(resnet18_url) 84 | state_dict = torch.load('pretrained_models/resnet18-5c106cde.pth') 85 | self_state_dict = self.state_dict() 86 | for k, v in state_dict.items(): 87 | if 'fc' in k: continue 88 | self_state_dict.update({k: v}) 89 | self.load_state_dict(self_state_dict) 90 | 91 | def get_params(self): 92 | wd_params, nowd_params = [], [] 93 | for name, module in self.named_modules(): 94 | if isinstance(module, (nn.Linear, nn.Conv2d)): 95 | wd_params.append(module.weight) 96 | if not module.bias is None: 97 | nowd_params.append(module.bias) 98 | elif isinstance(module, nn.BatchNorm2d): 99 | nowd_params += list(module.parameters()) 100 | return wd_params, nowd_params 101 | -------------------------------------------------------------------------------- /criteria/w_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class WNormLoss(nn.Module): 6 | def __init__(self, start_from_latent_avg=True): 7 | super(WNormLoss, self).__init__() 8 | self.start_from_latent_avg = start_from_latent_avg 9 | 10 | def forward(self, latent, latent_avg=None): 11 | if self.start_from_latent_avg: 12 | latent = latent - latent_avg 13 | return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0] 14 | -------------------------------------------------------------------------------- /datasets/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.data import Dataset 3 | from PIL import Image 4 | from utils import data_utils 5 | from torchvision import transforms 6 | 7 | 8 | class ImageDataset(Dataset): 9 | def __init__(self, root, transform=None): 10 | self.paths = sorted(data_utils.make_dataset(root)) 11 | self.transform = transform 12 | 13 | def __len__(self): 14 | return len(self.paths) 15 | 16 | def __getitem__(self, index): 17 | path = self.paths[index] 18 | image = Image.open(path).convert("RGB") 19 | 20 | if self.transform: 21 | image = self.transform(image) 22 | return image 23 | 24 | 25 | class CelebaAttributeDataset(Dataset): 26 | def __init__(self, images_root, attr, transform=None, attributes_root="", use_attr=True): 27 | self.paths = data_utils.make_dataset(images_root) 28 | self.transform = transform 29 | with open(attributes_root, "r") as f: 30 | lines = f.readlines() 31 | 32 | attr_num = -1 33 | for i, data_attr in enumerate(lines[1].split(" ")): 34 | if data_attr.strip() == attr.strip(): 35 | attr_num = i 36 | break 37 | assert attr_num > -1, f"Can not find attribute {attr}" 38 | 39 | filtred_paths = [] 40 | for path in self.paths: 41 | pic_num = int(path.split("/")[-1].replace(".jpg", "").replace(".png", "")) 42 | pic_attrs = lines[pic_num + 2].strip().split(" ") 43 | pic_attrs = pic_attrs[2:] 44 | if use_attr and pic_attrs[attr_num] == "1" or not use_attr and pic_attrs[attr_num] == "-1": 45 | filtred_paths.append(path) 46 | self.paths = sorted(filtred_paths) 47 | 48 | def __len__(self): 49 | return len(self.paths) 50 | 51 | def __getitem__(self, index): 52 | from_path = self.paths[index] 53 | image = Image.open(from_path).convert("RGB") 54 | 55 | if self.transform: 56 | image = self.transform(image) 57 | return image 58 | 59 | 60 | class FIDDataset(Dataset): 61 | def __init__(self, files, transforms=None): 62 | self.files = files 63 | self.transforms = transforms 64 | 65 | def __len__(self): 66 | return len(self.files) 67 | 68 | def __getitem__(self, i): 69 | file = self.files[i] 70 | image = file.convert("RGB") 71 | 72 | if self.transforms is not None: 73 | image = self.transforms(image) 74 | 75 | return image 76 | 77 | 78 | class MetricsPathsDataset(Dataset): 79 | def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None, return_path=False, ignore=[]): 80 | self.pairs = [] 81 | self.paths = [] 82 | self.names = [] 83 | 84 | for f in os.listdir(root_path): 85 | if f not in ignore: 86 | self.names.append(f) 87 | image_path = os.path.join(root_path, f) 88 | gt_path = os.path.join(gt_dir, f) 89 | if f.endswith(".jpg") or f.endswith(".png"): 90 | self.pairs.append([image_path, gt_path.replace(".png", ".jpg"), None]) 91 | self.paths.append(image_path) 92 | self.transform = transform 93 | self.transform_train = transform_train 94 | self.return_path = return_path 95 | 96 | def __len__(self): 97 | return len(self.pairs) 98 | 99 | def __getitem__(self, index): 100 | from_path, to_path, _ = self.pairs[index] 101 | from_im = Image.open(from_path).convert("RGB") 102 | to_im = Image.open(to_path).convert("RGB") 103 | 104 | if self.transform: 105 | to_im = self.transform(to_im) 106 | from_im = self.transform(from_im) 107 | 108 | if not self.return_path: 109 | return from_im, to_im 110 | else: 111 | return from_im, to_im, self.names[index] 112 | 113 | 114 | class MetricsDataDataset(Dataset): 115 | def __init__( 116 | self, paths, target_data, fake_data, transform=None, transform_train=None 117 | ): 118 | self.fake_data = fake_data 119 | self.target_data = target_data 120 | self.paths = paths 121 | self.transform = transform 122 | self.transform_train = transform_train 123 | 124 | def __len__(self): 125 | return len(self.fake_data) 126 | 127 | def __getitem__(self, index): 128 | 129 | target_im = self.target_data[index] 130 | fake_im = self.fake_data[index] 131 | 132 | if self.transform: 133 | fake_im = self.transform(fake_im) 134 | target_im = self.transform(target_im) 135 | 136 | return target_im, fake_im 137 | -------------------------------------------------------------------------------- /datasets/loaders.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | 4 | class InfiniteLoader(DataLoader): 5 | def __init__( 6 | self, 7 | *args, 8 | num_workers=0, 9 | pin_memory=True, 10 | is_infinite = True, 11 | **kwargs, 12 | ): 13 | super().__init__( 14 | *args, 15 | multiprocessing_context="fork" if num_workers > 0 else None, 16 | num_workers=num_workers, 17 | pin_memory=pin_memory, 18 | **kwargs, 19 | ) 20 | self.dataset_iterator = super().__iter__() 21 | self.is_infinite = is_infinite 22 | 23 | def __iter__(self): 24 | return self 25 | 26 | def __next__(self): 27 | try: 28 | x = next(self.dataset_iterator) 29 | except StopIteration: 30 | self.dataset_iterator = super().__iter__() 31 | if self.is_infinite: 32 | x = next(self.dataset_iterator) 33 | else: 34 | raise StopIteration 35 | 36 | return x 37 | -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import torchvision.transforms as transforms 3 | from utils.class_registry import ClassRegistry 4 | 5 | transforms_registry = ClassRegistry() 6 | 7 | 8 | class TransformsConfig(object): 9 | def __init__(self): 10 | pass 11 | 12 | @abstractmethod 13 | def get_transforms(self): 14 | pass 15 | 16 | class FaceTransforms(TransformsConfig): 17 | def __init__(self): 18 | super(FaceTransforms, self).__init__() 19 | self.image_size = None 20 | 21 | def get_transforms(self): 22 | transforms_dict = { 23 | "train": transforms.Compose( 24 | [ 25 | transforms.Resize(self.image_size), 26 | transforms.RandomHorizontalFlip(0.5), 27 | transforms.ToTensor(), 28 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 29 | ] 30 | ), 31 | "test": transforms.Compose( 32 | [ 33 | transforms.Resize(self.image_size), 34 | transforms.ToTensor(), 35 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 36 | ] 37 | ) 38 | } 39 | return transforms_dict 40 | 41 | 42 | @transforms_registry.add_to_registry(name="face_256") 43 | class Face256Transforms(FaceTransforms): 44 | def __init__(self): 45 | super(Face256Transforms, self).__init__() 46 | self.image_size = (256, 256) 47 | 48 | 49 | @transforms_registry.add_to_registry(name="face_1024") 50 | class Face1024Transforms(FaceTransforms): 51 | def __init__(self): 52 | super(Face1024Transforms, self).__init__() 53 | self.image_size = (1024, 1024) 54 | 55 | 56 | -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | 11 | 12 | __all__ = ["EasyDict", "make_cache_dir_path"] 13 | -------------------------------------------------------------------------------- /dnnlib/tflib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | from . import autosummary 8 | from . import network 9 | from . import optimizer 10 | from . import tfutil 11 | from . import custom_ops 12 | 13 | from .tfutil import * 14 | from .network import Network 15 | 16 | from .optimizer import Optimizer 17 | 18 | from .custom_ops import get_plugin 19 | -------------------------------------------------------------------------------- /dnnlib/tflib/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | # empty 8 | -------------------------------------------------------------------------------- /editings/bound/Eyeglasses_boundary.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/editings/bound/Eyeglasses_boundary.npy -------------------------------------------------------------------------------- /editings/bound/Heavy_Makeup_boundary.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/editings/bound/Heavy_Makeup_boundary.npy -------------------------------------------------------------------------------- /editings/bound/Smiling_boundary.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/editings/bound/Smiling_boundary.npy -------------------------------------------------------------------------------- /editings/deltaedit/delta_mapper.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import Module 6 | import torch.nn.functional as F 7 | 8 | from models.psp.stylegan2.model import EqualLinear, PixelNorm 9 | 10 | class Mapper(Module): 11 | 12 | def __init__(self, in_channel=512, out_channel=512, norm=True, num_layers=4): 13 | super(Mapper, self).__init__() 14 | 15 | layers = [PixelNorm()] if norm else [] 16 | 17 | layers.append(EqualLinear(in_channel, out_channel, lr_mul=0.01, activation='fused_lrelu')) 18 | for _ in range(num_layers-1): 19 | layers.append(EqualLinear(out_channel, out_channel, lr_mul=0.01, activation='fused_lrelu')) 20 | self.mapping = nn.Sequential(*layers) 21 | 22 | def forward(self, x): 23 | x = self.mapping(x) 24 | return x 25 | 26 | class DeltaMapper(Module): 27 | 28 | def __init__(self): 29 | super(DeltaMapper, self).__init__() 30 | 31 | #Style Module(sm) 32 | self.sm_coarse = Mapper(512, 512) 33 | self.sm_medium = Mapper(512, 512) 34 | self.sm_fine = Mapper(2464, 2464) 35 | 36 | #Condition Module(cm) 37 | self.cm_coarse = Mapper(1024, 512) 38 | self.cm_medium = Mapper(1024, 512) 39 | self.cm_fine = Mapper(1024, 2464) 40 | 41 | #Fusion Module(fm) 42 | self.fm_coarse = Mapper(512*2, 512, norm=False) 43 | self.fm_medium = Mapper(512*2, 512, norm=False) 44 | self.fm_fine = Mapper(2464*2, 2464, norm=False) 45 | 46 | def forward(self, sspace_feat, clip_feat): 47 | 48 | s_coarse = sspace_feat[:, :3*512].view(-1,3,512) 49 | s_medium = sspace_feat[:, 3*512:7*512].view(-1,4,512) 50 | s_fine = sspace_feat[:, 7*512:] #channels:2464 51 | 52 | s_coarse = self.sm_coarse(s_coarse) 53 | s_medium = self.sm_medium(s_medium) 54 | s_fine = self.sm_fine(s_fine) 55 | 56 | c_coarse = self.cm_coarse(clip_feat) 57 | c_medium = self.cm_medium(clip_feat) 58 | c_fine = self.cm_fine(clip_feat) 59 | 60 | x_coarse = torch.cat([s_coarse, torch.stack([c_coarse]*3, dim=1)], dim=2) #[b,3,1024] 61 | x_medium = torch.cat([s_medium, torch.stack([c_medium]*4, dim=1)], dim=2) #[b,4,1024] 62 | x_fine = torch.cat([s_fine, c_fine], dim=1) #[b,2464*2] 63 | 64 | x_coarse = self.fm_coarse(x_coarse) 65 | x_coarse = x_coarse.view(-1,3*512) 66 | 67 | x_medium = self.fm_medium(x_medium) 68 | x_medium = x_medium.view(-1,4*512) 69 | 70 | x_fine = self.fm_fine(x_fine) 71 | 72 | out = torch.cat([x_coarse, x_medium, x_fine], dim=1) 73 | return out -------------------------------------------------------------------------------- /editings/deltaedit/editor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import clip 3 | import copy 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | from editings.deltaedit import map_tool 8 | from editings.deltaedit.delta_mapper import DeltaMapper 9 | 10 | 11 | STYLE_DIM = [512] * 10 + [256, 256, 128, 128, 64, 64, 32] 12 | 13 | 14 | def GetBoundary(fs3, dt, threshold): 15 | tmp = np.dot(fs3, dt) 16 | select = np.abs(tmp) < threshold 17 | return select 18 | 19 | def improved_ds(ds, select): 20 | ds_imp = copy.copy(ds) 21 | ds_imp[select] = 0 22 | ds_imp = ds_imp.unsqueeze(0) 23 | return ds_imp 24 | 25 | 26 | class DeltaEditor: 27 | def __init__(self): 28 | device = "cuda" 29 | self.fs3 = np.load("pretrained_models/fs3.npy") 30 | np.set_printoptions(suppress=True) 31 | 32 | self.net = DeltaMapper() 33 | net_ckpt = torch.load("pretrained_models/delta_mapper.pt") 34 | self.net.load_state_dict(net_ckpt) 35 | self.net = self.net.to(device).eval() 36 | 37 | self.clip_model, self.preprocess = clip.load("ViT-B/32", device=device) 38 | self.avg_pool = torch.nn.AdaptiveAvgPool2d((224, 224)) 39 | self.upsample = torch.nn.Upsample(scale_factor=7) 40 | 41 | def get_delta_s(self, neutral, target, trash, orig_image, start_s): 42 | with torch.no_grad(): 43 | classnames = [target, neutral] 44 | dt = map_tool.GetDt(classnames, self.clip_model) 45 | select = GetBoundary(self.fs3, dt, trash) 46 | dt = torch.Tensor(dt).cuda() 47 | dt = dt / dt.norm(dim=-1, keepdim=True).float().clamp(min=1e-5) 48 | 49 | img_gen_for_clip = self.avg_pool(orig_image) 50 | c_latents = self.clip_model.encode_image(img_gen_for_clip.cuda()) 51 | c_latents = c_latents / c_latents.norm(dim=-1, keepdim=True).float() 52 | 53 | delta_c = torch.cat((c_latents, dt.unsqueeze(0)), dim=1) 54 | fake_delta_s = self.net(torch.cat(start_s, dim=-1), delta_c) 55 | improved_fake_delta_s = improved_ds(fake_delta_s[0], select) 56 | return torch.split(improved_fake_delta_s, STYLE_DIM, dim=-1) 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /editings/deltaedit/map_tool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import clip 3 | import os 4 | import numpy as np 5 | 6 | imagenet_templates = [ 7 | 'a bad photo of a {}.', 8 | # 'a photo of many {}.', 9 | 'a sculpture of a {}.', 10 | 'a photo of the hard to see {}.', 11 | 'a low resolution photo of the {}.', 12 | 'a rendering of a {}.', 13 | 'graffiti of a {}.', 14 | 'a bad photo of the {}.', 15 | 'a cropped photo of the {}.', 16 | 'a tattoo of a {}.', 17 | 'the embroidered {}.', 18 | 'a photo of a hard to see {}.', 19 | 'a bright photo of a {}.', 20 | 'a photo of a clean {}.', 21 | 'a photo of a dirty {}.', 22 | 'a dark photo of the {}.', 23 | 'a drawing of a {}.', 24 | 'a photo of my {}.', 25 | 'the plastic {}.', 26 | 'a photo of the cool {}.', 27 | 'a close-up photo of a {}.', 28 | 'a black and white photo of the {}.', 29 | 'a painting of the {}.', 30 | 'a painting of a {}.', 31 | 'a pixelated photo of the {}.', 32 | 'a sculpture of the {}.', 33 | 'a bright photo of the {}.', 34 | 'a cropped photo of a {}.', 35 | 'a plastic {}.', 36 | 'a photo of the dirty {}.', 37 | 'a jpeg corrupted photo of a {}.', 38 | 'a blurry photo of the {}.', 39 | 'a photo of the {}.', 40 | 'a good photo of the {}.', 41 | 'a rendering of the {}.', 42 | 'a {} in a video game.', 43 | 'a photo of one {}.', 44 | 'a doodle of a {}.', 45 | 'a close-up photo of the {}.', 46 | 'a photo of a {}.', 47 | 'the origami {}.', 48 | 'the {} in a video game.', 49 | 'a sketch of a {}.', 50 | 'a doodle of the {}.', 51 | 'a origami {}.', 52 | 'a low resolution photo of a {}.', 53 | 'the toy {}.', 54 | 'a rendition of the {}.', 55 | 'a photo of the clean {}.', 56 | 'a photo of a large {}.', 57 | 'a rendition of a {}.', 58 | 'a photo of a nice {}.', 59 | 'a photo of a weird {}.', 60 | 'a blurry photo of a {}.', 61 | 'a cartoon {}.', 62 | 'art of a {}.', 63 | 'a sketch of the {}.', 64 | 'a embroidered {}.', 65 | 'a pixelated photo of a {}.', 66 | 'itap of the {}.', 67 | 'a jpeg corrupted photo of the {}.', 68 | 'a good photo of a {}.', 69 | 'a plushie {}.', 70 | 'a photo of the nice {}.', 71 | 'a photo of the small {}.', 72 | 'a photo of the weird {}.', 73 | 'the cartoon {}.', 74 | 'art of the {}.', 75 | 'a drawing of the {}.', 76 | 'a photo of the large {}.', 77 | 'a black and white photo of a {}.', 78 | 'the plushie {}.', 79 | 'a dark photo of a {}.', 80 | 'itap of a {}.', 81 | 'graffiti of the {}.', 82 | 'a toy {}.', 83 | 'itap of my {}.', 84 | 'a photo of a cool {}.', 85 | 'a photo of a small {}.', 86 | 'a tattoo of the {}.', 87 | ] 88 | 89 | def zeroshot_classifier(classnames, templates,model): 90 | with torch.no_grad(): 91 | zeroshot_weights = [] 92 | for classname in classnames: 93 | texts = [template.format(classname) for template in templates] #format with class 94 | texts = clip.tokenize(texts).cuda() #tokenize 95 | class_embeddings = model.encode_text(texts) #embed with text encoder 96 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 97 | class_embedding = class_embeddings.mean(dim=0) 98 | class_embedding /= class_embedding.norm() 99 | zeroshot_weights.append(class_embedding) 100 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda() 101 | return zeroshot_weights 102 | 103 | def GetDt(classnames,model): 104 | text_features=zeroshot_classifier(classnames, imagenet_templates,model).t() 105 | 106 | dt=text_features[0]-text_features[1] 107 | dt=dt.cpu().numpy() 108 | 109 | return dt 110 | -------------------------------------------------------------------------------- /editings/ganspace.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def edit(latents, pca, edit_directions): 5 | edit_latents = [] 6 | for latent in latents: 7 | for pca_idx, start, end, strength in edit_directions: 8 | delta = get_delta(pca, latent, pca_idx, strength) 9 | delta_padded = torch.zeros(latent.shape).to("cuda") 10 | delta_padded[start:end] += delta.repeat(end - start, 1) 11 | edit_latents.append(latent + delta_padded) 12 | return torch.stack(edit_latents) 13 | 14 | 15 | def get_delta(pca, latent, idx, strength): 16 | w_centered = latent - pca["mean"].to("cuda") 17 | lat_comp = pca["comp"].to("cuda") 18 | lat_std = pca["std"].to("cuda") 19 | w_coord = ( 20 | torch.sum(w_centered[0].reshape(-1) * lat_comp[idx].reshape(-1)) / lat_std[idx] 21 | ) 22 | delta = (strength - w_coord) * lat_comp[idx] * lat_std[idx] 23 | return delta 24 | 25 | 26 | def edit_latent(latent, pca, edit_direction): 27 | pca_idx, start, end, strength = edit_direction 28 | delta = get_delta(pca, latent, pca_idx, strength) 29 | delta_padded = torch.zeros(latent.shape).to("cuda") 30 | delta_padded[start:end] += delta.repeat(end - start, 1) 31 | edit_latent = latent + delta_padded 32 | return edit_latent 33 | -------------------------------------------------------------------------------- /editings/ganspace_pca/cars_pca.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/editings/ganspace_pca/cars_pca.pt -------------------------------------------------------------------------------- /editings/ganspace_pca/church_pca.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/editings/ganspace_pca/church_pca.pt -------------------------------------------------------------------------------- /editings/ganspace_pca/ffhq_pca.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/editings/ganspace_pca/ffhq_pca.pt -------------------------------------------------------------------------------- /editings/interfacegan_directions/age.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/editings/interfacegan_directions/age.pt -------------------------------------------------------------------------------- /editings/interfacegan_directions/rotation.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/editings/interfacegan_directions/rotation.pt -------------------------------------------------------------------------------- /editings/interfacegan_directions/smile.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/editings/interfacegan_directions/smile.pt -------------------------------------------------------------------------------- /editings/styleclip/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/editings/styleclip/__init__.py -------------------------------------------------------------------------------- /editings/styleclip/global_mapper_data/S_mean_std: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/editings/styleclip/global_mapper_data/S_mean_std -------------------------------------------------------------------------------- /editings/styleclip/global_mapper_data/delta_i_c.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/editings/styleclip/global_mapper_data/delta_i_c.npy -------------------------------------------------------------------------------- /editings/styleclip/global_mapper_data/templates.txt: -------------------------------------------------------------------------------- 1 | a bad photo of a {}. 2 | a sculpture of a {}. 3 | a photo of the hard to see {}. 4 | a low resolution photo of the {}. 5 | a rendering of a {}. 6 | graffiti of a {}. 7 | a bad photo of the {}. 8 | a cropped photo of the {}. 9 | a tattoo of a {}. 10 | the embroidered {}. 11 | a photo of a hard to see {}. 12 | a bright photo of a {}. 13 | a photo of a clean {}. 14 | a photo of a dirty {}. 15 | a dark photo of the {}. 16 | a drawing of a {}. 17 | a photo of my {}. 18 | the plastic {}. 19 | a photo of the cool {}. 20 | a close-up photo of a {}. 21 | a black and white photo of the {}. 22 | a painting of the {}. 23 | a painting of a {}. 24 | a pixelated photo of the {}. 25 | a sculpture of the {}. 26 | a bright photo of the {}. 27 | a cropped photo of a {}. 28 | a plastic {}. 29 | a photo of the dirty {}. 30 | a jpeg corrupted photo of a {}. 31 | a blurry photo of the {}. 32 | a photo of the {}. 33 | a good photo of the {}. 34 | a rendering of the {}. 35 | a {} in a video game. 36 | a photo of one {}. 37 | a doodle of a {}. 38 | a close-up photo of the {}. 39 | a photo of a {}. 40 | the origami {}. 41 | the {} in a video game. 42 | a sketch of a {}. 43 | a doodle of the {}. 44 | a origami {}. 45 | a low resolution photo of a {}. 46 | the toy {}. 47 | a rendition of the {}. 48 | a photo of the clean {}. 49 | a photo of a large {}. 50 | a rendition of a {}. 51 | a photo of a nice {}. 52 | a photo of a weird {}. 53 | a blurry photo of a {}. 54 | a cartoon {}. 55 | art of a {}. 56 | a sketch of the {}. 57 | a embroidered {}. 58 | a pixelated photo of a {}. 59 | itap of the {}. 60 | a jpeg corrupted photo of the {}. 61 | a good photo of a {}. 62 | a plushie {}. 63 | a photo of the nice {}. 64 | a photo of the small {}. 65 | a photo of the weird {}. 66 | the cartoon {}. 67 | art of the {}. 68 | a drawing of the {}. 69 | a photo of the large {}. 70 | a black and white photo of a {}. 71 | the plushie {}. 72 | a dark photo of a {}. 73 | itap of a {}. 74 | graffiti of the {}. 75 | a toy {}. 76 | itap of my {}. 77 | a photo of a cool {}. 78 | a photo of a small {}. 79 | a tattoo of the {}. -------------------------------------------------------------------------------- /editings/styleclip/mapper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/editings/styleclip/mapper/__init__.py -------------------------------------------------------------------------------- /editings/styleclip/mapper/gloabl_mapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import clip 3 | import copy 4 | 5 | """ 6 | Modified from HyperStyle repository 7 | https://github.com/yuval-alaluf/hyperstyle/blob/main/editing/styleclip/global_direction.py 8 | """ 9 | STYLESPACE_DIMENSIONS = [512 for _ in range(15)] + [256, 256, 256] + [128, 128, 128] + [64, 64, 64] + [32, 32] 10 | 11 | TORGB_INDICES = list(range(1, len(STYLESPACE_DIMENSIONS), 3)) 12 | STYLESPACE_INDICES_WITHOUT_TORGB = [i for i in range(len(STYLESPACE_DIMENSIONS)) if i not in TORGB_INDICES][:11] 13 | 14 | def features_channels_to_s(s_without_torgb, s_std): 15 | s = [] 16 | start_index_features = 0 17 | for c in range(len(STYLESPACE_DIMENSIONS)): 18 | if c in STYLESPACE_INDICES_WITHOUT_TORGB: 19 | end_index_features = start_index_features + STYLESPACE_DIMENSIONS[c] 20 | s_i = s_without_torgb[start_index_features:end_index_features] * s_std[c] 21 | start_index_features = end_index_features 22 | else: 23 | s_i = torch.zeros(STYLESPACE_DIMENSIONS[c]).cuda() 24 | s_i = s_i.view(1, 1, -1, 1, 1) 25 | s.append(s_i) 26 | return s 27 | 28 | class StyleCLIPGlobalDirection: 29 | 30 | def __init__(self, delta_i_c, s_std, text_prompts_templates): 31 | super(StyleCLIPGlobalDirection, self).__init__() 32 | self.delta_i_c = delta_i_c 33 | self.s_std = s_std 34 | self.text_prompts_templates = text_prompts_templates 35 | self.clip_model, _ = clip.load("ViT-B/32", device="cuda") 36 | 37 | def get_delta_s(self, neutral_text, target_text, beta): 38 | delta_i = self.get_delta_i([target_text, neutral_text]).float() 39 | r_c = torch.matmul(self.delta_i_c, delta_i) 40 | delta_s = copy.copy(r_c) 41 | channels_to_zero = torch.abs(r_c) < beta 42 | delta_s[channels_to_zero] = 0 43 | max_channel_value = torch.abs(delta_s).max() 44 | if max_channel_value > 0: 45 | delta_s /= max_channel_value 46 | direction = features_channels_to_s(delta_s, self.s_std) 47 | return direction 48 | 49 | def get_delta_i(self, text_prompts): 50 | try: # Check if loaded 51 | delta_i = getattr(self, f"{text_prompts[0]}_{text_prompts[1]}") 52 | except: 53 | text_features = self._get_averaged_text_features(text_prompts) 54 | delta_t = text_features[0] - text_features[1] 55 | delta_i = delta_t / torch.norm(delta_t) 56 | setattr(self, f"{text_prompts[0]}_{text_prompts[1]}", delta_i) 57 | return delta_i 58 | 59 | def _get_averaged_text_features(self, text_prompts): 60 | with torch.no_grad(): 61 | text_features_list = [] 62 | for text_prompt in text_prompts: 63 | formatted_text_prompts = [template.format(text_prompt) for template in self.text_prompts_templates] # format with class 64 | formatted_text_prompts = clip.tokenize(formatted_text_prompts).cuda() # tokenize 65 | text_embeddings = self.clip_model.encode_text(formatted_text_prompts) # embed with text encoder 66 | text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) 67 | text_embedding = text_embeddings.mean(dim=0) 68 | text_embedding /= text_embedding.norm() 69 | text_features_list.append(text_embedding) 70 | text_features = torch.stack(text_features_list, dim=1).cuda() 71 | return text_features.t() -------------------------------------------------------------------------------- /editings/styleclip/mapper/latent_mappers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from editings.styleclip.models.stylegan2.model import EqualLinear, PixelNorm 3 | from torch import nn 4 | from torch.nn import Module 5 | 6 | 7 | class Mapper(Module): 8 | def __init__(self, opts): 9 | super(Mapper, self).__init__() 10 | 11 | self.opts = opts 12 | layers = [PixelNorm()] 13 | 14 | for i in range(4): 15 | layers.append(EqualLinear(512, 512, lr_mul=0.01, activation="fused_lrelu")) 16 | 17 | self.mapping = nn.Sequential(*layers) 18 | 19 | def forward(self, x): 20 | x = self.mapping(x) 21 | return x 22 | 23 | 24 | class SingleMapper(Module): 25 | def __init__(self, opts): 26 | super(SingleMapper, self).__init__() 27 | 28 | self.opts = opts 29 | 30 | self.mapping = Mapper(opts) 31 | 32 | def forward(self, x): 33 | out = self.mapping(x) 34 | return out 35 | 36 | 37 | class LevelsMapper(Module): 38 | def __init__(self, opts): 39 | super(LevelsMapper, self).__init__() 40 | 41 | self.opts = opts 42 | 43 | if not opts.no_coarse_mapper: 44 | self.course_mapping = Mapper(opts) 45 | if not opts.no_medium_mapper: 46 | self.medium_mapping = Mapper(opts) 47 | if not opts.no_fine_mapper: 48 | self.fine_mapping = Mapper(opts) 49 | 50 | def forward(self, x): 51 | x_coarse = x[:, :4, :] 52 | x_medium = x[:, 4:8, :] 53 | x_fine = x[:, 8:, :] 54 | 55 | if not self.opts.no_coarse_mapper: 56 | x_coarse = self.course_mapping(x_coarse) 57 | else: 58 | x_coarse = torch.zeros_like(x_coarse) 59 | if not self.opts.no_medium_mapper: 60 | x_medium = self.medium_mapping(x_medium) 61 | else: 62 | x_medium = torch.zeros_like(x_medium) 63 | if not self.opts.no_fine_mapper: 64 | x_fine = self.fine_mapping(x_fine) 65 | else: 66 | x_fine = torch.zeros_like(x_fine) 67 | 68 | out = torch.cat([x_coarse, x_medium, x_fine], dim=1) 69 | 70 | return out 71 | -------------------------------------------------------------------------------- /editings/styleclip/mapper/styleclip_mapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from editings.styleclip.mapper import latent_mappers 3 | from torch import nn 4 | 5 | 6 | def get_keys(d, name): 7 | if "state_dict" in d: 8 | d = d["state_dict"] 9 | d_filt = {k[len(name) + 1 :]: v for k, v in d.items() if k[: len(name)] == name} 10 | return d_filt 11 | 12 | 13 | class StyleCLIPMapper(nn.Module): 14 | def __init__(self, opts): 15 | super(StyleCLIPMapper, self).__init__() 16 | self.opts = opts 17 | # Define architecture 18 | self.mapper = self.set_mapper() 19 | 20 | self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 21 | 22 | # Load weights if needed 23 | self.load_weights() 24 | 25 | def set_mapper(self): 26 | if self.opts.mapper_type == "SingleMapper": 27 | mapper = latent_mappers.SingleMapper(self.opts) 28 | elif self.opts.mapper_type == "LevelsMapper": 29 | mapper = latent_mappers.LevelsMapper(self.opts) 30 | else: 31 | raise Exception("{} is not a valid mapper".format(self.opts.mapper_type)) 32 | return mapper 33 | 34 | def load_weights(self): 35 | if self.opts.checkpoint_path is not None: 36 | ckpt = torch.load(self.opts.checkpoint_path, map_location="cpu") 37 | self.mapper.load_state_dict(get_keys(ckpt, "mapper"), strict=True) 38 | -------------------------------------------------------------------------------- /editings/styleclip/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/editings/styleclip/models/__init__.py -------------------------------------------------------------------------------- /editings/styleclip/models/stylegan2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/editings/styleclip/models/stylegan2/__init__.py -------------------------------------------------------------------------------- /editings/styleclip/models/stylegan2/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | 4 | 5 | __all__ = ["FusedLeakyReLU", "fused_leaky_relu", "upfirdn2d"] 6 | -------------------------------------------------------------------------------- /editings/styleclip/models/stylegan2/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | module_path = os.path.dirname(__file__) 9 | 10 | 11 | class FusedLeakyReLU(nn.Module): 12 | def __init__(self, channel, negative_slope=0.2, scale=2**0.5): 13 | super().__init__() 14 | 15 | self.bias = nn.Parameter(torch.zeros(channel)) 16 | self.negative_slope = negative_slope 17 | self.scale = scale 18 | 19 | def forward(self, input): 20 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 21 | 22 | 23 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): 24 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 25 | input = input.cuda() 26 | if input.ndim == 3: 27 | return ( 28 | F.leaky_relu( 29 | input + bias.view(1, *rest_dim, bias.shape[0]), 30 | negative_slope=negative_slope, 31 | ) 32 | * scale 33 | ) 34 | else: 35 | return ( 36 | F.leaky_relu( 37 | input + bias.view(1, bias.shape[0], *rest_dim), 38 | negative_slope=negative_slope, 39 | ) 40 | * scale 41 | ) 42 | -------------------------------------------------------------------------------- /editings/styleclip/models/stylegan2/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | module_path = os.path.dirname(__file__) 8 | 9 | 10 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 11 | out = upfirdn2d_native( 12 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 13 | ) 14 | 15 | return out 16 | 17 | 18 | def upfirdn2d_native( 19 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 20 | ): 21 | _, channel, in_h, in_w = input.shape 22 | input = input.reshape(-1, in_h, in_w, 1) 23 | 24 | _, in_h, in_w, minor = input.shape 25 | kernel_h, kernel_w = kernel.shape 26 | 27 | out = input.view(-1, in_h, 1, in_w, 1, minor) 28 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 29 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 30 | 31 | out = F.pad( 32 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 33 | ) 34 | out = out[ 35 | :, 36 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 37 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 38 | :, 39 | ] 40 | 41 | out = out.permute(0, 3, 1, 2) 42 | out = out.reshape( 43 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 44 | ) 45 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 46 | out = F.conv2d(out, w) 47 | out = out.reshape( 48 | -1, 49 | minor, 50 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 51 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 52 | ) 53 | out = out.permute(0, 2, 3, 1) 54 | out = out[:, ::down_y, ::down_x, :] 55 | 56 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 57 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 58 | 59 | return out.view(-1, channel, out_h, out_w) 60 | -------------------------------------------------------------------------------- /env_install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # conda create -n sfe python=3.10 4 | 5 | 6 | conda install ninja -y 7 | pip3 install torch==1.12.0+cu113 torchvision==0.13.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html 8 | pip install -r requirements.txt -------------------------------------------------------------------------------- /models/farl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/models/farl/__init__.py -------------------------------------------------------------------------------- /models/farl/farl.py: -------------------------------------------------------------------------------- 1 | import facer 2 | from facer.face_parsing import FaRLFaceParser 3 | from facer.face_detection import RetinaFaceDetector 4 | from facer.face_detection.retinaface import RetinaFace 5 | from configs.paths import DefaultPaths 6 | import torch.backends.cudnn as cudnn 7 | import numpy as np 8 | 9 | import torch 10 | from torch import nn 11 | 12 | import torchvision.transforms as transforms 13 | 14 | import torch.nn.functional as F 15 | 16 | def my_fp_init(self, model_path=DefaultPaths.farl_path): 17 | super(FaRLFaceParser, self).__init__() 18 | self.conf_name = 'lapa/448' 19 | self.net = torch.jit.load(model_path) 20 | self.eval() 21 | 22 | FaRLFaceParser.__init__ = my_fp_init 23 | 24 | def remove_prefix(state_dict, prefix): 25 | """ Old style model is stored with all names of parameters sharing common prefix 'module.' """ 26 | def f(x): return x.split(prefix, 1)[-1] if x.startswith(prefix) else x 27 | return {f(key): value for key, value in state_dict.items()} 28 | 29 | def check_keys(model, pretrained_state_dict): 30 | ckpt_keys = set(pretrained_state_dict.keys()) 31 | model_keys = set(model.state_dict().keys()) 32 | used_pretrained_keys = model_keys & ckpt_keys 33 | assert len(used_pretrained_keys) > 0, "load NONE from pretrained checkpoint" 34 | return True 35 | 36 | def load_model(model, pretrained_path, load_to_cpu, network: str): 37 | if load_to_cpu: 38 | pretrained_dict = torch.load( 39 | pretrained_path, map_location=lambda storage, loc: storage 40 | ) 41 | else: 42 | device = torch.cuda.current_device() 43 | pretrained_dict = torch.load( 44 | pretrained_path, map_location=lambda storage, loc: storage.cuda(device) 45 | ) 46 | if "state_dict" in pretrained_dict.keys(): 47 | pretrained_dict = remove_prefix( 48 | pretrained_dict["state_dict"], "module.") 49 | else: 50 | pretrained_dict = remove_prefix(pretrained_dict, "module.") 51 | check_keys(model, pretrained_dict) 52 | model.load_state_dict(pretrained_dict, strict=False) 53 | return model 54 | 55 | def load_net(model_path): 56 | cfg = { 57 | "name": "mobilenet0.25", 58 | "min_sizes": [[16, 32], [64, 128], [256, 512]], 59 | "steps": [8, 16, 32], 60 | "variance": [0.1, 0.2], 61 | "clip": False, 62 | "loc_weight": 2.0, 63 | "gpu_train": True, 64 | "batch_size": 32, 65 | "ngpu": 1, 66 | "epoch": 250, 67 | "decay1": 190, 68 | "decay2": 220, 69 | "image_size": 640, 70 | "pretrain": True, 71 | "return_layers": {"stage1": 1, "stage2": 2, "stage3": 3}, 72 | "in_channel": 32, 73 | "out_channel": 64, 74 | } 75 | # net and model 76 | net = RetinaFace(cfg=cfg, phase="test").cuda() 77 | net = load_model(net, model_path, True, network="mobilenet") 78 | net.eval() 79 | cudnn.benchmark = True 80 | # net = net.to(device) 81 | return net 82 | 83 | def my_fd_init(self, model_path=DefaultPaths.mobile_net_pth, trash=0.8): 84 | super(RetinaFaceDetector, self).__init__() 85 | self.conf_name = 'mobilenet' 86 | self.threshold=trash 87 | self.net = load_net(model_path) 88 | self.eval() 89 | 90 | RetinaFaceDetector.__init__ = my_fd_init 91 | 92 | class TargetMask(nn.Module): 93 | def __init__(self, tfm=True): 94 | super().__init__() 95 | self.face_detector = RetinaFaceDetector(trash=0.8).cuda().eval() 96 | self.face_parser = FaRLFaceParser().cuda().eval() 97 | self.to_farl = transforms.Compose( 98 | [ 99 | transforms.Normalize([0., 0., 0.], [2., 2., 2.]), 100 | transforms.Normalize([-0.5, -0.5, -0.5], [1., 1., 1.]), 101 | ] 102 | ) 103 | self.tfm = tfm 104 | self.sigm = torch.nn.Sigmoid() 105 | 106 | def get_u_idxs(self, all_indexes): 107 | res = [] 108 | for i in range(all_indexes[-1] + 1): 109 | res.append(((all_indexes == i).nonzero(as_tuple=True)[0][0])) 110 | return torch.tensor(res) 111 | 112 | def get_mask(self, y, threshold=0.5): 113 | #print(y.type(), y.shape, y.max(), y.min()) 114 | y = y.long() 115 | faces = self.face_detector(y) 116 | #print(len(faces['image_ids'])) 117 | faces = self.face_parser(y, faces) 118 | 119 | seg_logits = faces['seg']['logits'] 120 | seg_probs = seg_logits.softmax(dim=1) 121 | 122 | uniq_idx = self.get_u_idxs(faces['image_ids']) 123 | 124 | chroma_mask = (seg_probs[uniq_idx, 0, :, :] >= threshold).to(y.dtype).unsqueeze(1) 125 | return chroma_mask 126 | 127 | def forward(self, x, y): 128 | if self.tfm: 129 | mask_y = self.get_mask(255. * self.to_farl(y)) 130 | else: 131 | mask_y = self.get_mask(255. * y) 132 | return (1 - mask_y) * x + mask_y * y 133 | 134 | def forward2(self, x, y): 135 | batch = (255. * self.to_farl(y)).long() 136 | try: 137 | faces = self.face_detector(batch) 138 | assert len(faces['image_ids']) != 0 139 | except: 140 | for trash in [0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]: 141 | try: 142 | new_masker = Masker(trash=trash) 143 | faces = new_masker.face_detector(batch) 144 | assert len(faces['image_ids']) != 0 145 | break 146 | except: 147 | pass 148 | assert len(faces['image_ids']) != 0 149 | faces = self.face_parser(batch, faces) 150 | farl_mask = self.sigm(faces['seg']['logits'][:, 0]) 151 | farl_mask = (farl_mask >= 0.995).float()[0] 152 | 153 | return (1 - farl_mask) * x + farl_mask * y 154 | 155 | 156 | class Masker(nn.Module): 157 | def __init__(self, trash=0.8): 158 | super().__init__() 159 | self.face_detector = RetinaFaceDetector(trash=trash).cuda().eval() 160 | self.face_parser = FaRLFaceParser().cuda().eval() 161 | 162 | def get_u_idxs(self, all_indexes): 163 | res = [] 164 | for i in range(all_indexes[-1] + 1): 165 | res.append(((all_indexes == i).nonzero(as_tuple=True)[0][0])) 166 | return torch.tensor(res) 167 | 168 | def get_mask(self, y, threshold=0.5): 169 | faces = self.face_detector(y) 170 | faces = self.face_parser(y, faces) 171 | 172 | seg_logits = faces['seg']['logits'] 173 | seg_probs = seg_logits.softmax(dim=1) 174 | 175 | uniq_idx = self.get_u_idxs(faces['image_ids']) 176 | 177 | chroma_mask = (seg_probs[uniq_idx, 0, :, :] >= threshold).to(y.dtype).unsqueeze(1) 178 | return chroma_mask 179 | 180 | def forward(self, x): 181 | return self.get_mask(255. * x).repeat(1, 3, 1, 1) 182 | -------------------------------------------------------------------------------- /models/hyperinverter/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/models/hyperinverter/encoders/__init__.py -------------------------------------------------------------------------------- /models/hyperinverter/encoders/helpers.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch 4 | from torch.nn import ( 5 | AdaptiveAvgPool2d, 6 | BatchNorm2d, 7 | Conv2d, 8 | MaxPool2d, 9 | Module, 10 | PReLU, 11 | ReLU, 12 | Sequential, 13 | Sigmoid, 14 | ) 15 | 16 | 17 | """ 18 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 19 | """ 20 | 21 | 22 | class Flatten(Module): 23 | def forward(self, input): 24 | return input.view(input.size(0), -1) 25 | 26 | 27 | def l2_norm(input, axis=1): 28 | norm = torch.norm(input, 2, axis, True) 29 | output = torch.div(input, norm) 30 | return output 31 | 32 | 33 | class Bottleneck(namedtuple("Block", ["in_channel", "depth", "stride"])): 34 | """A named tuple describing a ResNet block.""" 35 | 36 | 37 | def get_block(in_channel, depth, num_units, stride=2): 38 | return [Bottleneck(in_channel, depth, stride)] + [ 39 | Bottleneck(depth, depth, 1) for i in range(num_units - 1) 40 | ] 41 | 42 | 43 | def get_blocks(num_layers): 44 | if num_layers == 50: 45 | blocks = [ 46 | get_block(in_channel=64, depth=64, num_units=3), 47 | get_block(in_channel=64, depth=128, num_units=4), 48 | get_block(in_channel=128, depth=256, num_units=14), 49 | get_block(in_channel=256, depth=512, num_units=3), 50 | ] 51 | elif num_layers == 100: 52 | blocks = [ 53 | get_block(in_channel=64, depth=64, num_units=3), 54 | get_block(in_channel=64, depth=128, num_units=13), 55 | get_block(in_channel=128, depth=256, num_units=30), 56 | get_block(in_channel=256, depth=512, num_units=3), 57 | ] 58 | elif num_layers == 152: 59 | blocks = [ 60 | get_block(in_channel=64, depth=64, num_units=3), 61 | get_block(in_channel=64, depth=128, num_units=8), 62 | get_block(in_channel=128, depth=256, num_units=36), 63 | get_block(in_channel=256, depth=512, num_units=3), 64 | ] 65 | else: 66 | raise ValueError( 67 | "Invalid number of layers: {}. Must be one of [50, 100, 152]".format( 68 | num_layers 69 | ) 70 | ) 71 | return blocks 72 | 73 | 74 | class SEModule(Module): 75 | def __init__(self, channels, reduction): 76 | super(SEModule, self).__init__() 77 | self.avg_pool = AdaptiveAvgPool2d(1) 78 | self.fc1 = Conv2d( 79 | channels, channels // reduction, kernel_size=1, padding=0, bias=False 80 | ) 81 | self.relu = ReLU(inplace=True) 82 | self.fc2 = Conv2d( 83 | channels // reduction, channels, kernel_size=1, padding=0, bias=False 84 | ) 85 | self.sigmoid = Sigmoid() 86 | 87 | def forward(self, x): 88 | module_input = x 89 | x = self.avg_pool(x) 90 | x = self.fc1(x) 91 | x = self.relu(x) 92 | x = self.fc2(x) 93 | x = self.sigmoid(x) 94 | return module_input * x 95 | 96 | 97 | class bottleneck_IR(Module): 98 | def __init__(self, in_channel, depth, stride): 99 | super(bottleneck_IR, self).__init__() 100 | if in_channel == depth: 101 | self.shortcut_layer = MaxPool2d(1, stride) 102 | else: 103 | self.shortcut_layer = Sequential( 104 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 105 | BatchNorm2d(depth), 106 | ) 107 | self.res_layer = Sequential( 108 | BatchNorm2d(in_channel), 109 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 110 | PReLU(depth), 111 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 112 | BatchNorm2d(depth), 113 | ) 114 | 115 | def forward(self, x): 116 | shortcut = self.shortcut_layer(x) 117 | res = self.res_layer(x) 118 | return res + shortcut 119 | 120 | 121 | class bottleneck_IR_SE(Module): 122 | def __init__(self, in_channel, depth, stride): 123 | super(bottleneck_IR_SE, self).__init__() 124 | if in_channel == depth: 125 | self.shortcut_layer = MaxPool2d(1, stride) 126 | else: 127 | self.shortcut_layer = Sequential( 128 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 129 | BatchNorm2d(depth), 130 | ) 131 | self.res_layer = Sequential( 132 | BatchNorm2d(in_channel), 133 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 134 | PReLU(depth), 135 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 136 | BatchNorm2d(depth), 137 | SEModule(depth, 16), 138 | ) 139 | 140 | def forward(self, x): 141 | shortcut = self.shortcut_layer(x) 142 | res = self.res_layer(x) 143 | return res + shortcut 144 | -------------------------------------------------------------------------------- /models/hyperinverter/encoders/model_irse.py: -------------------------------------------------------------------------------- 1 | from models.hyperinverter.encoders.helpers import ( 2 | Flatten, 3 | bottleneck_IR, 4 | bottleneck_IR_SE, 5 | get_blocks, 6 | l2_norm, 7 | ) 8 | from torch.nn import ( 9 | BatchNorm1d, 10 | BatchNorm2d, 11 | Conv2d, 12 | Dropout, 13 | Linear, 14 | Module, 15 | PReLU, 16 | Sequential, 17 | ) 18 | 19 | 20 | """ 21 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 22 | """ 23 | 24 | 25 | class Backbone(Module): 26 | def __init__(self, input_size, num_layers, mode="ir", drop_ratio=0.4, affine=True): 27 | super(Backbone, self).__init__() 28 | assert input_size in [112, 224], "input_size should be 112 or 224" 29 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 30 | assert mode in ["ir", "ir_se"], "mode should be ir or ir_se" 31 | blocks = get_blocks(num_layers) 32 | if mode == "ir": 33 | unit_module = bottleneck_IR 34 | elif mode == "ir_se": 35 | unit_module = bottleneck_IR_SE 36 | self.input_layer = Sequential( 37 | Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64) 38 | ) 39 | if input_size == 112: 40 | self.output_layer = Sequential( 41 | BatchNorm2d(512), 42 | Dropout(drop_ratio), 43 | Flatten(), 44 | Linear(512 * 7 * 7, 512), 45 | BatchNorm1d(512, affine=affine), 46 | ) 47 | else: 48 | self.output_layer = Sequential( 49 | BatchNorm2d(512), 50 | Dropout(drop_ratio), 51 | Flatten(), 52 | Linear(512 * 14 * 14, 512), 53 | BatchNorm1d(512, affine=affine), 54 | ) 55 | 56 | modules = [] 57 | for block in blocks: 58 | for bottleneck in block: 59 | modules.append( 60 | unit_module( 61 | bottleneck.in_channel, bottleneck.depth, bottleneck.stride 62 | ) 63 | ) 64 | self.body = Sequential(*modules) 65 | 66 | def forward(self, x): 67 | x = self.input_layer(x) 68 | x = self.body(x) 69 | x = self.output_layer(x) 70 | return l2_norm(x) 71 | 72 | 73 | def IR_50(input_size): 74 | """Constructs a ir-50 model.""" 75 | model = Backbone(input_size, num_layers=50, mode="ir", drop_ratio=0.4, affine=False) 76 | return model 77 | 78 | 79 | def IR_101(input_size): 80 | """Constructs a ir-101 model.""" 81 | model = Backbone( 82 | input_size, num_layers=100, mode="ir", drop_ratio=0.4, affine=False 83 | ) 84 | return model 85 | 86 | 87 | def IR_152(input_size): 88 | """Constructs a ir-152 model.""" 89 | model = Backbone( 90 | input_size, num_layers=152, mode="ir", drop_ratio=0.4, affine=False 91 | ) 92 | return model 93 | 94 | 95 | def IR_SE_50(input_size): 96 | """Constructs a ir_se-50 model.""" 97 | model = Backbone( 98 | input_size, num_layers=50, mode="ir_se", drop_ratio=0.4, affine=False 99 | ) 100 | return model 101 | 102 | 103 | def IR_SE_101(input_size): 104 | """Constructs a ir_se-101 model.""" 105 | model = Backbone( 106 | input_size, num_layers=100, mode="ir_se", drop_ratio=0.4, affine=False 107 | ) 108 | return model 109 | 110 | 111 | def IR_SE_152(input_size): 112 | """Constructs a ir_se-152 model.""" 113 | model = Backbone( 114 | input_size, num_layers=152, mode="ir_se", drop_ratio=0.4, affine=False 115 | ) 116 | return model 117 | -------------------------------------------------------------------------------- /models/hyperinverter/hypernetwork.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def shape_to_num_params(shapes): 8 | return torch.sum(torch.tensor([torch.prod(s) for s in shapes])).int().item() 9 | 10 | 11 | class WeightRegressor(nn.Module): 12 | """Regressing features to convolution weight kernel""" 13 | 14 | def __init__(self, input_dim, hidden_dim, kernel_size=3, out_channels=16, in_channels=16): 15 | super().__init__() 16 | self.input_dim = input_dim 17 | self.hidden_dim = hidden_dim 18 | self.kernel_size = kernel_size 19 | self.out_channels = out_channels 20 | self.in_channels = in_channels 21 | 22 | # Feature Transformer 23 | self.fusion = nn.Sequential( 24 | nn.Conv2d(2 * self.input_dim, self.input_dim, kernel_size=1, padding=0, stride=1, bias=True), 25 | nn.InstanceNorm2d(self.input_dim), 26 | nn.ReLU(), 27 | ) 28 | self.feature_extractor = nn.Sequential( 29 | nn.Conv2d(self.input_dim, 64, kernel_size=3, padding=1, stride=1, bias=True), 30 | nn.ReLU(), 31 | nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, bias=True), 32 | nn.ReLU(), 33 | nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, bias=True), 34 | nn.ReLU(), 35 | nn.Conv2d(64, self.hidden_dim, kernel_size=4, stride=2, padding=1, bias=True), 36 | nn.ReLU(), 37 | ) 38 | 39 | # Linear Mapper 40 | self.w1 = nn.Parameter(torch.randn((self.hidden_dim, self.in_channels * self.hidden_dim))) 41 | self.b1 = nn.Parameter(torch.randn((self.in_channels * self.hidden_dim))) 42 | self.w2 = nn.Parameter(torch.randn((self.hidden_dim, self.out_channels * self.kernel_size * self.kernel_size))) 43 | self.b2 = nn.Parameter(torch.randn((self.out_channels * self.kernel_size * self.kernel_size))) 44 | 45 | self.weight_init() 46 | 47 | def weight_init(self): 48 | nn.init.kaiming_normal_(self.w1) 49 | nn.init.zeros_(self.b1) 50 | nn.init.kaiming_normal_(self.w2) 51 | nn.init.zeros_(self.b2) 52 | 53 | def forward(self, w_image_codes, w_bar_codes): 54 | bs = w_image_codes.size(0) 55 | 56 | # Feature Transformation 57 | out = self.fusion(torch.cat((w_image_codes, w_bar_codes), 1)) 58 | out = self.feature_extractor(out) 59 | out = out.view(bs, -1) 60 | 61 | # Linear map to weights 62 | out = torch.matmul(out, self.w1) + self.b1 63 | out = out.view(bs, self.in_channels, self.hidden_dim) 64 | out = torch.matmul(out, self.w2) + self.b2 65 | kernel = out.view(bs, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size) 66 | 67 | return kernel 68 | 69 | 70 | class Hypernetwork(nn.Module): 71 | def __init__(self, input_dim=512, hidden_dim=64, target_shape=None): 72 | super().__init__() 73 | self.input_dim = input_dim 74 | self.hidden_dim = hidden_dim 75 | self.target_shape = target_shape 76 | 77 | num_predicted_weights = 0 78 | weight_regressors = OrderedDict() 79 | for layer_name in target_shape: 80 | new_layer_name = "_".join(layer_name.split(".")) 81 | shape = target_shape[layer_name]["shape"] 82 | 83 | # without consider bias 84 | if len(shape) == 4: 85 | out_channels, in_channels, kernel_size = shape[:3] 86 | else: 87 | out_channels, in_channels = shape 88 | kernel_size = 1 89 | 90 | num_predicted_weights += shape_to_num_params([torch.tensor(list(shape))]) 91 | weight_regressors[new_layer_name] = WeightRegressor( 92 | input_dim=self.input_dim, 93 | hidden_dim=self.hidden_dim, 94 | kernel_size=kernel_size, 95 | out_channels=out_channels, 96 | in_channels=in_channels, 97 | ) 98 | self.weight_regressors = nn.ModuleDict(weight_regressors) 99 | self.num_predicted_weights = num_predicted_weights 100 | 101 | def forward(self, w_image_codes, w_bar_codes): 102 | bs = w_image_codes.size(0) 103 | out_weights = {} 104 | 105 | for layer_name in self.weight_regressors: 106 | ori_layer_name = ".".join(layer_name.split("_")) 107 | w_idx = self.target_shape[ori_layer_name]["w_idx"] 108 | weights = self.weight_regressors[layer_name]( 109 | w_image_codes[:, w_idx, :, :, :], w_bar_codes[:, w_idx, :, :, :] 110 | ) 111 | out_weights[ori_layer_name] = weights.view(bs, *list(self.target_shape[ori_layer_name]["shape"])) 112 | 113 | return out_weights -------------------------------------------------------------------------------- /models/hyperinverter/stylegan2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/models/hyperinverter/stylegan2/__init__.py -------------------------------------------------------------------------------- /models/hyperinverter/stylegan2/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | 4 | 5 | __all__ = ["FusedLeakyReLU", "fused_leaky_relu", "upfirdn2d"] 6 | -------------------------------------------------------------------------------- /models/hyperinverter/stylegan2/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | fused = load( 11 | "fused", 12 | sources=[ 13 | os.path.join(module_path, "fused_bias_act.cpp"), 14 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 15 | ], 16 | ) 17 | 18 | 19 | class FusedLeakyReLUFunctionBackward(Function): 20 | @staticmethod 21 | def forward(ctx, grad_output, out, negative_slope, scale): 22 | ctx.save_for_backward(out) 23 | ctx.negative_slope = negative_slope 24 | ctx.scale = scale 25 | 26 | empty = grad_output.new_empty(0) 27 | 28 | grad_input = fused.fused_bias_act( 29 | grad_output, empty, out, 3, 1, negative_slope, scale 30 | ) 31 | 32 | dim = [0] 33 | 34 | if grad_input.ndim > 2: 35 | dim += list(range(2, grad_input.ndim)) 36 | 37 | grad_bias = grad_input.sum(dim).detach() 38 | 39 | return grad_input, grad_bias 40 | 41 | @staticmethod 42 | def backward(ctx, gradgrad_input, gradgrad_bias): 43 | (out,) = ctx.saved_tensors 44 | gradgrad_out = fused.fused_bias_act( 45 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 46 | ) 47 | 48 | return gradgrad_out, None, None, None 49 | 50 | 51 | class FusedLeakyReLUFunction(Function): 52 | @staticmethod 53 | def forward(ctx, input, bias, negative_slope, scale): 54 | empty = input.new_empty(0) 55 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 56 | ctx.save_for_backward(out) 57 | ctx.negative_slope = negative_slope 58 | ctx.scale = scale 59 | 60 | return out 61 | 62 | @staticmethod 63 | def backward(ctx, grad_output): 64 | (out,) = ctx.saved_tensors 65 | 66 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 67 | grad_output, out, ctx.negative_slope, ctx.scale 68 | ) 69 | 70 | return grad_input, grad_bias, None, None 71 | 72 | 73 | class FusedLeakyReLU(nn.Module): 74 | def __init__(self, channel, negative_slope=0.2, scale=2**0.5): 75 | super().__init__() 76 | 77 | self.bias = nn.Parameter(torch.zeros(channel)) 78 | self.negative_slope = negative_slope 79 | self.scale = scale 80 | 81 | def forward(self, input): 82 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 83 | 84 | 85 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): 86 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 87 | -------------------------------------------------------------------------------- /models/hyperinverter/stylegan2/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /models/hyperinverter/stylegan2/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /models/hyperinverter/stylegan2/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /models/hyperinverter/stylegan2/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.autograd import Function 5 | from torch.nn import functional as F 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | upfirdn2d_op = load( 11 | "upfirdn2d", 12 | sources=[ 13 | os.path.join(module_path, "upfirdn2d.cpp"), 14 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 15 | ], 16 | ) 17 | 18 | 19 | class UpFirDn2dBackward(Function): 20 | @staticmethod 21 | def forward( 22 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 23 | ): 24 | up_x, up_y = up 25 | down_x, down_y = down 26 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 27 | 28 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 29 | 30 | grad_input = upfirdn2d_op.upfirdn2d( 31 | grad_output, 32 | grad_kernel, 33 | down_x, 34 | down_y, 35 | up_x, 36 | up_y, 37 | g_pad_x0, 38 | g_pad_x1, 39 | g_pad_y0, 40 | g_pad_y1, 41 | ) 42 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 43 | 44 | ctx.save_for_backward(kernel) 45 | 46 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 47 | 48 | ctx.up_x = up_x 49 | ctx.up_y = up_y 50 | ctx.down_x = down_x 51 | ctx.down_y = down_y 52 | ctx.pad_x0 = pad_x0 53 | ctx.pad_x1 = pad_x1 54 | ctx.pad_y0 = pad_y0 55 | ctx.pad_y1 = pad_y1 56 | ctx.in_size = in_size 57 | ctx.out_size = out_size 58 | 59 | return grad_input 60 | 61 | @staticmethod 62 | def backward(ctx, gradgrad_input): 63 | (kernel,) = ctx.saved_tensors 64 | 65 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 66 | 67 | gradgrad_out = upfirdn2d_op.upfirdn2d( 68 | gradgrad_input, 69 | kernel, 70 | ctx.up_x, 71 | ctx.up_y, 72 | ctx.down_x, 73 | ctx.down_y, 74 | ctx.pad_x0, 75 | ctx.pad_x1, 76 | ctx.pad_y0, 77 | ctx.pad_y1, 78 | ) 79 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 80 | gradgrad_out = gradgrad_out.view( 81 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 82 | ) 83 | 84 | return gradgrad_out, None, None, None, None, None, None, None, None 85 | 86 | 87 | class UpFirDn2d(Function): 88 | @staticmethod 89 | def forward(ctx, input, kernel, up, down, pad): 90 | up_x, up_y = up 91 | down_x, down_y = down 92 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 93 | 94 | kernel_h, kernel_w = kernel.shape 95 | batch, channel, in_h, in_w = input.shape 96 | ctx.in_size = input.shape 97 | 98 | input = input.reshape(-1, in_h, in_w, 1) 99 | 100 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 101 | 102 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 103 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 104 | ctx.out_size = (out_h, out_w) 105 | 106 | ctx.up = (up_x, up_y) 107 | ctx.down = (down_x, down_y) 108 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 109 | 110 | g_pad_x0 = kernel_w - pad_x0 - 1 111 | g_pad_y0 = kernel_h - pad_y0 - 1 112 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 113 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 114 | 115 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 116 | 117 | out = upfirdn2d_op.upfirdn2d( 118 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 119 | ) 120 | # out = out.view(major, out_h, out_w, minor) 121 | out = out.view(-1, channel, out_h, out_w) 122 | 123 | return out 124 | 125 | @staticmethod 126 | def backward(ctx, grad_output): 127 | kernel, grad_kernel = ctx.saved_tensors 128 | 129 | grad_input = UpFirDn2dBackward.apply( 130 | grad_output, 131 | kernel, 132 | grad_kernel, 133 | ctx.up, 134 | ctx.down, 135 | ctx.pad, 136 | ctx.g_pad, 137 | ctx.in_size, 138 | ctx.out_size, 139 | ) 140 | 141 | return grad_input, None, None, None, None 142 | 143 | 144 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 145 | out = UpFirDn2d.apply( 146 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 147 | ) 148 | 149 | return out 150 | 151 | 152 | def upfirdn2d_native( 153 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 154 | ): 155 | _, in_h, in_w, minor = input.shape 156 | kernel_h, kernel_w = kernel.shape 157 | 158 | out = input.view(-1, in_h, 1, in_w, 1, minor) 159 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 160 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 161 | 162 | out = F.pad( 163 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 164 | ) 165 | out = out[ 166 | :, 167 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 168 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 169 | :, 170 | ] 171 | 172 | out = out.permute(0, 3, 1, 2) 173 | out = out.reshape( 174 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 175 | ) 176 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 177 | out = F.conv2d(out, w) 178 | out = out.reshape( 179 | -1, 180 | minor, 181 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 182 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 183 | ) 184 | out = out.permute(0, 2, 3, 1) 185 | 186 | return out[:, ::down_y, ::down_x, :] 187 | -------------------------------------------------------------------------------- /models/mtcnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/models/mtcnn/__init__.py -------------------------------------------------------------------------------- /models/mtcnn/mtcnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | from models.mtcnn.mtcnn_pytorch.src.get_nets import PNet, RNet, ONet 5 | from models.mtcnn.mtcnn_pytorch.src.box_utils import ( 6 | nms, 7 | calibrate_box, 8 | get_image_boxes, 9 | convert_to_square, 10 | ) 11 | from models.mtcnn.mtcnn_pytorch.src.first_stage import run_first_stage 12 | from models.mtcnn.mtcnn_pytorch.src.align_trans import ( 13 | get_reference_facial_points, 14 | warp_and_crop_face, 15 | ) 16 | 17 | device = "cuda:0" 18 | 19 | 20 | class MTCNN: 21 | def __init__(self): 22 | self.pnet = PNet().to(device) 23 | self.rnet = RNet().to(device) 24 | self.onet = ONet().to(device) 25 | self.pnet.eval() 26 | self.rnet.eval() 27 | self.onet.eval() 28 | self.refrence = get_reference_facial_points(default_square=True) 29 | 30 | def align(self, img): 31 | _, landmarks = self.detect_faces(img) 32 | if len(landmarks) == 0: 33 | return None, None 34 | facial5points = [[landmarks[0][j], landmarks[0][j + 5]] for j in range(5)] 35 | warped_face, tfm = warp_and_crop_face( 36 | np.array(img), facial5points, self.refrence, crop_size=(112, 112) 37 | ) 38 | return Image.fromarray(warped_face), tfm 39 | 40 | def align_multi(self, img, limit=None, min_face_size=30.0): 41 | boxes, landmarks = self.detect_faces(img, min_face_size) 42 | if limit: 43 | boxes = boxes[:limit] 44 | landmarks = landmarks[:limit] 45 | faces = [] 46 | tfms = [] 47 | for landmark in landmarks: 48 | facial5points = [[landmark[j], landmark[j + 5]] for j in range(5)] 49 | warped_face, tfm = warp_and_crop_face( 50 | np.array(img), facial5points, self.refrence, crop_size=(112, 112) 51 | ) 52 | faces.append(Image.fromarray(warped_face)) 53 | tfms.append(tfm) 54 | return boxes, faces, tfms 55 | 56 | def detect_faces( 57 | self, 58 | image, 59 | min_face_size=20.0, 60 | thresholds=[0.15, 0.25, 0.35], 61 | nms_thresholds=[0.7, 0.7, 0.7], 62 | ): 63 | """ 64 | Arguments: 65 | image: an instance of PIL.Image. 66 | min_face_size: a float number. 67 | thresholds: a list of length 3. 68 | nms_thresholds: a list of length 3. 69 | 70 | Returns: 71 | two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10], 72 | bounding boxes and facial landmarks. 73 | """ 74 | 75 | # BUILD AN IMAGE PYRAMID 76 | width, height = image.size 77 | min_length = min(height, width) 78 | 79 | min_detection_size = 12 80 | factor = 0.707 # sqrt(0.5) 81 | 82 | # scales for scaling the image 83 | scales = [] 84 | 85 | # scales the image so that 86 | # minimum size that we can detect equals to 87 | # minimum face size that we want to detect 88 | m = min_detection_size / min_face_size 89 | min_length *= m 90 | 91 | factor_count = 0 92 | while min_length > min_detection_size: 93 | scales.append(m * factor**factor_count) 94 | min_length *= factor 95 | factor_count += 1 96 | 97 | # STAGE 1 98 | 99 | # it will be returned 100 | bounding_boxes = [] 101 | 102 | with torch.no_grad(): 103 | # run P-Net on different scales 104 | for s in scales: 105 | boxes = run_first_stage( 106 | image, self.pnet, scale=s, threshold=thresholds[0] 107 | ) 108 | bounding_boxes.append(boxes) 109 | 110 | # collect boxes (and offsets, and scores) from different scales 111 | bounding_boxes = [i for i in bounding_boxes if i is not None] 112 | bounding_boxes = np.vstack(bounding_boxes) 113 | 114 | keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0]) 115 | bounding_boxes = bounding_boxes[keep] 116 | 117 | # use offsets predicted by pnet to transform bounding boxes 118 | bounding_boxes = calibrate_box( 119 | bounding_boxes[:, 0:5], bounding_boxes[:, 5:] 120 | ) 121 | # shape [n_boxes, 5] 122 | 123 | bounding_boxes = convert_to_square(bounding_boxes) 124 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 125 | 126 | # STAGE 2 127 | 128 | img_boxes = get_image_boxes(bounding_boxes, image, size=24) 129 | img_boxes = torch.FloatTensor(img_boxes).to(device) 130 | 131 | output = self.rnet(img_boxes) 132 | offsets = output[0].cpu().data.numpy() # shape [n_boxes, 4] 133 | probs = output[1].cpu().data.numpy() # shape [n_boxes, 2] 134 | 135 | keep = np.where(probs[:, 1] > thresholds[1])[0] 136 | bounding_boxes = bounding_boxes[keep] 137 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 138 | offsets = offsets[keep] 139 | 140 | keep = nms(bounding_boxes, nms_thresholds[1]) 141 | bounding_boxes = bounding_boxes[keep] 142 | bounding_boxes = calibrate_box(bounding_boxes, offsets[keep]) 143 | bounding_boxes = convert_to_square(bounding_boxes) 144 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 145 | 146 | # STAGE 3 147 | 148 | img_boxes = get_image_boxes(bounding_boxes, image, size=48) 149 | if len(img_boxes) == 0: 150 | return [], [] 151 | img_boxes = torch.FloatTensor(img_boxes).to(device) 152 | output = self.onet(img_boxes) 153 | landmarks = output[0].cpu().data.numpy() # shape [n_boxes, 10] 154 | offsets = output[1].cpu().data.numpy() # shape [n_boxes, 4] 155 | probs = output[2].cpu().data.numpy() # shape [n_boxes, 2] 156 | 157 | keep = np.where(probs[:, 1] > thresholds[2])[0] 158 | bounding_boxes = bounding_boxes[keep] 159 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 160 | offsets = offsets[keep] 161 | landmarks = landmarks[keep] 162 | 163 | # compute landmark points 164 | width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0 165 | height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0 166 | xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1] 167 | landmarks[:, 0:5] = ( 168 | np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5] 169 | ) 170 | landmarks[:, 5:10] = ( 171 | np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10] 172 | ) 173 | 174 | bounding_boxes = calibrate_box(bounding_boxes, offsets) 175 | keep = nms(bounding_boxes, nms_thresholds[2], mode="min") 176 | bounding_boxes = bounding_boxes[keep] 177 | landmarks = landmarks[keep] 178 | 179 | return bounding_boxes, landmarks 180 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/models/mtcnn/mtcnn_pytorch/__init__.py -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/__init__.py: -------------------------------------------------------------------------------- 1 | from .visualization_utils import show_bboxes 2 | from .detector import detect_faces 3 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/detector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | from .get_nets import PNet, RNet, ONet 5 | from .box_utils import nms, calibrate_box, get_image_boxes, convert_to_square 6 | from .first_stage import run_first_stage 7 | 8 | 9 | def detect_faces( 10 | image, 11 | min_face_size=20.0, 12 | thresholds=[0.6, 0.7, 0.8], 13 | nms_thresholds=[0.7, 0.7, 0.7], 14 | ): 15 | """ 16 | Arguments: 17 | image: an instance of PIL.Image. 18 | min_face_size: a float number. 19 | thresholds: a list of length 3. 20 | nms_thresholds: a list of length 3. 21 | 22 | Returns: 23 | two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10], 24 | bounding boxes and facial landmarks. 25 | """ 26 | 27 | # LOAD MODELS 28 | pnet = PNet() 29 | rnet = RNet() 30 | onet = ONet() 31 | onet.eval() 32 | 33 | # BUILD AN IMAGE PYRAMID 34 | width, height = image.size 35 | min_length = min(height, width) 36 | 37 | min_detection_size = 12 38 | factor = 0.707 # sqrt(0.5) 39 | 40 | # scales for scaling the image 41 | scales = [] 42 | 43 | # scales the image so that 44 | # minimum size that we can detect equals to 45 | # minimum face size that we want to detect 46 | m = min_detection_size / min_face_size 47 | min_length *= m 48 | 49 | factor_count = 0 50 | while min_length > min_detection_size: 51 | scales.append(m * factor**factor_count) 52 | min_length *= factor 53 | factor_count += 1 54 | 55 | # STAGE 1 56 | 57 | # it will be returned 58 | bounding_boxes = [] 59 | 60 | with torch.no_grad(): 61 | # run P-Net on different scales 62 | for s in scales: 63 | boxes = run_first_stage(image, pnet, scale=s, threshold=thresholds[0]) 64 | bounding_boxes.append(boxes) 65 | 66 | # collect boxes (and offsets, and scores) from different scales 67 | bounding_boxes = [i for i in bounding_boxes if i is not None] 68 | bounding_boxes = np.vstack(bounding_boxes) 69 | 70 | keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0]) 71 | bounding_boxes = bounding_boxes[keep] 72 | 73 | # use offsets predicted by pnet to transform bounding boxes 74 | bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:]) 75 | # shape [n_boxes, 5] 76 | 77 | bounding_boxes = convert_to_square(bounding_boxes) 78 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 79 | 80 | # STAGE 2 81 | 82 | img_boxes = get_image_boxes(bounding_boxes, image, size=24) 83 | img_boxes = torch.FloatTensor(img_boxes) 84 | 85 | output = rnet(img_boxes) 86 | offsets = output[0].data.numpy() # shape [n_boxes, 4] 87 | probs = output[1].data.numpy() # shape [n_boxes, 2] 88 | 89 | keep = np.where(probs[:, 1] > thresholds[1])[0] 90 | bounding_boxes = bounding_boxes[keep] 91 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 92 | offsets = offsets[keep] 93 | 94 | keep = nms(bounding_boxes, nms_thresholds[1]) 95 | bounding_boxes = bounding_boxes[keep] 96 | bounding_boxes = calibrate_box(bounding_boxes, offsets[keep]) 97 | bounding_boxes = convert_to_square(bounding_boxes) 98 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 99 | 100 | # STAGE 3 101 | 102 | img_boxes = get_image_boxes(bounding_boxes, image, size=48) 103 | if len(img_boxes) == 0: 104 | return [], [] 105 | img_boxes = torch.FloatTensor(img_boxes) 106 | output = onet(img_boxes) 107 | landmarks = output[0].data.numpy() # shape [n_boxes, 10] 108 | offsets = output[1].data.numpy() # shape [n_boxes, 4] 109 | probs = output[2].data.numpy() # shape [n_boxes, 2] 110 | 111 | keep = np.where(probs[:, 1] > thresholds[2])[0] 112 | bounding_boxes = bounding_boxes[keep] 113 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 114 | offsets = offsets[keep] 115 | landmarks = landmarks[keep] 116 | 117 | # compute landmark points 118 | width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0 119 | height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0 120 | xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1] 121 | landmarks[:, 0:5] = ( 122 | np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5] 123 | ) 124 | landmarks[:, 5:10] = ( 125 | np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10] 126 | ) 127 | 128 | bounding_boxes = calibrate_box(bounding_boxes, offsets) 129 | keep = nms(bounding_boxes, nms_thresholds[2], mode="min") 130 | bounding_boxes = bounding_boxes[keep] 131 | landmarks = landmarks[keep] 132 | 133 | return bounding_boxes, landmarks 134 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/first_stage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import math 4 | from PIL import Image 5 | import numpy as np 6 | from .box_utils import nms, _preprocess 7 | 8 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 9 | device = "cuda:0" 10 | 11 | 12 | def run_first_stage(image, net, scale, threshold): 13 | """Run P-Net, generate bounding boxes, and do NMS. 14 | 15 | Arguments: 16 | image: an instance of PIL.Image. 17 | net: an instance of pytorch's nn.Module, P-Net. 18 | scale: a float number, 19 | scale width and height of the image by this number. 20 | threshold: a float number, 21 | threshold on the probability of a face when generating 22 | bounding boxes from predictions of the net. 23 | 24 | Returns: 25 | a float numpy array of shape [n_boxes, 9], 26 | bounding boxes with scores and offsets (4 + 1 + 4). 27 | """ 28 | 29 | # scale the image and convert it to a float array 30 | width, height = image.size 31 | sw, sh = math.ceil(width * scale), math.ceil(height * scale) 32 | img = image.resize((sw, sh), Image.BILINEAR) 33 | img = np.asarray(img, "float32") 34 | 35 | img = torch.FloatTensor(_preprocess(img)).to(device) 36 | with torch.no_grad(): 37 | output = net(img) 38 | probs = output[1].cpu().data.numpy()[0, 1, :, :] 39 | offsets = output[0].cpu().data.numpy() 40 | # probs: probability of a face at each sliding window 41 | # offsets: transformations to true bounding boxes 42 | 43 | boxes = _generate_bboxes(probs, offsets, scale, threshold) 44 | if len(boxes) == 0: 45 | return None 46 | 47 | keep = nms(boxes[:, 0:5], overlap_threshold=0.5) 48 | return boxes[keep] 49 | 50 | 51 | def _generate_bboxes(probs, offsets, scale, threshold): 52 | """Generate bounding boxes at places 53 | where there is probably a face. 54 | 55 | Arguments: 56 | probs: a float numpy array of shape [n, m]. 57 | offsets: a float numpy array of shape [1, 4, n, m]. 58 | scale: a float number, 59 | width and height of the image were scaled by this number. 60 | threshold: a float number. 61 | 62 | Returns: 63 | a float numpy array of shape [n_boxes, 9] 64 | """ 65 | 66 | # applying P-Net is equivalent, in some sense, to 67 | # moving 12x12 window with stride 2 68 | stride = 2 69 | cell_size = 12 70 | 71 | # indices of boxes where there is probably a face 72 | inds = np.where(probs > threshold) 73 | 74 | if inds[0].size == 0: 75 | return np.array([]) 76 | 77 | # transformations of bounding boxes 78 | tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)] 79 | # they are defined as: 80 | # w = x2 - x1 + 1 81 | # h = y2 - y1 + 1 82 | # x1_true = x1 + tx1*w 83 | # x2_true = x2 + tx2*w 84 | # y1_true = y1 + ty1*h 85 | # y2_true = y2 + ty2*h 86 | 87 | offsets = np.array([tx1, ty1, tx2, ty2]) 88 | score = probs[inds[0], inds[1]] 89 | 90 | # P-Net is applied to scaled images 91 | # so we need to rescale bounding boxes back 92 | bounding_boxes = np.vstack( 93 | [ 94 | np.round((stride * inds[1] + 1.0) / scale), 95 | np.round((stride * inds[0] + 1.0) / scale), 96 | np.round((stride * inds[1] + 1.0 + cell_size) / scale), 97 | np.round((stride * inds[0] + 1.0 + cell_size) / scale), 98 | score, 99 | offsets, 100 | ] 101 | ) 102 | # why one is added? 103 | 104 | return bounding_boxes.T 105 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/get_nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | from configs.paths import DefaultPaths 6 | from pathlib import Path 7 | import numpy as np 8 | 9 | PNET_PATH = Path(DefaultPaths.mtcnn) / "pnet.npy" 10 | ONET_PATH = Path(DefaultPaths.mtcnn) / "onet.npy" 11 | RNET_PATH = Path(DefaultPaths.mtcnn) / "rnet.npy" 12 | 13 | 14 | class Flatten(nn.Module): 15 | def __init__(self): 16 | super(Flatten, self).__init__() 17 | 18 | def forward(self, x): 19 | """ 20 | Arguments: 21 | x: a float tensor with shape [batch_size, c, h, w]. 22 | Returns: 23 | a float tensor with shape [batch_size, c*h*w]. 24 | """ 25 | 26 | # without this pretrained model isn't working 27 | x = x.transpose(3, 2).contiguous() 28 | 29 | return x.view(x.size(0), -1) 30 | 31 | 32 | class PNet(nn.Module): 33 | def __init__(self): 34 | super().__init__() 35 | 36 | # suppose we have input with size HxW, then 37 | # after first layer: H - 2, 38 | # after pool: ceil((H - 2)/2), 39 | # after second conv: ceil((H - 2)/2) - 2, 40 | # after last conv: ceil((H - 2)/2) - 4, 41 | # and the same for W 42 | 43 | self.features = nn.Sequential( 44 | OrderedDict( 45 | [ 46 | ("conv1", nn.Conv2d(3, 10, 3, 1)), 47 | ("prelu1", nn.PReLU(10)), 48 | ("pool1", nn.MaxPool2d(2, 2, ceil_mode=True)), 49 | ("conv2", nn.Conv2d(10, 16, 3, 1)), 50 | ("prelu2", nn.PReLU(16)), 51 | ("conv3", nn.Conv2d(16, 32, 3, 1)), 52 | ("prelu3", nn.PReLU(32)), 53 | ] 54 | ) 55 | ) 56 | 57 | self.conv4_1 = nn.Conv2d(32, 2, 1, 1) 58 | self.conv4_2 = nn.Conv2d(32, 4, 1, 1) 59 | 60 | weights = np.load(PNET_PATH, allow_pickle=True)[()] 61 | for n, p in self.named_parameters(): 62 | p.data = torch.FloatTensor(weights[n]) 63 | 64 | def forward(self, x): 65 | """ 66 | Arguments: 67 | x: a float tensor with shape [batch_size, 3, h, w]. 68 | Returns: 69 | b: a float tensor with shape [batch_size, 4, h', w']. 70 | a: a float tensor with shape [batch_size, 2, h', w']. 71 | """ 72 | x = self.features(x) 73 | a = self.conv4_1(x) 74 | b = self.conv4_2(x) 75 | a = F.softmax(a, dim=-1) 76 | return b, a 77 | 78 | 79 | class RNet(nn.Module): 80 | def __init__(self): 81 | super().__init__() 82 | 83 | self.features = nn.Sequential( 84 | OrderedDict( 85 | [ 86 | ("conv1", nn.Conv2d(3, 28, 3, 1)), 87 | ("prelu1", nn.PReLU(28)), 88 | ("pool1", nn.MaxPool2d(3, 2, ceil_mode=True)), 89 | ("conv2", nn.Conv2d(28, 48, 3, 1)), 90 | ("prelu2", nn.PReLU(48)), 91 | ("pool2", nn.MaxPool2d(3, 2, ceil_mode=True)), 92 | ("conv3", nn.Conv2d(48, 64, 2, 1)), 93 | ("prelu3", nn.PReLU(64)), 94 | ("flatten", Flatten()), 95 | ("conv4", nn.Linear(576, 128)), 96 | ("prelu4", nn.PReLU(128)), 97 | ] 98 | ) 99 | ) 100 | 101 | self.conv5_1 = nn.Linear(128, 2) 102 | self.conv5_2 = nn.Linear(128, 4) 103 | 104 | weights = np.load(RNET_PATH, allow_pickle=True)[()] 105 | for n, p in self.named_parameters(): 106 | p.data = torch.FloatTensor(weights[n]) 107 | 108 | def forward(self, x): 109 | """ 110 | Arguments: 111 | x: a float tensor with shape [batch_size, 3, h, w]. 112 | Returns: 113 | b: a float tensor with shape [batch_size, 4]. 114 | a: a float tensor with shape [batch_size, 2]. 115 | """ 116 | x = self.features(x) 117 | a = self.conv5_1(x) 118 | b = self.conv5_2(x) 119 | a = F.softmax(a, dim=-1) 120 | return b, a 121 | 122 | 123 | class ONet(nn.Module): 124 | def __init__(self): 125 | super().__init__() 126 | 127 | self.features = nn.Sequential( 128 | OrderedDict( 129 | [ 130 | ("conv1", nn.Conv2d(3, 32, 3, 1)), 131 | ("prelu1", nn.PReLU(32)), 132 | ("pool1", nn.MaxPool2d(3, 2, ceil_mode=True)), 133 | ("conv2", nn.Conv2d(32, 64, 3, 1)), 134 | ("prelu2", nn.PReLU(64)), 135 | ("pool2", nn.MaxPool2d(3, 2, ceil_mode=True)), 136 | ("conv3", nn.Conv2d(64, 64, 3, 1)), 137 | ("prelu3", nn.PReLU(64)), 138 | ("pool3", nn.MaxPool2d(2, 2, ceil_mode=True)), 139 | ("conv4", nn.Conv2d(64, 128, 2, 1)), 140 | ("prelu4", nn.PReLU(128)), 141 | ("flatten", Flatten()), 142 | ("conv5", nn.Linear(1152, 256)), 143 | ("drop5", nn.Dropout(0.25)), 144 | ("prelu5", nn.PReLU(256)), 145 | ] 146 | ) 147 | ) 148 | 149 | self.conv6_1 = nn.Linear(256, 2) 150 | self.conv6_2 = nn.Linear(256, 4) 151 | self.conv6_3 = nn.Linear(256, 10) 152 | 153 | weights = np.load(ONET_PATH, allow_pickle=True)[()] 154 | for n, p in self.named_parameters(): 155 | p.data = torch.FloatTensor(weights[n]) 156 | 157 | def forward(self, x): 158 | """ 159 | Arguments: 160 | x: a float tensor with shape [batch_size, 3, h, w]. 161 | Returns: 162 | c: a float tensor with shape [batch_size, 10]. 163 | b: a float tensor with shape [batch_size, 4]. 164 | a: a float tensor with shape [batch_size, 2]. 165 | """ 166 | x = self.features(x) 167 | a = self.conv6_1(x) 168 | b = self.conv6_2(x) 169 | c = self.conv6_3(x) 170 | a = F.softmax(a, dim=-1) 171 | return c, b, a 172 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/visualization_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageDraw 2 | 3 | 4 | def show_bboxes(img, bounding_boxes, facial_landmarks=[]): 5 | """Draw bounding boxes and facial landmarks. 6 | 7 | Arguments: 8 | img: an instance of PIL.Image. 9 | bounding_boxes: a float numpy array of shape [n, 5]. 10 | facial_landmarks: a float numpy array of shape [n, 10]. 11 | 12 | Returns: 13 | an instance of PIL.Image. 14 | """ 15 | 16 | img_copy = img.copy() 17 | draw = ImageDraw.Draw(img_copy) 18 | 19 | for b in bounding_boxes: 20 | draw.rectangle([(b[0], b[1]), (b[2], b[3])], outline="white") 21 | 22 | for p in facial_landmarks: 23 | for i in range(5): 24 | draw.ellipse( 25 | [(p[i] - 1.0, p[i + 5] - 1.0), (p[i] + 1.0, p[i + 5] + 1.0)], 26 | outline="blue", 27 | ) 28 | 29 | return img_copy 30 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/weights/onet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/models/mtcnn/mtcnn_pytorch/src/weights/onet.npy -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy -------------------------------------------------------------------------------- /models/psp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/models/psp/__init__.py -------------------------------------------------------------------------------- /models/psp/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/models/psp/encoders/__init__.py -------------------------------------------------------------------------------- /models/psp/encoders/helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from collections import namedtuple 4 | from torch.nn import ( 5 | Conv2d, 6 | BatchNorm2d, 7 | PReLU, 8 | ReLU, 9 | Sigmoid, 10 | MaxPool2d, 11 | AdaptiveAvgPool2d, 12 | Sequential, 13 | Module, 14 | ) 15 | 16 | """ 17 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 18 | """ 19 | 20 | 21 | class Flatten(Module): 22 | def forward(self, input): 23 | return input.view(input.size(0), -1) 24 | 25 | 26 | def l2_norm(input, axis=1): 27 | norm = torch.norm(input, 2, axis, True) 28 | output = torch.div(input, norm) 29 | return output 30 | 31 | 32 | class Bottleneck(namedtuple("Block", ["in_channel", "depth", "stride"])): 33 | """A named tuple describing a ResNet block.""" 34 | 35 | 36 | def get_block(in_channel, depth, num_units, stride=2): 37 | return [Bottleneck(in_channel, depth, stride)] + [ 38 | Bottleneck(depth, depth, 1) for i in range(num_units - 1) 39 | ] 40 | 41 | 42 | def get_blocks(num_layers): 43 | if num_layers == 50: 44 | blocks = [ 45 | get_block(in_channel=64, depth=64, num_units=3), 46 | get_block(in_channel=64, depth=128, num_units=4), 47 | get_block(in_channel=128, depth=256, num_units=14), 48 | get_block(in_channel=256, depth=512, num_units=3), 49 | ] 50 | elif num_layers == 100: 51 | blocks = [ 52 | get_block(in_channel=64, depth=64, num_units=3), 53 | get_block(in_channel=64, depth=128, num_units=13), 54 | get_block(in_channel=128, depth=256, num_units=30), 55 | get_block(in_channel=256, depth=512, num_units=3), 56 | ] 57 | elif num_layers == 152: 58 | blocks = [ 59 | get_block(in_channel=64, depth=64, num_units=3), 60 | get_block(in_channel=64, depth=128, num_units=8), 61 | get_block(in_channel=128, depth=256, num_units=36), 62 | get_block(in_channel=256, depth=512, num_units=3), 63 | ] 64 | else: 65 | raise ValueError( 66 | "Invalid number of layers: {}. Must be one of [50, 100, 152]".format( 67 | num_layers 68 | ) 69 | ) 70 | return blocks 71 | 72 | 73 | class SEModule(Module): 74 | def __init__(self, channels, reduction): 75 | super(SEModule, self).__init__() 76 | self.avg_pool = AdaptiveAvgPool2d(1) 77 | self.fc1 = Conv2d( 78 | channels, channels // reduction, kernel_size=1, padding=0, bias=False 79 | ) 80 | self.relu = ReLU(inplace=True) 81 | self.fc2 = Conv2d( 82 | channels // reduction, channels, kernel_size=1, padding=0, bias=False 83 | ) 84 | self.sigmoid = Sigmoid() 85 | 86 | def forward(self, x): 87 | module_input = x 88 | x = self.avg_pool(x) 89 | x = self.fc1(x) 90 | x = self.relu(x) 91 | x = self.fc2(x) 92 | x = self.sigmoid(x) 93 | return module_input * x 94 | 95 | 96 | class bottleneck_IR(Module): 97 | def __init__(self, in_channel, depth, stride): 98 | super(bottleneck_IR, self).__init__() 99 | if in_channel == depth: 100 | self.shortcut_layer = MaxPool2d(1, stride) 101 | else: 102 | self.shortcut_layer = Sequential( 103 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 104 | BatchNorm2d(depth), 105 | ) 106 | self.res_layer = Sequential( 107 | BatchNorm2d(in_channel), 108 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 109 | PReLU(depth), 110 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 111 | BatchNorm2d(depth), 112 | ) 113 | 114 | def forward(self, x): 115 | shortcut = self.shortcut_layer(x) 116 | res = self.res_layer(x) 117 | return res + shortcut 118 | 119 | 120 | class bottleneck_IR_SE(Module): 121 | def __init__(self, in_channel, depth, stride): 122 | super(bottleneck_IR_SE, self).__init__() 123 | if in_channel == depth: 124 | self.shortcut_layer = MaxPool2d(1, stride) 125 | else: 126 | self.shortcut_layer = Sequential( 127 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 128 | BatchNorm2d(depth), 129 | ) 130 | self.res_layer = Sequential( 131 | BatchNorm2d(in_channel), 132 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 133 | PReLU(depth), 134 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 135 | BatchNorm2d(depth), 136 | SEModule(depth, 16), 137 | ) 138 | 139 | def forward(self, x): 140 | shortcut = self.shortcut_layer(x) 141 | res = self.res_layer(x) 142 | return res + shortcut 143 | 144 | 145 | def _upsample_add(x, y): 146 | """Upsample and add two feature maps. 147 | Args: 148 | x: (Variable) top feature map to be upsampled. 149 | y: (Variable) lateral feature map. 150 | Returns: 151 | (Variable) added feature map. 152 | Note in PyTorch, when input size is odd, the upsampled feature map 153 | with `F.upsample(..., scale_factor=2, mode='nearest')` 154 | maybe not equal to the lateral feature map size. 155 | e.g. 156 | original input size: [N,_,15,15] -> 157 | conv2d feature map size: [N,_,8,8] -> 158 | upsampled feature map size: [N,_,16,16] 159 | So we choose bilinear upsample which supports arbitrary output sizes. 160 | """ 161 | _, _, H, W = y.size() 162 | return F.interpolate(x, size=(H, W), mode="bilinear", align_corners=True) + y 163 | -------------------------------------------------------------------------------- /models/psp/encoders/model_irse.py: -------------------------------------------------------------------------------- 1 | from torch.nn import ( 2 | Linear, 3 | Conv2d, 4 | BatchNorm1d, 5 | BatchNorm2d, 6 | PReLU, 7 | Dropout, 8 | Sequential, 9 | Module, 10 | ) 11 | from models.psp.encoders.helpers import ( 12 | get_blocks, 13 | Flatten, 14 | bottleneck_IR, 15 | bottleneck_IR_SE, 16 | l2_norm, 17 | ) 18 | 19 | """ 20 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 21 | """ 22 | 23 | 24 | class Backbone(Module): 25 | def __init__(self, input_size, num_layers, mode="ir", drop_ratio=0.4, affine=True): 26 | super(Backbone, self).__init__() 27 | assert input_size in [112, 224], "input_size should be 112 or 224" 28 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 29 | assert mode in ["ir", "ir_se"], "mode should be ir or ir_se" 30 | blocks = get_blocks(num_layers) 31 | if mode == "ir": 32 | unit_module = bottleneck_IR 33 | elif mode == "ir_se": 34 | unit_module = bottleneck_IR_SE 35 | self.input_layer = Sequential( 36 | Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64) 37 | ) 38 | if input_size == 112: 39 | self.output_layer = Sequential( 40 | BatchNorm2d(512), 41 | Dropout(drop_ratio), 42 | Flatten(), 43 | Linear(512 * 7 * 7, 512), 44 | BatchNorm1d(512, affine=affine), 45 | ) 46 | else: 47 | self.output_layer = Sequential( 48 | BatchNorm2d(512), 49 | Dropout(drop_ratio), 50 | Flatten(), 51 | Linear(512 * 14 * 14, 512), 52 | BatchNorm1d(512, affine=affine), 53 | ) 54 | 55 | modules = [] 56 | for block in blocks: 57 | for bottleneck in block: 58 | modules.append( 59 | unit_module( 60 | bottleneck.in_channel, bottleneck.depth, bottleneck.stride 61 | ) 62 | ) 63 | self.body = Sequential(*modules) 64 | 65 | def forward(self, x): 66 | x = self.input_layer(x) 67 | x = self.body(x) 68 | x = self.output_layer(x) 69 | return l2_norm(x) 70 | 71 | 72 | def IR_50(input_size): 73 | """Constructs a ir-50 model.""" 74 | model = Backbone(input_size, num_layers=50, mode="ir", drop_ratio=0.4, affine=False) 75 | return model 76 | 77 | 78 | def IR_101(input_size): 79 | """Constructs a ir-101 model.""" 80 | model = Backbone( 81 | input_size, num_layers=100, mode="ir", drop_ratio=0.4, affine=False 82 | ) 83 | return model 84 | 85 | 86 | def IR_152(input_size): 87 | """Constructs a ir-152 model.""" 88 | model = Backbone( 89 | input_size, num_layers=152, mode="ir", drop_ratio=0.4, affine=False 90 | ) 91 | return model 92 | 93 | 94 | def IR_SE_50(input_size): 95 | """Constructs a ir_se-50 model.""" 96 | model = Backbone( 97 | input_size, num_layers=50, mode="ir_se", drop_ratio=0.4, affine=False 98 | ) 99 | return model 100 | 101 | 102 | def IR_SE_101(input_size): 103 | """Constructs a ir_se-101 model.""" 104 | model = Backbone( 105 | input_size, num_layers=100, mode="ir_se", drop_ratio=0.4, affine=False 106 | ) 107 | return model 108 | 109 | 110 | def IR_SE_152(input_size): 111 | """Constructs a ir_se-152 model.""" 112 | model = Backbone( 113 | input_size, num_layers=152, mode="ir_se", drop_ratio=0.4, affine=False 114 | ) 115 | return model 116 | -------------------------------------------------------------------------------- /models/psp/stylegan2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/models/psp/stylegan2/__init__.py -------------------------------------------------------------------------------- /models/psp/stylegan2/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /models/psp/stylegan2/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | module_path = os.path.dirname(__file__) 8 | fused = load( 9 | "fused", 10 | sources=[ 11 | os.path.join(module_path, "fused_bias_act.cpp"), 12 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 13 | ], 14 | ) 15 | 16 | class FusedLeakyReLUFunctionBackward(Function): 17 | @staticmethod 18 | def forward(ctx, grad_output, out, negative_slope, scale): 19 | ctx.save_for_backward(out) 20 | ctx.negative_slope = negative_slope 21 | ctx.scale = scale 22 | 23 | empty = grad_output.new_empty(0) 24 | 25 | grad_input = fused.fused_bias_act( 26 | grad_output, empty, out, 3, 1, negative_slope, scale 27 | ) 28 | 29 | dim = [0] 30 | 31 | if grad_input.ndim > 2: 32 | dim += list(range(2, grad_input.ndim)) 33 | 34 | grad_bias = grad_input.sum(dim).detach() 35 | 36 | return grad_input, grad_bias 37 | 38 | @staticmethod 39 | def backward(ctx, gradgrad_input, gradgrad_bias): 40 | (out,) = ctx.saved_tensors 41 | gradgrad_out = fused.fused_bias_act( 42 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 43 | ) 44 | 45 | return gradgrad_out, None, None, None 46 | 47 | 48 | class FusedLeakyReLUFunction(Function): 49 | @staticmethod 50 | def forward(ctx, input, bias, negative_slope, scale): 51 | empty = input.new_empty(0) 52 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 53 | ctx.save_for_backward(out) 54 | ctx.negative_slope = negative_slope 55 | ctx.scale = scale 56 | 57 | return out 58 | 59 | @staticmethod 60 | def backward(ctx, grad_output): 61 | (out,) = ctx.saved_tensors 62 | 63 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 64 | grad_output, out, ctx.negative_slope, ctx.scale 65 | ) 66 | 67 | return grad_input, grad_bias, None, None 68 | 69 | 70 | class FusedLeakyReLU(nn.Module): 71 | def __init__(self, channel, negative_slope=0.2, scale=2**0.5): 72 | super().__init__() 73 | 74 | self.bias = nn.Parameter(torch.zeros(channel)) 75 | self.negative_slope = negative_slope 76 | self.scale = scale 77 | 78 | def forward(self, input): 79 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 80 | 81 | 82 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): 83 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 84 | -------------------------------------------------------------------------------- /models/psp/stylegan2/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /models/psp/stylegan2/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /models/psp/stylegan2/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /models/psp/stylegan2/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.autograd import Function 5 | from torch.utils.cpp_extension import load 6 | 7 | module_path = os.path.dirname(__file__) 8 | upfirdn2d_op = load( 9 | "upfirdn2d", 10 | sources=[ 11 | os.path.join(module_path, "upfirdn2d.cpp"), 12 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 13 | ], 14 | ) 15 | 16 | 17 | class UpFirDn2dBackward(Function): 18 | @staticmethod 19 | def forward( 20 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 21 | ): 22 | up_x, up_y = up 23 | down_x, down_y = down 24 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 25 | 26 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 27 | 28 | grad_input = upfirdn2d_op.upfirdn2d( 29 | grad_output, 30 | grad_kernel, 31 | down_x, 32 | down_y, 33 | up_x, 34 | up_y, 35 | g_pad_x0, 36 | g_pad_x1, 37 | g_pad_y0, 38 | g_pad_y1, 39 | ) 40 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 41 | 42 | ctx.save_for_backward(kernel) 43 | 44 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 45 | 46 | ctx.up_x = up_x 47 | ctx.up_y = up_y 48 | ctx.down_x = down_x 49 | ctx.down_y = down_y 50 | ctx.pad_x0 = pad_x0 51 | ctx.pad_x1 = pad_x1 52 | ctx.pad_y0 = pad_y0 53 | ctx.pad_y1 = pad_y1 54 | ctx.in_size = in_size 55 | ctx.out_size = out_size 56 | 57 | return grad_input 58 | 59 | @staticmethod 60 | def backward(ctx, gradgrad_input): 61 | (kernel,) = ctx.saved_tensors 62 | 63 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 64 | 65 | gradgrad_out = upfirdn2d_op.upfirdn2d( 66 | gradgrad_input, 67 | kernel, 68 | ctx.up_x, 69 | ctx.up_y, 70 | ctx.down_x, 71 | ctx.down_y, 72 | ctx.pad_x0, 73 | ctx.pad_x1, 74 | ctx.pad_y0, 75 | ctx.pad_y1, 76 | ) 77 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 78 | gradgrad_out = gradgrad_out.view( 79 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 80 | ) 81 | 82 | return gradgrad_out, None, None, None, None, None, None, None, None 83 | 84 | 85 | class UpFirDn2d(Function): 86 | @staticmethod 87 | def forward(ctx, input, kernel, up, down, pad): 88 | up_x, up_y = up 89 | down_x, down_y = down 90 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 91 | 92 | kernel_h, kernel_w = kernel.shape 93 | batch, channel, in_h, in_w = input.shape 94 | ctx.in_size = input.shape 95 | 96 | input = input.reshape(-1, in_h, in_w, 1) 97 | 98 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 99 | 100 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 101 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 102 | ctx.out_size = (out_h, out_w) 103 | 104 | ctx.up = (up_x, up_y) 105 | ctx.down = (down_x, down_y) 106 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 107 | 108 | g_pad_x0 = kernel_w - pad_x0 - 1 109 | g_pad_y0 = kernel_h - pad_y0 - 1 110 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 111 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 112 | 113 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 114 | 115 | out = upfirdn2d_op.upfirdn2d( 116 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 117 | ) 118 | # out = out.view(major, out_h, out_w, minor) 119 | out = out.view(-1, channel, out_h, out_w) 120 | 121 | return out 122 | 123 | @staticmethod 124 | def backward(ctx, grad_output): 125 | kernel, grad_kernel = ctx.saved_tensors 126 | 127 | grad_input = UpFirDn2dBackward.apply( 128 | grad_output, 129 | kernel, 130 | grad_kernel, 131 | ctx.up, 132 | ctx.down, 133 | ctx.pad, 134 | ctx.g_pad, 135 | ctx.in_size, 136 | ctx.out_size, 137 | ) 138 | 139 | return grad_input, None, None, None, None 140 | 141 | 142 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 143 | out = UpFirDn2d.apply( 144 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 145 | ) 146 | 147 | return out 148 | 149 | 150 | def upfirdn2d_native( 151 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 152 | ): 153 | _, in_h, in_w, minor = input.shape 154 | kernel_h, kernel_w = kernel.shape 155 | 156 | out = input.view(-1, in_h, 1, in_w, 1, minor) 157 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 158 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 159 | 160 | out = F.pad( 161 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 162 | ) 163 | out = out[ 164 | :, 165 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 166 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 167 | :, 168 | ] 169 | 170 | out = out.permute(0, 3, 1, 2) 171 | out = out.reshape( 172 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 173 | ) 174 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 175 | out = F.conv2d(out, w) 176 | out = out.reshape( 177 | -1, 178 | minor, 179 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 180 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 181 | ) 182 | out = out.permute(0, 2, 3, 1) 183 | 184 | return out[:, ::down_y, ::down_x, :] 185 | -------------------------------------------------------------------------------- /notebook/images/gosling.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/notebook/images/gosling.jpg -------------------------------------------------------------------------------- /notebook/images/robert.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/notebook/images/robert.png -------------------------------------------------------------------------------- /notebook/images/robert_aligned_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/notebook/images/robert_aligned_mask.jpg -------------------------------------------------------------------------------- /notebook/images/scarlet.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/notebook/images/scarlet.jpg -------------------------------------------------------------------------------- /notebook/images/smith.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/notebook/images/smith.jpg -------------------------------------------------------------------------------- /notebook/images/watson.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/notebook/images/watson.jpeg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gdown==4.7.1 2 | wandb==0.15.2 3 | omegaconf==2.1.2 4 | tqdm==4.66.3 5 | Pillow==10.3.0 6 | pytorch_fid==0.3.0 7 | piq==0.8.0 8 | fsspec==2024.3.1 9 | networkx==3.3 10 | einops==0.7.0 11 | scipy==1.10.1 12 | clip @ git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33 13 | pyfacer==0.0.4 14 | timm==1.0.3 15 | dlib==19.24.4 16 | pandas==2.2.2 17 | fpie -------------------------------------------------------------------------------- /runners/base_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import json 5 | import omegaconf 6 | import wandb 7 | import glob 8 | 9 | from pathlib import Path 10 | from editings.latent_editor import LatentEditor 11 | 12 | from models.methods import methods_registry 13 | from metrics.metrics import metrics_registry 14 | from utils.model_utils import get_stylespace_from_w 15 | 16 | 17 | class BaseRunner: 18 | def __init__(self, config): 19 | self.config = config 20 | self.method_config = config.methods_args[config.model.method] 21 | 22 | def setup(self): 23 | self._setup_device() 24 | self._setup_latent_editor() 25 | self._setup_method() 26 | 27 | def get_edited_latent(self, original_latent, editing_name, editing_degrees, original_image=None): 28 | if editing_name in self.latent_editor.stylespace_directions: 29 | stylespace_latent = get_stylespace_from_w(original_latent, self.method.decoder) 30 | edited_latents = ( 31 | self.latent_editor.get_stylespace_edits( 32 | stylespace_latent, editing_degrees, editing_name 33 | )) 34 | elif editing_name in self.latent_editor.interfacegan_directions: 35 | edited_latents = ( 36 | self.latent_editor.get_interface_gan_edits( 37 | original_latent, editing_degrees, editing_name 38 | )) 39 | 40 | elif editing_name in self.latent_editor.styleclip_directions: 41 | edited_latents = self.latent_editor.get_styleclip_mapper_edits( 42 | original_latent, editing_degrees, editing_name 43 | ) 44 | 45 | elif editing_name in self.latent_editor.ganspace_directions: 46 | edited_latents = ( 47 | self.latent_editor.get_ganspace_edits( 48 | original_latent, editing_degrees, editing_name 49 | ) 50 | ) 51 | elif editing_name in self.latent_editor.fs_directions.keys(): 52 | edited_latents = self.latent_editor.get_fs_edits( 53 | original_latent, editing_degrees, editing_name 54 | ) 55 | elif editing_name.startswith("styleclip_global_"): 56 | stylespace_latent = get_stylespace_from_w(original_latent, self.method.decoder) 57 | edited_latents = ( 58 | self.latent_editor.get_styleclip_global_edits( 59 | stylespace_latent, editing_degrees, editing_name.replace("styleclip_global_", "") 60 | )) 61 | elif editing_name.startswith("deltaedit_"): 62 | assert original_image is not None 63 | stylespace_latent = get_stylespace_from_w(original_latent, self.method.decoder) 64 | edited_latents = ( 65 | self.latent_editor.get_deltaedit_edits( 66 | stylespace_latent, editing_degrees, editing_name.replace("deltaedit_", ""), original_image 67 | )) 68 | else: 69 | raise ValueError(f'Edit name {editing_name} is not available') 70 | return edited_latents 71 | 72 | def _setup_latent_editor(self): 73 | self.latent_editor = LatentEditor(self.config.exp.domain) 74 | 75 | def _setup_device(self): 76 | config_device = self.config.model["device"].lower() 77 | 78 | if config_device == "cpu": 79 | device = "cpu" 80 | elif config_device.isdigit(): 81 | device = "cuda:{}".format(config_device) 82 | elif config_device.startswith("cuda:"): 83 | device = config_device 84 | else: 85 | raise ValueError("Incorrect Device Type") 86 | 87 | try: 88 | torch.randn(1).to(device) 89 | print("Device: {}".format(device)) 90 | except Exception as e: 91 | print("Could not use device {}, {}".format(device, e)) 92 | print("Set device to CPU") 93 | device = "cpu" 94 | 95 | self.device = torch.device(device) 96 | 97 | def _setup_method(self): 98 | method_name = self.config.model.method 99 | self.method = methods_registry[method_name]( 100 | checkpoint_path=self.config.model.checkpoint_path, 101 | **self.config.methods_args[method_name], 102 | ).to(self.device) 103 | -------------------------------------------------------------------------------- /scripts/calculate_metrics.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import random 4 | 5 | sys.path = ['.'] + sys.path 6 | 7 | from argparse import ArgumentParser 8 | from pathlib import Path 9 | from metrics.metrics import metrics_registry 10 | from utils.common_utils import setup_seed 11 | 12 | 13 | setup_seed(777) 14 | 15 | 16 | def run(test_opts): 17 | metrics = [] 18 | for metric_name in test_opts.metrics: 19 | metrics.append( 20 | metrics_registry[metric_name]() 21 | ) 22 | 23 | out_path = None 24 | for metric in metrics: 25 | print("Calculating", metric.get_name()) 26 | if test_opts.metrics_dir != "": 27 | out_path = Path(test_opts.metrics_dir) / metric.get_name() 28 | out_path = f"{out_path}.json" 29 | _, value, _ = metric( 30 | test_opts.orig_path, 31 | test_opts.reconstr_path, 32 | out_path=str(out_path) if out_path else None, 33 | ) 34 | 35 | 36 | if __name__ == "__main__": 37 | parser = ArgumentParser() 38 | parser.add_argument("--metrics", nargs="+", help="List of calculated metrics") 39 | parser.add_argument( 40 | "--orig_path", type=str, help="Path to directory of original images to evaluate" 41 | ) 42 | parser.add_argument( 43 | "--reconstr_path", 44 | type=str, 45 | help="Path to directory of reconstructions of images to evaluate", 46 | ) 47 | parser.add_argument( 48 | "--batch_size", default=4, type=int, help="Batch size for testing and inference" 49 | ) 50 | parser.add_argument( 51 | "--workers", 52 | default=4, 53 | type=int, 54 | help="Number of test/inference dataloader workers", 55 | ) 56 | parser.add_argument( 57 | "--metrics_dir", 58 | default="", 59 | type=str, 60 | help="Directory to save .json metrics info", 61 | ) 62 | 63 | test_opts = parser.parse_args() 64 | run(test_opts) 65 | -------------------------------------------------------------------------------- /scripts/fid_calculation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | 4 | sys.path = ['.'] + sys.path 5 | 6 | from argparse import ArgumentParser 7 | from pathlib import Path 8 | from metrics.metrics import metrics_registry 9 | from datasets.transforms import transforms_registry 10 | from datasets.datasets import CelebaAttributeDataset 11 | from utils.common_utils import tensor2im, setup_seed 12 | 13 | 14 | setup_seed(777) 15 | 16 | 17 | def inferece_fid_editing(opts): 18 | fid_metric = metrics_registry["fid"]() 19 | transform = transforms_registry[opts.transforms]().get_transforms()["test"] 20 | 21 | attr_name = opts.attr_name 22 | 23 | attr_dataset = CelebaAttributeDataset( 24 | opts.orig_path, 25 | attr_name, 26 | transform, 27 | opts.celeba_attr_table_pth, 28 | use_attr=not opts.attr_is_reversed 29 | ) 30 | 31 | not_attr_dataset = CelebaAttributeDataset( 32 | opts.synt_path, 33 | attr_name, 34 | transform, 35 | opts.celeba_attr_table_pth, 36 | use_attr=opts.attr_is_reversed 37 | ) 38 | 39 | print(f"Percent of Images of attribute {opts.attr_name} is " 40 | f"{len(attr_dataset) / (len(attr_dataset) + len(not_attr_dataset))}") 41 | 42 | attr_images = [] 43 | for attr_image in attr_dataset: 44 | img = tensor2im(attr_image).convert("RGB") 45 | attr_images.append(img) 46 | 47 | edited_images = [] 48 | for not_attr_image in not_attr_dataset: 49 | img = tensor2im(not_attr_image).convert("RGB") 50 | edited_images.append(img) 51 | 52 | from_data_arg = { 53 | "inp_data": attr_images, 54 | "fake_data": edited_images, 55 | "paths": [], 56 | } 57 | 58 | _, fid_value, _ = fid_metric("", "", out_path="", from_data=from_data_arg) 59 | print(f"FID for {opts.attr_name} is {fid_value:.4f}") 60 | 61 | 62 | if __name__ == "__main__": 63 | parser = ArgumentParser() 64 | parser.add_argument( 65 | "--orig_path", type=str, help="Path to directory of original Celeba images " 66 | ) 67 | parser.add_argument( 68 | "--synt_path", 69 | type=str, 70 | help="Path to synthesized edited images", 71 | ) 72 | parser.add_argument( 73 | "--attr_name", 74 | type=str, 75 | help="Name of Celeba attribute that is added during editing.", 76 | ) 77 | parser.add_argument( 78 | "--attr_is_reversed", 79 | action='store_true', 80 | help="Means that attribute was not added but removed during editing", 81 | ) 82 | parser.add_argument( 83 | "--celeba_attr_table_pth", 84 | default="CelebAMask-HQ-attribute-anno.txt", 85 | type=str, 86 | help="Path to celeba attributes .txt", 87 | ) 88 | parser.add_argument( 89 | "--transforms", 90 | default="face_1024", 91 | type=str, 92 | help="Which transforms from datasets.transforms.transforms_registry should be used", 93 | ) 94 | 95 | opts = parser.parse_args() 96 | inferece_fid_editing(opts) 97 | -------------------------------------------------------------------------------- /scripts/inference.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | 4 | sys.path = ['.'] + sys.path 5 | 6 | from arguments import inference_arguments 7 | from runners.inference_runners import inference_runner_registry 8 | from utils.common_utils import printer, setup_seed 9 | 10 | 11 | def run_inference(config): 12 | inference_runner = inference_runner_registry[config.inference.inference_runner]( 13 | config 14 | ) 15 | inference_runner.setup() 16 | inference_runner.run() 17 | 18 | 19 | if __name__ == "__main__": 20 | config = inference_arguments.load_config() 21 | setup_seed(config.exp.seed) 22 | 23 | printer(config) 24 | 25 | run_inference(config) 26 | -------------------------------------------------------------------------------- /scripts/simple_inference.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import random 4 | 5 | sys.path = ['.'] + sys.path 6 | 7 | from argparse import ArgumentParser 8 | from utils.common_utils import setup_seed 9 | from runners.simple_runner import SimpleRunner 10 | 11 | 12 | setup_seed(777) 13 | 14 | 15 | def run(opts): 16 | runner = SimpleRunner( 17 | editor_ckpt_pth="pretrained_models/sfe_editor_light.pt", 18 | ) 19 | 20 | runner.edit( 21 | orig_img_pth=opts.orig_img_pth, 22 | editing_name=opts.editing_name, 23 | edited_power=opts.edited_power, 24 | save_pth=opts.save_pth, 25 | align=opts.align, 26 | use_mask=opts.use_mask, 27 | mask_trashold=opts.mask_trashold, 28 | mask_path=opts.mask_path 29 | ) 30 | runner.available_editings() 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = ArgumentParser() 35 | parser.add_argument( 36 | "--orig_img_pth", type=str, help="Path to original image" 37 | ) 38 | parser.add_argument( 39 | "--editing_name", 40 | type=str, 41 | help="Name of desired editing", 42 | ) 43 | parser.add_argument( 44 | "--edited_power", 45 | type=float, 46 | help="Power of desired editing, float", 47 | ) 48 | parser.add_argument( 49 | "--save_pth", 50 | type=str, 51 | help="Path where to save edited (and aligned) image", 52 | ) 53 | parser.add_argument( 54 | "--align", 55 | action='store_true', 56 | help="Flag to align image if it was not", 57 | ) 58 | parser.add_argument( 59 | "--use_mask", 60 | action='store_true', 61 | help="Flag to edit only masked zone. May be usefull to remove background artefacts.", 62 | ) 63 | parser.add_argument( 64 | "--mask_trashold", 65 | type=float, 66 | default=0.095, 67 | help="Trashold in range (0, 1) to separate face from background", 68 | ) 69 | parser.add_argument( 70 | "--mask_path", 71 | type=str, 72 | default=None, 73 | help="Path to custom background mask", 74 | ) 75 | 76 | opts = parser.parse_args() 77 | run(opts) 78 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sys 3 | 4 | sys.path = ['.'] + sys.path 5 | 6 | import torch 7 | from arguments import training_arguments 8 | from runners.training_runners import training_runners 9 | from utils.common_utils import printer, setup_seed 10 | 11 | 12 | if __name__ == "__main__": 13 | config = training_arguments.load_config() 14 | setup_seed(config.exp.seed) 15 | 16 | printer(config) 17 | 18 | trainer = training_runners[config.train.train_runner](config) 19 | trainer.setup() 20 | trainer.run() 21 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/training/__init__.py -------------------------------------------------------------------------------- /training/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torch import nn 5 | from criteria import id_loss, moco_loss, id_vit_loss 6 | from criteria.lpips.lpips import LPIPS 7 | from utils.class_registry import ClassRegistry 8 | from configs.paths import DefaultPaths 9 | 10 | 11 | losses = ClassRegistry() 12 | adv_losses = ClassRegistry() 13 | disc_losses = ClassRegistry() 14 | other_losses = ClassRegistry() 15 | 16 | 17 | class LossBuilder: 18 | def __init__(self, enc_losses_dict, disc_losses_dict, device): 19 | self.coefs_dict = enc_losses_dict 20 | self.losses_names = [k for k, v in enc_losses_dict.items() if v > 0] 21 | self.losses = {} 22 | self.adv_losses = {} 23 | self.other_losses = {} 24 | self.device = device 25 | 26 | for loss in self.losses_names: 27 | if loss in losses.classes.keys(): 28 | self.losses[loss] = losses[loss]().to(self.device).eval() 29 | elif loss in adv_losses.classes.keys(): 30 | self.adv_losses[loss] = adv_losses[loss]() 31 | elif loss in other_losses.classes.keys(): 32 | self.other_losses[loss] = other_losses[loss]() 33 | else: 34 | raise ValueError(f'Unexepted loss: {loss}') 35 | 36 | self.disc_losses = [] 37 | for loss_name, loss_args in disc_losses_dict.items(): 38 | if loss_args.coef > 0: 39 | self.disc_losses.append(disc_losses[loss_name](**loss_args)) 40 | 41 | 42 | def encoder_loss(self, batch_data): 43 | loss_dict = {} 44 | global_loss = 0.0 45 | 46 | for loss_name, loss in self.losses.items(): 47 | loss_val = loss(batch_data["y_hat"], batch_data["x"]) 48 | global_loss += self.coefs_dict[loss_name] * loss_val 49 | loss_dict[loss_name] = float(loss_val) 50 | 51 | for loss_name, loss in self.other_losses.items(): 52 | loss_val = loss(batch_data) 53 | assert torch.isfinite(loss_val) 54 | global_loss += self.coefs_dict[loss_name] * loss_val 55 | loss_dict[loss_name] = float(loss_val) 56 | 57 | if batch_data["use_adv_loss"]: 58 | for loss_name, loss in self.adv_losses.items(): 59 | loss_val = loss(batch_data["fake_preds"]) 60 | global_loss += self.coefs_dict[loss_name] * loss_val 61 | loss_dict[loss_name] = float(loss_val) 62 | 63 | return global_loss, loss_dict 64 | 65 | def disc_loss(self, D, batch_data): 66 | disc_losses = {} 67 | total_disc_loss = torch.tensor([0.], device=self.device) 68 | 69 | for loss in self.disc_losses: 70 | disc_loss, disc_loss_dict = loss(D, batch_data) 71 | 72 | total_disc_loss += disc_loss 73 | disc_losses.update(disc_loss_dict) 74 | 75 | return total_disc_loss, disc_losses 76 | 77 | 78 | 79 | @losses.add_to_registry(name="l2") 80 | class L2Loss(nn.MSELoss): 81 | pass 82 | 83 | 84 | @losses.add_to_registry(name="lpips") 85 | class LPIPSLoss(LPIPS): 86 | pass 87 | 88 | 89 | @losses.add_to_registry(name="lpips_scale") 90 | class LPIPSScaleLoss(nn.Module): 91 | def __init__(self): 92 | super().__init__() 93 | self.loss_fn = LPIPSLoss() 94 | 95 | def forward(self, x, y): 96 | out = 0 97 | for res in [256, 128, 64]: 98 | x_scale = F.interpolate(x, size=(res, res), mode="bilinear", align_corners=False) 99 | y_scale = F.interpolate(y, size=(res, res), mode="bilinear", align_corners=False) 100 | out += self.loss_fn.forward(x_scale, y_scale).mean() 101 | return out 102 | 103 | 104 | @other_losses.add_to_registry(name="feat_rec") 105 | class FeatReconLoss(nn.Module): 106 | def __init__(self): 107 | super().__init__() 108 | self.loss_fn = nn.MSELoss() 109 | 110 | def forward(self, batch): 111 | return self.loss_fn(batch["feat_recon"], batch["feat_real"]).mean() 112 | 113 | 114 | @other_losses.add_to_registry(name="feat_rec_l1") 115 | class FeatReconL1Loss(nn.Module): 116 | def __init__(self): 117 | super().__init__() 118 | self.loss_fn = nn.L1Loss() 119 | 120 | def forward(self, batch): 121 | return self.loss_fn(batch["feat_recon"], batch["feat_real"]).mean() 122 | 123 | 124 | 125 | @other_losses.add_to_registry(name="l2_latent") 126 | class LatentMSELoss(nn.Module): 127 | def __init__(self): 128 | super().__init__() 129 | self.loss_fn = nn.MSELoss() 130 | 131 | def forward(self, batch): 132 | return self.loss_fn(batch["latent"], batch["latent_rec"]).mean() 133 | 134 | 135 | 136 | @losses.add_to_registry(name="id") 137 | class IDLoss(id_loss.IDLoss): 138 | pass 139 | 140 | 141 | @losses.add_to_registry(name="id_vit") 142 | class IDVitLoss(id_vit_loss.IDVitLoss): 143 | pass 144 | 145 | 146 | @losses.add_to_registry(name="moco") 147 | class MocoLoss(moco_loss.MocoLoss): 148 | pass 149 | 150 | 151 | @adv_losses.add_to_registry(name="adv") 152 | class EncoderAdvLoss: 153 | def __call__(self, fake_preds): 154 | loss_G_adv = F.softplus(-fake_preds).mean() 155 | return loss_G_adv 156 | 157 | 158 | @disc_losses.add_to_registry(name="main") 159 | class AdvLoss: 160 | def __init__(self, coef=0.0): 161 | self.coef = coef 162 | 163 | def __call__(self, disc, loss_input): 164 | real_images = loss_input["x"].detach() 165 | generated_images = loss_input["y_hat"].detach() 166 | loss_dict = {} 167 | 168 | fake_preds = disc(generated_images, None) 169 | real_preds = disc(real_images, None) 170 | loss = self.d_logistic_loss(real_preds, fake_preds) 171 | loss_dict["disc/main_loss"] = float(loss) 172 | 173 | return loss, loss_dict 174 | 175 | def d_logistic_loss(self, real_preds, fake_preds): 176 | real_loss = F.softplus(-real_preds) 177 | fake_loss = F.softplus(fake_preds) 178 | 179 | return (real_loss.mean() + fake_loss.mean()) / 2 180 | 181 | 182 | @disc_losses.add_to_registry(name="r1") 183 | class R1Loss: 184 | def __init__(self, coef=0.0, hyper_d_reg_every=16): 185 | self.coef = coef 186 | self.hyper_d_reg_every = hyper_d_reg_every 187 | 188 | def __call__(self, disc, loss_input): 189 | real_images = loss_input["x"] 190 | step = loss_input["step"] 191 | if step % self.hyper_d_reg_every != 0: # use r1 only once per 'hyper_d_reg_every' steps 192 | return torch.tensor([0.], requires_grad=True, device='cuda'), {} 193 | 194 | real_images.requires_grad = True 195 | loss_dict = {} 196 | 197 | real_preds = disc(real_images, None) 198 | real_preds = real_preds.view(real_images.size(0), -1) 199 | real_preds = real_preds.mean(dim=1).unsqueeze(1) 200 | r1_loss = self.d_r1_loss(real_preds, real_images) 201 | 202 | loss_D_R1 = self.coef / 2 * r1_loss * self.hyper_d_reg_every + 0 * real_preds[0] 203 | loss_dict["disc/r1_reg"] = float(loss_D_R1) 204 | return loss_D_R1, loss_dict 205 | 206 | def d_r1_loss(self, real_pred, real_img): 207 | (grad_real,) = torch.autograd.grad( 208 | outputs=real_pred.sum(), inputs=real_img, create_graph=True 209 | ) 210 | grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() 211 | 212 | return grad_penalty 213 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/utils/__init__.py -------------------------------------------------------------------------------- /utils/class_registry.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import typing 3 | import omegaconf 4 | import dataclasses 5 | 6 | 7 | class ClassRegistry: 8 | def __init__(self): 9 | self.classes = dict() 10 | self.args = dict() 11 | self.arg_keys = None 12 | 13 | def __getitem__(self, item): 14 | return self.classes[item] 15 | 16 | def make_dataclass_from_init(self, func, name, arg_keys, stop_args): 17 | args = inspect.signature(func).parameters 18 | args = [ 19 | (k, typing.Any, omegaconf.MISSING) 20 | if v.default is inspect.Parameter.empty 21 | else (k, typing.Optional[typing.Any], None) 22 | if v.default is None 23 | else ( 24 | k, 25 | type(v.default), 26 | dataclasses.field(default=v.default), 27 | ) 28 | for k, v in args.items() 29 | ] 30 | args = [arg for arg in args if arg[0] not in stop_args] 31 | if arg_keys: 32 | self.arg_keys = arg_keys 33 | arg_classes = dict() 34 | for key in arg_keys: 35 | arg_classes[key] = dataclasses.make_dataclass(key, args) 36 | return dataclasses.make_dataclass( 37 | name, 38 | [ 39 | (k, v, dataclasses.field(default=v())) 40 | for k, v in arg_classes.items() 41 | ], 42 | ) 43 | return dataclasses.make_dataclass(name, args) 44 | 45 | def make_dataclass_from_classes(self, name): 46 | return dataclasses.make_dataclass( 47 | name, 48 | [(k, v, dataclasses.field(default=v())) for k, v in self.classes.items()], 49 | ) 50 | 51 | def make_dataclass_from_args(self, name): 52 | return dataclasses.make_dataclass( 53 | name, 54 | [(k, v, dataclasses.field(default=v())) for k, v in self.args.items()], 55 | ) 56 | 57 | def add_to_registry( 58 | self, name, arg_keys=None, stop_args=("self", "args", "kwargs") 59 | ): 60 | def add_class_by_name(cls): 61 | self.classes[name] = cls 62 | self.args[name] = self.make_dataclass_from_init( 63 | cls.__init__, name, arg_keys, stop_args 64 | ) 65 | return cls 66 | 67 | return add_class_by_name 68 | -------------------------------------------------------------------------------- /utils/common_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from torch.nn import functional as F 4 | from PIL import Image 5 | 6 | 7 | class AlignerCantFindFaceError(Exception): 8 | pass 9 | 10 | class MaskerCantFindFaceError(Exception): 11 | pass 12 | 13 | 14 | def tensor2im(var): 15 | var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy() 16 | var = (var + 1) / 2 17 | var[var < 0] = 0 18 | var[var > 1] = 1 19 | var = var * 255 20 | return Image.fromarray(var.astype("uint8")) 21 | 22 | 23 | def tensor2im_no_tfm(var): 24 | var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy() 25 | var = var * 255 26 | return Image.fromarray(var.astype("uint8")) 27 | 28 | 29 | def printer(obj, tabs=0): 30 | for (key, value) in obj.items(): 31 | try: 32 | _ = value.items() 33 | print(" " * tabs + str(key) + ":") 34 | printer(value, tabs + 4) 35 | except: 36 | print(f" " * tabs + str(key) + " : " + str(value)) 37 | 38 | 39 | def get_keys(d, name, key="state_dict"): 40 | if key in d: 41 | d = d[key] 42 | d_filt = {k[len(name) + 1 :]: v for k, v in d.items() if k[: len(name) + 1] == name + '.'} 43 | return d_filt 44 | 45 | 46 | def setup_seed(seed): 47 | random.seed(seed) 48 | torch.random.manual_seed(seed) 49 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adopted from pix2pixHD: 3 | https://github.com/NVIDIA/pix2pixHD/blob/master/data/image_folder.py 4 | """ 5 | import os 6 | 7 | IMG_EXTENSIONS = [ 8 | ".jpg", 9 | ".JPG", 10 | ".jpeg", 11 | ".JPEG", 12 | ".png", 13 | ".PNG", 14 | ".ppm", 15 | ".PPM", 16 | ".bmp", 17 | ".BMP", 18 | ".tiff", 19 | ] 20 | 21 | 22 | def is_image_file(filename): 23 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 24 | 25 | 26 | def make_dataset(dir): 27 | images = [] 28 | assert os.path.isdir(dir), "%s is not a valid directory" % dir 29 | for root, _, fnames in sorted(os.walk(dir)): 30 | for fname in fnames: 31 | if is_image_file(fname): 32 | path = os.path.join(root, fname) 33 | images.append(path) 34 | return images 35 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | RESNET_MAPPING = { 2 | "layer1.0": "body.0", 3 | "layer1.1": "body.1", 4 | "layer1.2": "body.2", 5 | "layer2.0": "body.3", 6 | "layer2.1": "body.4", 7 | "layer2.2": "body.5", 8 | "layer2.3": "body.6", 9 | "layer3.0": "body.7", 10 | "layer3.1": "body.8", 11 | "layer3.2": "body.9", 12 | "layer3.3": "body.10", 13 | "layer3.4": "body.11", 14 | "layer3.5": "body.12", 15 | "layer4.0": "body.13", 16 | "layer4.1": "body.14", 17 | "layer4.2": "body.15", 18 | } 19 | 20 | 21 | def count_parameters(model): 22 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 23 | 24 | 25 | def toogle_grad(model, flag=True): 26 | for p in model.parameters(): 27 | p.requires_grad = flag 28 | 29 | 30 | def stylegan_to_classifier(x, out_size=(224, 224)): 31 | """Clip image to range(0,1)""" 32 | img_tmp = x.clone() 33 | img_tmp = torch.clamp((0.5*img_tmp + 0.5), 0, 1) 34 | img_tmp = F.interpolate(img_tmp, size=out_size, mode='bilinear') 35 | img_tmp[:,0] = (img_tmp[:,0] - 0.485)/0.229 36 | img_tmp[:,1] = (img_tmp[:,1] - 0.456)/0.224 37 | img_tmp[:,2] = (img_tmp[:,2] - 0.406)/0.225 38 | return img_tmp 39 | 40 | 41 | def get_stylespace_from_w(w, G): 42 | style_space = [] 43 | to_rgb_stylespaces = [] 44 | 45 | noise = [getattr(G.noises, 'noise_{}'.format(i)) for i in range(G.num_layers)] 46 | latent = w 47 | style_space.append(G.conv1.conv.modulation(latent[:, 0])) 48 | to_rgb_stylespaces.append(G.to_rgb1.conv.modulation(latent[:, 1])) 49 | 50 | i = 1 51 | for conv1, conv2, noise1, noise2, to_rgb in zip( 52 | G.convs[::2], G.convs[1::2], noise[1::2], noise[2::2], G.to_rgbs 53 | ): 54 | style_space.append(conv1.conv.modulation(latent[:, i])) 55 | style_space.append(conv2.conv.modulation(latent[:, i + 1])) 56 | to_rgb_stylespaces.append(to_rgb.conv.modulation(latent[:, i + 2])) 57 | i += 2 58 | return style_space, to_rgb_stylespaces 59 | 60 | 61 | def get_stylespace_from_w_hyperinv(w, G): 62 | with torch.no_grad(): 63 | style_space = [] 64 | to_rgb_stylespaces = [] 65 | G = G.synthesis 66 | 67 | block_ws = [] 68 | w_idx = 0 69 | for res in G.block_resolutions: 70 | block = getattr(G, f"b{res}") 71 | block_ws.append(w.narrow(1, w_idx, block.num_conv + block.num_torgb)) 72 | w_idx += block.num_conv 73 | 74 | i = 0 75 | for res, cur_ws in zip(G.block_resolutions, block_ws): 76 | block = getattr(G, f"b{res}") 77 | if i != 0: 78 | style_space.append(block.conv0.affine(w[:, i])) 79 | i += 1 80 | style_space.append(block.conv1.affine(w[:, i])) 81 | i += 1 82 | to_rgb_stylespaces.append(block.torgb.affine(w[:, i])) 83 | 84 | return style_space, to_rgb_stylespaces 85 | -------------------------------------------------------------------------------- /utils/torch_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/utils/torch_utils/__init__.py -------------------------------------------------------------------------------- /utils/torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import glob 10 | import hashlib 11 | import importlib 12 | import os 13 | import shutil 14 | from pathlib import Path 15 | 16 | import torch 17 | import torch.utils.cpp_extension 18 | from torch.utils.file_baton import FileBaton 19 | 20 | 21 | # ---------------------------------------------------------------------------- 22 | # Global options. 23 | 24 | verbosity = "brief" # Verbosity level: 'none', 'brief', 'full' 25 | 26 | # ---------------------------------------------------------------------------- 27 | # Internal helper funcs. 28 | 29 | 30 | def _find_compiler_bindir(): 31 | patterns = [ 32 | "C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64", 33 | "C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64", 34 | "C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64", 35 | "C:/Program Files (x86)/Microsoft Visual Studio */vc/bin", 36 | ] 37 | for pattern in patterns: 38 | matches = sorted(glob.glob(pattern)) 39 | if len(matches): 40 | return matches[-1] 41 | return None 42 | 43 | 44 | # ---------------------------------------------------------------------------- 45 | # Main entry point for compiling and loading C++/CUDA plugins. 46 | 47 | _cached_plugins = dict() 48 | 49 | 50 | def get_plugin(module_name, sources, **build_kwargs): 51 | assert verbosity in ["none", "brief", "full"] 52 | 53 | # Already cached? 54 | if module_name in _cached_plugins: 55 | return _cached_plugins[module_name] 56 | 57 | # Print status. 58 | if verbosity == "full": 59 | print(f'Setting up PyTorch plugin "{module_name}"...') 60 | elif verbosity == "brief": 61 | print(f'Setting up PyTorch plugin "{module_name}"... ', end="", flush=True) 62 | 63 | try: # pylint: disable=too-many-nested-blocks 64 | # Make sure we can find the necessary compiler binaries. 65 | if os.name == "nt" and os.system("where cl.exe >nul 2>nul") != 0: 66 | compiler_bindir = _find_compiler_bindir() 67 | if compiler_bindir is None: 68 | raise RuntimeError( 69 | f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".' 70 | ) 71 | os.environ["PATH"] += ";" + compiler_bindir 72 | 73 | # Compile and load. 74 | verbose_build = verbosity == "full" 75 | 76 | # Incremental build md5sum trickery. Copies all the input source files 77 | # into a cached build directory under a combined md5 digest of the input 78 | # source files. Copying is done only if the combined digest has changed. 79 | # This keeps input file timestamps and filenames the same as in previous 80 | # extension builds, allowing for fast incremental rebuilds. 81 | # 82 | # This optimization is done only in case all the source files reside in 83 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 84 | # environment variable is set (we take this as a signal that the user 85 | # actually cares about this.) 86 | source_dirs_set = set(os.path.dirname(source) for source in sources) 87 | if len(source_dirs_set) == 1 and ("TORCH_EXTENSIONS_DIR" in os.environ): 88 | all_source_files = sorted( 89 | list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()) 90 | ) 91 | 92 | # Compute a combined hash digest for all source files in the same 93 | # custom op directory (usually .cu, .cpp, .py and .h files). 94 | hash_md5 = hashlib.md5() 95 | for src in all_source_files: 96 | with open(src, "rb") as f: 97 | hash_md5.update(f.read()) 98 | build_dir = torch.utils.cpp_extension._get_build_directory( 99 | module_name, verbose=verbose_build 100 | ) # pylint: disable=protected-access 101 | digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) 102 | 103 | if not os.path.isdir(digest_build_dir): 104 | os.makedirs(digest_build_dir, exist_ok=True) 105 | baton = FileBaton(os.path.join(digest_build_dir, "lock")) 106 | if baton.try_acquire(): 107 | try: 108 | for src in all_source_files: 109 | shutil.copyfile( 110 | src, 111 | os.path.join(digest_build_dir, os.path.basename(src)), 112 | ) 113 | finally: 114 | baton.release() 115 | else: 116 | # Someone else is copying source files under the digest dir, 117 | # wait until done and continue. 118 | baton.wait() 119 | digest_sources = [ 120 | os.path.join(digest_build_dir, os.path.basename(x)) for x in sources 121 | ] 122 | torch.utils.cpp_extension.load( 123 | name=module_name, 124 | build_directory=build_dir, 125 | verbose=verbose_build, 126 | sources=digest_sources, 127 | **build_kwargs, 128 | ) 129 | else: 130 | torch.utils.cpp_extension.load( 131 | name=module_name, verbose=verbose_build, sources=sources, **build_kwargs 132 | ) 133 | module = importlib.import_module(module_name) 134 | 135 | except Exception: 136 | if verbosity == "brief": 137 | print("Failed!") 138 | raise 139 | 140 | # Print status and add to cache. 141 | if verbosity == "full": 142 | print(f'Done setting up PyTorch plugin "{module_name}".') 143 | elif verbosity == "brief": 144 | print("Done.") 145 | _cached_plugins[module_name] = module 146 | return module 147 | 148 | 149 | # ---------------------------------------------------------------------------- 150 | -------------------------------------------------------------------------------- /utils/torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/StyleFeatureEditor/780e2ae2fd90e4872ea175997603d5b1141587ce/utils/torch_utils/ops/__init__.py -------------------------------------------------------------------------------- /utils/torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } -------------------------------------------------------------------------------- /utils/torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ -------------------------------------------------------------------------------- /utils/torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ -------------------------------------------------------------------------------- /utils/torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | 14 | # ---------------------------------------------------------------------------- 15 | 16 | 17 | def fma(a, b, c): # => a * b + c 18 | return _FusedMultiplyAdd.apply(a, b, c) 19 | 20 | 21 | # ---------------------------------------------------------------------------- 22 | 23 | 24 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 25 | @staticmethod 26 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 27 | out = torch.addcmul(c, a, b) 28 | ctx.save_for_backward(a, b) 29 | ctx.c_shape = c.shape 30 | return out 31 | 32 | @staticmethod 33 | def backward(ctx, dout): # pylint: disable=arguments-differ 34 | a, b = ctx.saved_tensors 35 | c_shape = ctx.c_shape 36 | da = None 37 | db = None 38 | dc = None 39 | 40 | if ctx.needs_input_grad[0]: 41 | da = _unbroadcast(dout * b, a.shape) 42 | 43 | if ctx.needs_input_grad[1]: 44 | db = _unbroadcast(dout * a, b.shape) 45 | 46 | if ctx.needs_input_grad[2]: 47 | dc = _unbroadcast(dout, c_shape) 48 | 49 | return da, db, dc 50 | 51 | 52 | # ---------------------------------------------------------------------------- 53 | 54 | 55 | def _unbroadcast(x, shape): 56 | extra_dims = x.ndim - len(shape) 57 | assert extra_dims >= 0 58 | dim = [ 59 | i 60 | for i in range(x.ndim) 61 | if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1) 62 | ] 63 | if len(dim): 64 | x = x.sum(dim=dim, keepdim=True) 65 | if extra_dims: 66 | x = x.reshape(-1, *x.shape[extra_dims + 1 :]) 67 | assert x.shape == shape 68 | return x 69 | 70 | 71 | # ---------------------------------------------------------------------------- 72 | -------------------------------------------------------------------------------- /utils/torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 25 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 26 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 27 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 28 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 29 | 30 | // Create output tensor. 31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 32 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 33 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 34 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 35 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 36 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 37 | 38 | // Initialize CUDA kernel parameters. 39 | upfirdn2d_kernel_params p; 40 | p.x = x.data_ptr(); 41 | p.f = f.data_ptr(); 42 | p.y = y.data_ptr(); 43 | p.up = make_int2(upx, upy); 44 | p.down = make_int2(downx, downy); 45 | p.pad0 = make_int2(padx0, pady0); 46 | p.flip = (flip) ? 1 : 0; 47 | p.gain = gain; 48 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 49 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 50 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 51 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 52 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 53 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 54 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 55 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 56 | 57 | // Choose CUDA kernel. 58 | upfirdn2d_kernel_spec spec; 59 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 60 | { 61 | spec = choose_upfirdn2d_kernel(p); 62 | }); 63 | 64 | // Set looping options. 65 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 66 | p.loopMinor = spec.loopMinor; 67 | p.loopX = spec.loopX; 68 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 69 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 70 | 71 | // Compute grid size. 72 | dim3 blockSize, gridSize; 73 | if (spec.tileOutW < 0) // large 74 | { 75 | blockSize = dim3(4, 32, 1); 76 | gridSize = dim3( 77 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 78 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 79 | p.launchMajor); 80 | } 81 | else // small 82 | { 83 | blockSize = dim3(256, 1, 1); 84 | gridSize = dim3( 85 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 86 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 87 | p.launchMajor); 88 | } 89 | 90 | // Launch CUDA kernel. 91 | void* args[] = {&p}; 92 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 93 | return y; 94 | } 95 | 96 | //------------------------------------------------------------------------ 97 | 98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 99 | { 100 | m.def("upfirdn2d", &upfirdn2d); 101 | } 102 | 103 | //------------------------------------------------------------------------ -------------------------------------------------------------------------------- /utils/torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ --------------------------------------------------------------------------------